Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grass.script: Grass script keyvalue typing #331

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions python/grass/script/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,24 @@
import json
import csv
import io
from collections.abc import Mapping
from tempfile import NamedTemporaryFile
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar

from .utils import KeyValue, parse_key_val, basename, encode, decode, try_remove
from grass.exceptions import ScriptError, CalledModuleError
from grass.grassdb.manage import resolve_mapset_path


if TYPE_CHECKING:
from _typeshed import StrPath


T = TypeVar("T")
_Env = Mapping[str, str]


# subprocess wrapper that uses shell on Windows
class Popen(subprocess.Popen):
_builtin_exts = {".com", ".exe", ".bat", ".cmd"}
Expand Down Expand Up @@ -1032,20 +1042,22 @@ def _compare_units(dic):


def _text_to_key_value_dict(
filename, sep=":", val_sep=",", checkproj=False, checkunits=False
):
filename: StrPath,
sep: str = ":",
val_sep: str = ",",
checkproj: bool = False,
checkunits: bool = False,
) -> KeyValue[list[int | float | str]]:
"""Convert a key-value text file, where entries are separated by newlines
and the key and value are separated by `sep', into a key-value dictionary
and discover/use the correct data types (float, int or string) for values.

:param str filename: The name or name and path of the text file to convert
:param str sep: The character that separates the keys and values, default
is ":"
:param str val_sep: The character that separates the values of a single
:param filename: The name or name and path of the text file to convert
:param sep: The character that separates the keys and values, default is ":"
:param val_sep: The character that separates the values of a single
key, default is ","
:param bool checkproj: True if it has to check some information about
projection system
:param bool checkproj: True if it has to check some information about units
:param checkproj: True if it has to check some information about projection system
:param checkunits: True if it has to check some information about units

:return: The dictionary

Expand All @@ -1066,7 +1078,7 @@ def _text_to_key_value_dict(
"""
with Path(filename).open() as f:
text = f.readlines()
kvdict = KeyValue()
kvdict: KeyValue[list[int | float | str]] = KeyValue()

for line in text:
if line.find(sep) >= 0:
Expand All @@ -1077,7 +1089,7 @@ def _text_to_key_value_dict(
# Jump over empty values
continue
values = value.split(val_sep)
value_list = []
value_list: list[int | float | str] = []

for value in values:
not_float = False
Expand Down Expand Up @@ -1173,7 +1185,7 @@ def compare_key_value_text_files(
# interface to g.gisenv


def gisenv(env=None):
def gisenv(env: _Env | None = None) -> KeyValue[str | None]:
"""Returns the output from running g.gisenv (with no arguments), as a
dictionary. Example:

Expand All @@ -1191,14 +1203,14 @@ def gisenv(env=None):
# interface to g.region


def locn_is_latlong(env=None) -> bool:
def locn_is_latlong(env: _Env | None = None) -> bool:
"""Tests if location is lat/long. Value is obtained
by checking the "g.region -pu" projection code.

:return: True for a lat/long region, False otherwise
"""
s = read_command("g.region", flags="pu", env=env)
kv = parse_key_val(s, ":")
kv: KeyValue[str | None] = parse_key_val(s, ":")
return kv["projection"].split(" ")[0] == "3"


Expand Down Expand Up @@ -1246,7 +1258,9 @@ def region(region3d=False, complete=False, env=None):
return reg


def region_env(region3d=False, flags=None, env=None, **kwargs):
def region_env(
region3d: bool = False, flags: str | None = None, env: _Env | None = None, **kwargs
) -> str:
"""Returns region settings as a string which can used as
GRASS_REGION environmental variable.

Expand All @@ -1256,8 +1270,8 @@ def region_env(region3d=False, flags=None, env=None, **kwargs):
See also :func:`use_temp_region()` for alternative method how to define
temporary region used for raster-based computation.

:param bool region3d: True to get 3D region
:param string flags: for example 'a'
:param region3d: True to get 3D region
:param flags: for example 'a'
:param env: dictionary with system environment variables (`os.environ` by default)
:param kwargs: g.region's parameters like 'raster', 'vector' or 'region'

Expand All @@ -1271,7 +1285,7 @@ def region_env(region3d=False, flags=None, env=None, **kwargs):
:return: empty string on error
"""
# read proj/zone from WIND file
gis_env = gisenv(env)
gis_env: KeyValue[str | None] = gisenv(env)
windfile = os.path.join(
gis_env["GISDBASE"], gis_env["LOCATION_NAME"], gis_env["MAPSET"], "WIND"
)
Expand Down
106 changes: 83 additions & 23 deletions python/grass/script/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@
import string

from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, AnyStr, Callable, TypeVar, cast, overload


if TYPE_CHECKING:
from _typeshed import FileDescriptorOrPath, StrPath, StrOrBytesPath
from _typeshed import FileDescriptorOrPath, StrOrBytesPath, StrPath


# Type variables
T = TypeVar("T")
VT = TypeVar("VT") # Value type


def float_or_dms(s) -> float:
Expand Down Expand Up @@ -144,7 +150,7 @@ def basename(path: StrPath, ext: str | None = None) -> str:
return name


class KeyValue(dict):
class KeyValue(dict[str, VT]):
"""A general-purpose key-value store.

KeyValue is a subclass of dict, but also allows entries to be read and
Expand All @@ -157,16 +163,19 @@ class KeyValue(dict):
>>> reg.south = 205
>>> reg['south']
205

The keys of KeyValue are strings. To use other key types, use other mapping types.
To use the attribute syntax, the keys must be valid Python attribute names.
"""

def __getattr__(self, key):
def __getattr__(self, key: str) -> VT:
return self[key]

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: VT) -> None:
self[key] = value


def _get_encoding():
def _get_encoding() -> str:
try:
# Python >= 3.11
encoding = locale.getencoding()
Expand All @@ -177,7 +186,7 @@ def _get_encoding():
return encoding


def decode(bytes_, encoding=None):
def decode(bytes_: AnyStr, encoding: str | None = None) -> str:
"""Decode bytes with default locale and return (unicode) string

No-op if parameter is not bytes (assumed unicode string).
Expand Down Expand Up @@ -205,13 +214,13 @@ def decode(bytes_, encoding=None):
raise TypeError(msg)


def encode(string, encoding=None):
def encode(string: AnyStr, encoding: str | None = None) -> bytes:
"""Encode string with default locale and return bytes with that encoding

No-op if parameter is bytes (assumed already encoded).
This ensures garbage in, garbage out.

:param str string: the string to encode
:param string: the string to encode
:param encoding: encoding to be used, default value is None

Example
Expand All @@ -230,36 +239,77 @@ def encode(string, encoding=None):
enc = _get_encoding() if encoding is None else encoding
return string.encode(enc)
# if something else than text
msg = "can only accept types str and bytes"
msg = "Can only accept types str and bytes"
raise TypeError(msg)


def text_to_string(text, encoding=None):
def text_to_string(text: AnyStr, encoding: str | None = None) -> str:
"""Convert text to str. Useful when passing text into environments,
in Python 2 it needs to be bytes on Windows, in Python 3 in needs unicode.
"""
return decode(text, encoding=encoding)


def parse_key_val(s, sep="=", dflt=None, val_type=None, vsep=None) -> KeyValue:
@overload
def parse_key_val(
s: AnyStr,
sep: str = "=",
dflt: T | None = None,
val_type: None = ...,
vsep: str | None = None,
) -> KeyValue[str | T | None]:
pass


@overload
def parse_key_val(
s: AnyStr,
sep: str = "=",
dflt: T | None = None,
val_type: Callable[[str], T] = ...,
vsep: str | None = None,
) -> KeyValue[T | None]:
pass


@overload
def parse_key_val(
s: AnyStr,
sep: str = "=",
dflt: T | None = None,
val_type: Callable[[str], T] | None = None,
vsep: str | None = None,
) -> KeyValue[str | T] | KeyValue[T | None] | KeyValue[T] | KeyValue[str | T | None]:
pass


def parse_key_val(
s: AnyStr,
sep: str = "=",
dflt: T | None = None,
val_type: Callable[[str], T] | None = None,
vsep: str | None = None,
) -> KeyValue[str | T] | KeyValue[T | None] | KeyValue[T] | KeyValue[str | T | None]:
"""Parse a string into a dictionary, where entries are separated
by newlines and the key and value are separated by `sep` (default: `=`)

>>> parse_key_val('min=20\\nmax=50') == {'min': '20', 'max': '50'}
True
>>> parse_key_val('min=20\\nmax=50',
... val_type=float) == {'min': 20, 'max': 50}
>>> parse_key_val('min=20\\nmax=50', val_type=float) == {'min': 20, 'max': 50}
True

:param str s: string to be parsed
:param str sep: key/value separator
:param s: string to be parsed
:param sep: key/value separator
:param dflt: default value to be used
:param val_type: value type (None for no cast)
:param vsep: vertical separator (default is Python 'universal newlines' approach)

:return: parsed input (dictionary of keys/values)
"""
result = KeyValue()

result: (
KeyValue[str | T] | KeyValue[T | None] | KeyValue[T] | KeyValue[str | T | None]
) = KeyValue()

if not s:
return result
Expand All @@ -269,23 +319,33 @@ def parse_key_val(s, sep="=", dflt=None, val_type=None, vsep=None) -> KeyValue:
vsep = encode(vsep) if vsep else vsep

if vsep:
lines = s.split(vsep)
lines: list[bytes] | list[str] = s.split(vsep)
try:
lines.remove("\n")
except ValueError:
pass
else:
lines = s.splitlines()

if callable(val_type):
result = cast("KeyValue[T | None]", result)
for line in lines:
kv: list[bytes] | list[str] = line.split(sep, 1)
k: str = decode(kv[0].strip())
result[k] = val_type(decode(kv[1].strip())) if len(kv) > 1 else dflt

if dflt is not None:
result = cast("KeyValue[T]", result)
return result

result = cast("KeyValue[str | T | None]", result)
for line in lines:
kv = line.split(sep, 1)
k = decode(kv[0].strip())
v = decode(kv[1].strip()) if len(kv) > 1 else dflt
result[k] = decode(kv[1].strip()) if len(kv) > 1 else dflt

if val_type:
result[k] = val_type(v)
else:
result[k] = v
if dflt is not None:
result = cast("KeyValue[str | T]", result)

return result

Expand Down
Loading