diff --git a/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py b/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py index d465418..c2a7244 100644 --- a/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py +++ b/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py @@ -1,11 +1,13 @@ -import xarray -import numpy as np +import os +from datetime import datetime + import geopandas as gpd -from shapely.geometry import Polygon, Point import matplotlib.pyplot as plt +import numpy as np import pandas as pd -from datetime import datetime -import os +import xarray +from shapely.geometry import Point, Polygon + class SpatialTemporalStats: def __init__(self): @@ -15,168 +17,250 @@ def __init__(self): def generate_grid(self, resolution=1): self.resolution = resolution # Generate the latitude and longitude values using meshgrid - grid_lons, grid_lats = np.meshgrid(np.arange(-180, 181, resolution), - np.arange(-90, 91, resolution)) + grid_lons, grid_lats = np.meshgrid( + np.arange(-180, 181, resolution), np.arange(-90, 91, resolution) + ) # Flatten the arrays to get coordinates grid_coords = np.vstack([grid_lons.flatten(), grid_lats.flatten()]).T # Create a GeoDataFrame from the coordinates - self.grid_gdf = gpd.GeoDataFrame(geometry=[Polygon([(lon, lat), (lon + resolution, lat), - (lon + resolution, lat + resolution), (lon, lat + resolution)]) - for lon, lat in grid_coords], - crs='EPSG:4326') # CRS for WGS84 - self.grid_gdf["grid_id"]=np.arange(1,len(self.grid_gdf)+1) - - - def _extract_date_times(self,filenames): + self.grid_gdf = gpd.GeoDataFrame( + geometry=[ + Polygon( + [ + (lon, lat), + (lon + resolution, lat), + (lon + resolution, lat + resolution), + (lon, lat + resolution), + ] + ) + for lon, lat in grid_coords + ], + crs="EPSG:4326", + ) # CRS for WGS84 + self.grid_gdf["grid_id"] = np.arange(1, len(self.grid_gdf) + 1) + + def _extract_date_times(self, filenames): date_times = [] for filename in filenames: # Split the filename by '.' to get the parts - parts = filename.split('.') - + parts = filename.split(".") + # Extract the last part which contains the date/time information date_time_part = parts[-2] - - # The date/time format in the filename seems to be 'YYYYMMDDHH', so we can parse it accordingly + + # date/time format in filename is 'YYYYMMDDHH', can parse it accordingly year = int(date_time_part[:4]) month = int(date_time_part[4:6]) day = int(date_time_part[6:8]) hour = int(date_time_part[8:10]) - + # Construct the datetime object date_time = datetime(year, month, day, hour) - - date_times.append(date_time) - - return date_times - - - def read_obs_values(self, obs_files_path, sensor,var_name, channel_no, start_date, end_date,filter_by_vars,QC_filter): - self.sensor=sensor - self.channel_no=channel_no - #read all obs files - all_files = os.listdir(obs_files_path) - #obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4')] - obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4') and "diag_%s_ges" % sensor in file] - - # get date time from file names. alternatively could get from attribute but that needs reading the entire nc4 - files_date_times_df=pd.DataFrame() - - files_date_times=self._extract_date_times(obs_files) - files_date_times_df["file_name"]=obs_files - files_date_times_df["date_time"]=files_date_times - files_date_times_df['date'] = pd.to_datetime(files_date_times_df['date_time'].dt.date) - #read start date - start_date = datetime.strptime(start_date, '%Y-%m-%d') - end_date = datetime.strptime(end_date, '%Y-%m-%d') + date_times.append(date_time) - studied_cycle_files=files_date_times_df[((files_date_times_df["date"]>=start_date) & ((files_date_times_df["date"]<=end_date)))]["file_name"] + return date_times - studied_gdf_list=[] + def read_obs_values( + self, + obs_files_path, + sensor, + var_name, + channel_no, + start_date, + end_date, + filter_by_vars, + QC_filter, + ): + self.sensor = sensor + self.channel_no = channel_no + # read all obs files + all_files = os.listdir(obs_files_path) + obs_files = [ + os.path.join(obs_files_path, file) + for file in all_files + if file.endswith(".nc4") and "diag_%s_ges" % sensor in file + ] + + # get date time from file names + files_date_times_df = pd.DataFrame() + + files_date_times = self._extract_date_times(obs_files) + files_date_times_df["file_name"] = obs_files + files_date_times_df["date_time"] = files_date_times + files_date_times_df["date"] = pd.to_datetime( + files_date_times_df["date_time"].dt.date + ) + + # read start date + start_date = datetime.strptime(start_date, "%Y-%m-%d") + end_date = datetime.strptime(end_date, "%Y-%m-%d") + + studied_cycle_files = files_date_times_df[ + ( + (files_date_times_df["date"] >= start_date) + & ((files_date_times_df["date"] <= end_date)) + ) + ]["file_name"] + + studied_gdf_list = [] for this_cycle_obs_file in studied_cycle_files: - ds=xarray.open_dataset(this_cycle_obs_file) + ds = xarray.open_dataset(this_cycle_obs_file) - Combined_bool=ds["Channel_Index"].data==channel_no + Combined_bool = ds["Channel_Index"].data == channel_no if QC_filter: - QC_bool=ds["QC_Flag"].data==0 - Combined_bool=Combined_bool*QC_bool + QC_bool = ds["QC_Flag"].data == 0 + Combined_bool = Combined_bool * QC_bool - #apply filters by variable + # apply filters by variable for this_filter in filter_by_vars: filter_var_name, filter_operation, filter_value = this_filter - if filter_operation=='lt': - this_filter_bool=ds[filter_var_name].data<=filter_value + if filter_operation == "lt": + this_filter_bool = ds[filter_var_name].data <= filter_value else: - this_filter_bool=ds[filter_var_name].data>=filter_value - Combined_bool=Combined_bool*~this_filter_bool #here we have to negate the above bool to make it right - - this_cycle_var_values=ds[var_name].data[Combined_bool] - this_cycle_lat_values=ds["Latitude"].data[Combined_bool] - this_cycle_long_values=ds["Longitude"].data[Combined_bool] - this_cycle_long_values=np.where(this_cycle_long_values<=180, this_cycle_long_values,this_cycle_long_values-360) - - geometry = [Point(xy) for xy in zip(this_cycle_long_values, this_cycle_lat_values)] + this_filter_bool = ds[filter_var_name].data >= filter_value + Combined_bool = ( + Combined_bool * ~this_filter_bool + ) # here we have to negate the above bool to make it right + + this_cycle_var_values = ds[var_name].data[Combined_bool] + this_cycle_lat_values = ds["Latitude"].data[Combined_bool] + this_cycle_long_values = ds["Longitude"].data[Combined_bool] + this_cycle_long_values = np.where( + this_cycle_long_values <= 180, + this_cycle_long_values, + this_cycle_long_values - 360, + ) + + geometry = [ + Point(xy) for xy in zip(this_cycle_long_values, this_cycle_lat_values) + ] # Create a GeoDataFrame - this_cycle_gdf = gpd.GeoDataFrame(geometry=geometry, crs='EPSG:4326') - this_cycle_gdf["value"]=this_cycle_var_values + this_cycle_gdf = gpd.GeoDataFrame(geometry=geometry, crs="EPSG:4326") + this_cycle_gdf["value"] = this_cycle_var_values studied_gdf_list.append(this_cycle_gdf) - studied_gdf = pd.concat(studied_gdf_list) # Perform spatial join - joined_gdf = gpd.sjoin(studied_gdf, self.grid_gdf, op='within', how='right') + joined_gdf = gpd.sjoin(studied_gdf, self.grid_gdf, op="within", how="right") # Calculate average values of points in each polygon - self.obs_gdf=self.grid_gdf.copy() - self.obs_gdf[var_name +'_Average'] = joined_gdf.groupby('grid_id')['value'].mean() - self.obs_gdf[var_name + '_RMS'] = (joined_gdf.groupby('grid_id')['value'].apply(lambda x: np.sqrt((x**2).mean()))) - self.obs_gdf[var_name +"_Count"] = joined_gdf.groupby('grid_id')['value'].count() + self.obs_gdf = self.grid_gdf.copy() + self.obs_gdf[var_name + "_Average"] = joined_gdf.groupby("grid_id")[ + "value" + ].mean() + self.obs_gdf[var_name + "_RMS"] = joined_gdf.groupby("grid_id")["value"].apply( + lambda x: np.sqrt((x**2).mean()) + ) + self.obs_gdf[var_name + "_Count"] = joined_gdf.groupby("grid_id")[ + "value" + ].count() + + # convert count of zero to null. This will help also for plotting + self.obs_gdf[var_name + "_Count"] = np.where( + self.obs_gdf[var_name + "_Count"].values == 0, + np.nan, + self.obs_gdf[var_name + "_Count"].values, + ) return self.obs_gdf - def filter_by_variable(self, var_name, comparison_type, value): - pass - - def plot_obs(self, selected_var_gdf, var_name, region, resolution,output_path): + def plot_obs(self, selected_var_gdf, var_name, region, resolution, output_path): self.resolution = resolution - var_names = [var_name + '_Average', var_name + "_Count", var_name + '_RMS'] + var_names = [var_name + "_Average", var_name + "_Count", var_name + "_RMS"] - for idx, item in enumerate(var_names, start=1): + for _, item in enumerate(var_names): plt.figure(figsize=(12, 8)) ax = plt.subplot(1, 1, 1) - - if region == 1: # Plotting global region (no need for filtering) - title = 'Global Region' - filtered_gdf=selected_var_gdf + title = "Global Region" + filtered_gdf = selected_var_gdf elif region == 2: # Plotting polar region (+60 latitude and above) - title = 'Polar Region (+60 latitude and above)' + title = "Polar Region (+60 latitude and above)" filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply(lambda geom: self.is_polygon_in_polar_region(geom, 60))] + selected_var_gdf.geometry.apply( + lambda geom: self.is_polygon_in_polar_region(geom, 60) + ) + ] elif region == 3: # Plotting northern mid-latitudes region (20 to 60 latitude) - title = 'Northern Mid-latitudes Region (20 to 60 latitude)' + title = "Northern Mid-latitudes Region (20 to 60 latitude)" filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply(lambda geom: self.is_polygon_in_latitude_range(geom, 20, 60))] + selected_var_gdf.geometry.apply( + lambda geom: self.is_polygon_in_latitude_range(geom, 20, 60) + ) + ] elif region == 4: # Plotting tropics region (-20 to 20 latitude) - title = 'Tropics Region (-20 to 20 latitude)' + title = "Tropics Region (-20 to 20 latitude)" filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply(lambda geom: self.is_polygon_in_latitude_range(geom, -20, 20))] + selected_var_gdf.geometry.apply( + lambda geom: self.is_polygon_in_latitude_range(geom, -20, 20) + ) + ] elif region == 5: # Plotting southern mid-latitudes region (-60 to -20 latitude) - title = 'Southern Mid-latitudes Region (-60 to -20 latitude)' + title = "Southern Mid-latitudes Region (-60 to -20 latitude)" filtered_gdf = selected_var_gdf[ - selected_var_gdf.geometry.apply(lambda geom: self.is_polygon_in_latitude_range(geom, -60, -20))] + selected_var_gdf.geometry.apply( + lambda geom: self.is_polygon_in_latitude_range(geom, -60, -20) + ) + ] elif region == 6: # Plotting southern polar region (less than -60 latitude) - title = 'Southern Polar Region (less than -60 latitude)' - filtered_gdf = selected_var_gdf[selected_var_gdf.geometry.apply(lambda geom: geom.centroid.y < -60)] - - min_val, max_val, std_val = filtered_gdf[item].min(), filtered_gdf[item].max(), \ - filtered_gdf[item].std() - cbar_label = 'grid=%dX%d, min=%.3lf, max=%.3lf, std=%.3lf\n' % (resolution, resolution, min_val, max_val, std_val) - - filtered_gdf.plot(ax=ax, cmap='jet', column=item, legend=True, - missing_kwds={'color': 'lightgrey'}, - legend_kwds={'orientation': 'horizontal', 'shrink': 0.5, 'label': cbar_label}) - - plt.title("%s\n%s ch:%d %s" % (title,self.sensor, self.channel_no, item)) - plt.savefig(os.path.join(output_path,"%s_ch%d_%s_region_%d.png" % (self.sensor, self.channel_no, item, region))) + title = "Southern Polar Region (less than -60 latitude)" + filtered_gdf = selected_var_gdf[ + selected_var_gdf.geometry.apply(lambda geom: geom.centroid.y < -60) + ] + + min_val, max_val, std_val = ( + filtered_gdf[item].min(), + filtered_gdf[item].max(), + filtered_gdf[item].std(), + ) + cbar_label = "grid=%dX%d, min=%.3lf, max=%.3lf, std=%.3lf\n" % ( + resolution, + resolution, + min_val, + max_val, + std_val, + ) + + filtered_gdf.plot( + ax=ax, + cmap="jet", + column=item, + legend=True, + missing_kwds={"color": "lightgrey"}, + legend_kwds={ + "orientation": "horizontal", + "shrink": 0.5, + "label": cbar_label, + }, + ) + + plt.title("%s\n%s ch:%d %s" % (title, self.sensor, self.channel_no, item)) + plt.savefig( + os.path.join( + output_path, + "%s_ch%d_%s_region_%d.png" + % (self.sensor, self.channel_no, item, region), + ) + ) plt.close() def is_polygon_in_polar_region(self, polygon, latitude_threshold): @@ -204,78 +288,127 @@ def is_polygon_in_latitude_range(self, polygon, min_latitude, max_latitude): # Check if the latitude is within the specified range return min_latitude <= centroid_latitude <= max_latitude - + def list_variable_names(self, file_path): - ds=xarray.open_dataset(file_path) + ds = xarray.open_dataset(file_path) print(ds.info()) - def make_summary_plots(self, obs_files_path, sensor,var_name, start_date, end_date,QC_filter,output_path): - self.sensor=sensor - #read all obs files + def make_summary_plots( + self, + obs_files_path, + sensor, + var_name, + start_date, + end_date, + QC_filter, + output_path, + ): + self.sensor = sensor + # read all obs files all_files = os.listdir(obs_files_path) - #obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4')] - obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4') and "diag_%s_ges" % sensor in file] + # obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4')] + obs_files = [ + os.path.join(obs_files_path, file) + for file in all_files + if file.endswith(".nc4") and "diag_%s_ges" % sensor in file + ] # get date time from file names. alternatively could get from attribute but that needs reading the entire nc4 - files_date_times_df=pd.DataFrame() - - files_date_times=self._extract_date_times(obs_files) - files_date_times_df["file_name"]=obs_files - files_date_times_df["date_time"]=files_date_times - files_date_times_df['date'] = pd.to_datetime(files_date_times_df['date_time'].dt.date) - - #read start date - start_date = datetime.strptime(start_date, '%Y-%m-%d') - end_date = datetime.strptime(end_date, '%Y-%m-%d') - - studied_cycle_files=files_date_times_df[((files_date_times_df["date"]>=start_date) & ((files_date_times_df["date"]<=end_date)))]["file_name"] - Summary_results=[] - - #get unique channels from one of the files - ds=xarray.open_dataset(studied_cycle_files[0]) - unique_channels=np.unique(ds["Channel_Index"].data).tolist() + files_date_times_df = pd.DataFrame() + + files_date_times = self._extract_date_times(obs_files) + files_date_times_df["file_name"] = obs_files + files_date_times_df["date_time"] = files_date_times + files_date_times_df["date"] = pd.to_datetime( + files_date_times_df["date_time"].dt.date + ) + + # read start date + start_date = datetime.strptime(start_date, "%Y-%m-%d") + end_date = datetime.strptime(end_date, "%Y-%m-%d") + + studied_cycle_files = files_date_times_df[ + ( + (files_date_times_df["date"] >= start_date) + & ((files_date_times_df["date"] <= end_date)) + ) + ]["file_name"] + Summary_results = [] + + # get unique channels from one of the files + ds = xarray.open_dataset(studied_cycle_files[0]) + unique_channels = np.unique(ds["Channel_Index"].data).tolist() for this_channel in unique_channels: this_channel_values = np.empty(shape=(0,)) for this_cycle_obs_file in studied_cycle_files: - ds=xarray.open_dataset(this_cycle_obs_file) - Combined_bool=ds["Channel_Index"].data==this_channel + ds = xarray.open_dataset(this_cycle_obs_file) + Combined_bool = ds["Channel_Index"].data == this_channel if QC_filter: - QC_bool=ds["QC_Flag"].data==0 - Combined_bool=Combined_bool*QC_bool - - this_cycle_var_values=ds[var_name].data[Combined_bool] - this_channel_values = np.append(this_channel_values, this_cycle_var_values) - - Summary_results.append([this_channel,np.size(this_channel_values), np.std(this_channel_values),np.mean(this_channel_values) ]) - - Summary_resultsDF=pd.DataFrame(Summary_results, columns=["channel", "count", "std", 'mean']) + QC_bool = ds["QC_Flag"].data == 0 + Combined_bool = Combined_bool * QC_bool + + this_cycle_var_values = ds[var_name].data[Combined_bool] + this_channel_values = np.append( + this_channel_values, this_cycle_var_values + ) + + Summary_results.append( + [ + this_channel, + np.size(this_channel_values), + np.std(this_channel_values), + np.mean(this_channel_values), + ] + ) + + Summary_resultsDF = pd.DataFrame( + Summary_results, columns=["channel", "count", "std", "mean"] + ) # Plotting plt.figure(figsize=(10, 6)) plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["count"], s=50) plt.xlabel("Channel") plt.ylabel("Count") - plt.title("%s %s"%((self.sensor,var_name))) + plt.title("%s %s" % ((self.sensor, var_name))) plt.xticks(Summary_resultsDF["channel"]) plt.xticks(rotation=45) plt.grid(True) plt.tight_layout() - plt.savefig(os.path.join(output_path,"%s_%s_sumamryCounts.png" % (self.sensor,var_name))) + plt.savefig( + os.path.join( + output_path, "%s_%s_sumamryCounts.png" % (self.sensor, var_name) + ) + ) plt.close() # Plotting scatter plot for mean and std plt.figure(figsize=(10, 6)) - plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["mean"], s=50, c='red', label="Mean") - plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["std"], s=50, c='green', label="Std") + plt.scatter( + Summary_resultsDF["channel"], + Summary_resultsDF["mean"], + s=50, + c="red", + label="Mean", + ) + plt.scatter( + Summary_resultsDF["channel"], + Summary_resultsDF["std"], + s=50, + c="green", + label="Std", + ) plt.xlabel("Channel") plt.ylabel("Statistics") - plt.title("%s %s"%((self.sensor,var_name))) + plt.title("%s %s" % ((self.sensor, var_name))) plt.xticks(Summary_resultsDF["channel"]) plt.xticks(rotation=45) plt.grid(True) plt.tight_layout() plt.legend() - plt.savefig(os.path.join(output_path, "%s_%s_mean_std.png" % (self.sensor, var_name))) + plt.savefig( + os.path.join(output_path, "%s_%s_mean_std.png" % (self.sensor, var_name)) + ) return Summary_resultsDF diff --git a/ush/SpatialTemporalStatsTool/user_Analysis.py b/ush/SpatialTemporalStatsTool/user_Analysis.py index 22929fa..8a6efb7 100644 --- a/ush/SpatialTemporalStatsTool/user_Analysis.py +++ b/ush/SpatialTemporalStatsTool/user_Analysis.py @@ -1,47 +1,67 @@ from SpatialTemporalStats import SpatialTemporalStats # Set input and output paths -input_path= r"/scratch2/NCEPDEV/stmp1/Emily.Liu/GDAS-ops-radstat/data/" -output_path=r'./Results' +input_path = "/PATH/TO/Input/Files" +output_path = r'./Results' + # Set sensor name -sensor="atms_n20" +sensor = "atms_n20" # Set variable name and channel number -var_name= "Obs_Minus_Forecast_adjusted" -channel_no=1 +var_name = "Obs_Minus_Forecast_adjusted" +channel_no = 1 # Set start and end dates start_date, end_date = '2023-03-01', '2023-03-10' -# Set region (1: global, 2: polar region, 3: mid-latitudes region, 4:tropics region, 5:southern mid-latitudes region, 6:southern polar region) -region=1 +# Set region +# 1: global, 2: polar region, 3: mid-latitudes region, +# 4: tropics region, 5:southern mid-latitudes region, 6: southern polar region +region = 3 # Initialize SpatialTemporalStats object my_tool = SpatialTemporalStats() # Set resolution for grid generation -resolution=2 +resolution = 2 # Generate grid -my_tool.generate_grid(resolution) # Call generate_grid method -# observation_gdf=my_tool.read_obs_values(input_path, sensor, var_name, channel_no, start_date, end_date) +my_tool.generate_grid(resolution) # Call generate_grid method) # Set QC filter -QC_filter= True # should be always False or true +QC_filter = True # should be always False or true # Set filter by variables -filter_by_vars=[("Land_Fraction", 'lt', 0.9),] #list each case in a separate tuple inside this list. can be an empty list #options are 'lt' or 'gt' stands for 'less than' and 'greater than' -#filter_by_vars=[] +# can be an empty list +filter_by_vars=[] + +#filter_by_vars = [("Land_Fraction", "lt", 0.9),] +# list each case in a separate tuple inside this list. +# options are 'lt' or 'gt' for 'less than' and 'greater than' # Read observational values and perform analysis -o_minus_f_gdf=my_tool.read_obs_values(input_path, sensor,var_name, channel_no, start_date, end_date,filter_by_vars,QC_filter) +o_minus_f_gdf = my_tool.read_obs_values( + input_path, + sensor, + var_name, + channel_no, + start_date, + end_date, + filter_by_vars, + QC_filter, +) + +# Can save the results in a gpkg file +# o_minus_f_gdf.to_file("filename.gpkg", driver='GPKG') # Plot observations my_tool.plot_obs(o_minus_f_gdf, var_name, region, resolution, output_path) # Make summary plots -summary_results=my_tool.make_summary_plots( input_path, sensor,var_name, start_date, end_date,QC_filter,output_path) +summary_results = my_tool.make_summary_plots( + input_path, sensor, var_name, start_date, end_date, QC_filter, output_path +) # Print summary results print(summary_results)