diff --git a/pyschism/forcing/source_sink/nwm.py b/pyschism/forcing/source_sink/nwm.py index 7c15a55f..947f0ecf 100644 --- a/pyschism/forcing/source_sink/nwm.py +++ b/pyschism/forcing/source_sink/nwm.py @@ -113,38 +113,45 @@ def __init__(self, hgrid: Gr3, nwm_file=None, workers=-1): raise IOError( "No National Water model intersections found on the mesh.") intersection = gpd.GeoDataFrame(data, crs=hgrid.crs) - #TODO: add exporting intersection as an option + #export intersection to a file #intersection.to_file('intersections.shp') del data - # 2) Generate element centroid KDTree - centroids = [] - for element in hgrid.elements.elements.values(): - cent = LinearRing( - hgrid.nodes.coord[list( - map(hgrid.nodes.get_index_by_id, element))] - ).centroid - centroids.append((cent.x, cent.y)) - tree = cKDTree(centroids) - del centroids - - # 3) Match reach/boundary intersection to nearest element centroid - coords = [ - np.array(inters.geometry.coords) for inters in intersection.itertuples() - ] - _, idxs = tree.query(np.vstack(coords), workers=workers) - del tree - - logger.info( - "Pairing features to corresponding element took " f"{time()-start}." - ) + #TODO:intead of creating a gdf for all elements, only create for nearest elements and its neighbors + buffered_geometry = self.hgrid.elements.gdf.geometry.buffer(0.0001) + elements_gdf = self.hgrid.elements.gdf.copy() + elements_gdf['geometry'] = buffered_geometry + joined_gdf = gpd.sjoin(intersection, elements_gdf, how="inner", predicate="intersects") + point_element_ids = joined_gdf.index_right.values.astype(int) + + ## 2) Generate element centroid KDTree + #centroids = [] + #for element in hgrid.elements.elements.values(): + # cent = LinearRing( + # hgrid.nodes.coord[list( + # map(hgrid.nodes.get_index_by_id, element))] + # ).centroid + # centroids.append((cent.x, cent.y)) + #tree = cKDTree(centroids) + #del centroids + + ## 3) Match reach/boundary intersection to nearest element centroid + #coords = [ + # np.array(inters.geometry.coords) for inters in intersection.itertuples() + #] + #_, idxs = tree.query(np.vstack(coords), workers=workers) + #del tree + + #logger.info( + # "Pairing features to corresponding element took " f"{time()-start}." + #) hull = hgrid.hull.multipolygon() start = time() sources = defaultdict(list) sinks = defaultdict(list) - for row in intersection.itertuples(): + for idx, row in enumerate(intersection.itertuples()): poi = row.geometry reach = reaches.iloc[row.reachIndex].geometry if not isinstance(reach, LineString): @@ -157,15 +164,18 @@ def __init__(self, hgrid: Gr3, nwm_file=None, workers=-1): d1 = segment_origin.distance(poi) downstream = segment.interpolate( d1 + np.finfo(np.float32).eps) - element = hgrid.elements.gdf.iloc[idxs[row.Index]] + #element = hgrid.elements.gdf.iloc[idxs[row.Index]] + element = point_element_ids[idx] + 1 if ( box(*LineString([poi, downstream]).bounds) .intersection(hull) .intersects(downstream) ): - sources[element.id].append(reaches.iloc[row.reachIndex].feature_id) + #sources[element.id].append(reaches.iloc[row.reachIndex].feature_id) + sources[str(element)].append(reaches.iloc[row.reachIndex].feature_id) else: - sinks[element.id].append(reaches.iloc[row.reachIndex].feature_id) + #sinks[element.id].append(reaches.iloc[row.reachIndex].feature_id) + sinks[str(element)].append(reaches.iloc[row.reachIndex].feature_id) break logger.info(