Skip to content

Commit

Permalink
Convert proj to queries crs
Browse files Browse the repository at this point in the history
  • Loading branch information
mpiannucci committed Nov 14, 2024
1 parent b022ebf commit 6f004dd
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 8 deletions.
21 changes: 20 additions & 1 deletion tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_select_position_regular_xy(regular_xy_dataset):


def test_select_position_projected_xy(projected_xy_dataset):
from xpublish_edr.geometry.common import project_geometry
from xpublish_edr.geometry.common import project_geometry, project_dataset

point = Point((64.59063409, 66.66454929))
projected_point = project_geometry(projected_xy_dataset, "EPSG:4326", point)
Expand All @@ -157,6 +157,25 @@ def test_select_position_projected_xy(projected_xy_dataset):
projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest"),
)

projected_ds = project_dataset(ds, "EPSG:4326")
(
npt.assert_approx_equal(projected_ds.longitude.values, 64.59063409),
"Longitude is incorrect",
)
(
npt.assert_approx_equal(projected_ds.latitude.values, 66.66454929),
"Latitude is incorrect",
)
(
npt.assert_approx_equal(
projected_ds.temp.values,
projected_xy_dataset.sel(
rlon=[18.045], rlat=[21.725], method="nearest"
).temp.values,
),
"Temperature is incorrect",
)


def test_select_position_regular_xy_interpolate(regular_xy_dataset):
point = Point((204, 44))
Expand Down
82 changes: 75 additions & 7 deletions xpublish_edr/geometry/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,92 @@ def is_regular_xy_coords(ds: xr.Dataset) -> bool:
return coord_is_regular(ds.cf["X"]) and coord_is_regular(ds.cf["Y"])


def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> Geometry:
"""
Get the projection from the dataset
"""
def dataset_crs(ds: xr.Dataset) -> pyproj.CRS:
grid_mapping_names = ds.cf.grid_mapping_names
if len(grid_mapping_names) == 0:
# TODO: Should we require a grid mapping? For now return as is
return geometry
# Default to WGS84
return pyproj.crs.CRS.from_epsg(4326)
if len(grid_mapping_names) > 1:
raise ValueError(f"Multiple grid mappings found: {grid_mapping_names!r}!")
(grid_mapping_var,) = tuple(itertools.chain(*ds.cf.grid_mapping_names.values()))

grid_mapping = ds[grid_mapping_var]
data_crs = pyproj.crs.CRS.from_cf(grid_mapping.attrs)
return pyproj.CRS.from_cf(grid_mapping.attrs)


def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> Geometry:
"""
Get the projection from the dataset
"""
data_crs = dataset_crs(ds)

transformer = transformer_from_crs(
crs_from=geometry_crs,
crs_to=data_crs,
always_xy=True,
)
return transform(transformer.transform, geometry)


def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset:
"""
Project the dataset to the given CRS
"""
data_crs = dataset_crs(ds)
target_crs = pyproj.CRS.from_string(query_crs)
if data_crs == target_crs:
return ds

transformer = transformer_from_crs(
crs_from=data_crs,
crs_to=target_crs,
always_xy=True,
)

# TODO: Handle rotated pole
cf_coords = target_crs.coordinate_system.to_cf()

# Get the new X and Y coordinates
target_y_coord = next(coord for coord in cf_coords if coord["axis"] == "Y")
target_x_coord = next(coord for coord in cf_coords if coord["axis"] == "X")

# Transform the coordinates
# If the data is vectorized, we just transform the points in full
# TODO: Handle 2D coordinates
if not is_regular_xy_coords(ds):
raise NotImplementedError("Only 1D coordinates are supported")

x_dim = ds.cf["X"].dims[0]
y_dim = ds.cf["Y"].dims[0]
if x_dim == [VECTORIZED_DIM]:
x = ds.cf["X"]
y = ds.cf["Y"]
else:
# Otherwise we need to transform the full grid
x, y = xr.broadcast(ds.cf["X"], ds.cf["Y"])

x, y = transformer.transform(x, y)

coords_to_drop = [
c for c in ds.coords if x_dim in ds[c].dims or y_dim in ds[c].dims
]

target_x_coord_name = target_x_coord["standard_name"]
target_y_coord_name = target_y_coord["standard_name"]

if target_x_coord_name in ds:
target_x_coord_name += "_"
if target_y_coord_name in ds:
target_y_coord_name += "_"

# Create the new dataset with vectorized coordinates
ds = ds.assign_coords(
{
target_x_coord_name: ((x_dim, y_dim), x),
target_y_coord_name: ((x_dim, y_dim), y),
}
)

ds = ds.drop(coords_to_drop)

return ds

0 comments on commit 6f004dd

Please sign in to comment.