diff --git a/sympl/_core/dataarray.py b/sympl/_core/dataarray.py index b6b3f84..933786b 100644 --- a/sympl/_core/dataarray.py +++ b/sympl/_core/dataarray.py @@ -4,6 +4,7 @@ class DataArray(xr.DataArray): + __slots__ = [] def __add__(self, other): """If this DataArray is on the left side of the addition, keep its diff --git a/sympl/_core/get_np_arrays.py b/sympl/_core/get_np_arrays.py index 9a84b57..1cb98d4 100644 --- a/sympl/_core/get_np_arrays.py +++ b/sympl/_core/get_np_arrays.py @@ -43,13 +43,16 @@ def get_numpy_array(data_array, out_dims, dim_lengths): dict of dim_lengths that will give the length of any missing dims in the data_array. """ - if len(data_array.values.shape) == 0 and len(out_dims) == 0: - return data_array.values # special case, 0-dimensional scalar array + if len(data_array.data.shape) == 0 and len(out_dims) == 0: + return data_array.data # special case, 0-dimensional scalar array else: missing_dims = [dim for dim in out_dims if dim not in data_array.dims] for dim in missing_dims: data_array = data_array.expand_dims(dim) - numpy_array = data_array.transpose(*out_dims).values + if not all(dim1 == dim2 for dim1, dim2 in zip(data_array.dims, out_dims)): + numpy_array = data_array.transpose(*out_dims).data + else: + numpy_array = data_array.data if len(missing_dims) == 0: out_array = numpy_array else: # expand out missing dims which are currently length 1. diff --git a/sympl/_core/restore_dataarray.py b/sympl/_core/restore_dataarray.py index 791fba1..cf87e90 100644 --- a/sympl/_core/restore_dataarray.py +++ b/sympl/_core/restore_dataarray.py @@ -8,9 +8,10 @@ def ensure_values_are_arrays(array_dict): - for name, value in array_dict.items(): - if not isinstance(value, np.ndarray): - array_dict[name] = np.asarray(value) + pass + # for name, value in array_dict.items(): + # if not isinstance(value, np.ndarray): + # array_dict[name] = np.asarray(value) def get_alias_or_name(name, output_properties, input_properties): diff --git a/sympl/_core/units.py b/sympl/_core/units.py index d2977e8..2278ce0 100644 --- a/sympl/_core/units.py +++ b/sympl/_core/units.py @@ -1,22 +1,24 @@ # -*- coding: utf-8 -*- +import functools import pint class UnitRegistry(pint.UnitRegistry): - + @functools.lru_cache def __call__(self, input_string, **kwargs): return super(UnitRegistry, self).__call__( - input_string.replace( - u'%', 'percent').replace( - u'°', 'degree' - ), - **kwargs) + input_string.replace(u"%", "percent").replace(u"°", "degree"), **kwargs + ) unit_registry = UnitRegistry() -unit_registry.define('degrees_north = degree_north = degree_N = degrees_N = degreeN = degreesN') -unit_registry.define('degrees_east = degree_east = degree_E = degrees_E = degreeE = degreesE') -unit_registry.define('percent = 0.01*count = %') +unit_registry.define( + "degrees_north = degree_north = degree_N = degrees_N = degreeN = degreesN" +) +unit_registry.define( + "degrees_east = degree_east = degree_E = degrees_E = degreeE = degreesE" +) +unit_registry.define("percent = 0.01*count = %") def units_are_compatible(unit1, unit2): @@ -63,9 +65,7 @@ def clean_units(unit_string): def is_valid_unit(unit_string): """Returns True if the unit string is recognized, and False otherwise.""" - unit_string = unit_string.replace( - '%', 'percent').replace( - '°', 'degree') + unit_string = unit_string.replace("%", "percent").replace("°", "degree") try: unit_registry(unit_string) except pint.UndefinedUnitError: @@ -75,16 +75,17 @@ def is_valid_unit(unit_string): def data_array_to_units(value, units): - if not hasattr(value, 'attrs') or 'units' not in value.attrs: - raise TypeError( - 'Cannot retrieve units from type {}'.format(type(value))) - elif unit_registry(value.attrs['units']) != unit_registry(units): - attrs = value.attrs.copy() - value = unit_registry.Quantity(value, value.attrs['units']).to(units).magnitude - attrs['units'] = units - value.attrs = attrs + if not hasattr(value, "attrs") or "units" not in value.attrs: + raise TypeError("Cannot retrieve units from type {}".format(type(value))) + elif unit_registry(value.attrs["units"]) != unit_registry(units): + out = value.copy() + out.data[...] = ( + unit_registry.convert(1, value.attrs["units"], units) * value.data + ) + out.attrs["units"] = units + value = out return value def from_unit_to_another(value, original_units, new_units): - return (unit_registry(original_units)*value).to(new_units).magnitude + return (unit_registry(original_units) * value).to(new_units).magnitude