Skip to content

Commit

Permalink
Fix #499 pkcli gets commands from class Commands (#500)
Browse files Browse the repository at this point in the history
- Instantiate the class first
- Ignore superclass methods
  • Loading branch information
robnagler authored Aug 23, 2024
1 parent 49d5c46 commit 80278ba
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 41 deletions.
62 changes: 36 additions & 26 deletions pykern/pkcli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import argparse
import importlib
import inspect
import itertools
import os
import os.path
import pkgutil
import re
import sys
import types

# Avoid pykern imports so avoid dependency issues for pkconfig
from pykern import pkconfig
Expand Down Expand Up @@ -139,6 +141,8 @@ def main(root_pkg, argv=None):
cli = _module(root_pkg, module_name)
if not cli:
return 1
if c := getattr(cli, "Commands", None):
cli = c()
prog = prog + " " + module_name
parser = CustomParser(prog)
cmds = _commands(cli)
Expand All @@ -160,8 +164,6 @@ def main(root_pkg, argv=None):
parser.error("too few arguments")
if argv[0][0] != "-":
argv[0] = _module_to_cmd(argv[0])
from pykern.pkdebug import pkdp

try:
res = argh.dispatch(parser, argv=argv)
except CommandError as e:
Expand All @@ -187,20 +189,42 @@ def _argh_name_mapping_policy():


def _commands(cli):
"""Extracts all public functions from `cli`
"""Extracts all public functions or methods from `cli`
Args:
cli (module): where commands are executed from
cli (object): where commands are executed from
Returns:
list of function: public functions sorted alphabetically
list: commands sorted alphabetically
"""
res = []
for n, t in inspect.getmembers(cli):
if _is_command(t, cli):
res.append(t)
sorted(res, key=lambda f: f.__name__.lower())
return res

def _functions():
return _iter(
lambda t: inspect.isfunction(t)
and hasattr(t, "__module__")
and t.__module__ == cli.__name__
)

def _iter(predicate):
for n, t in inspect.getmembers(cli, predicate=predicate):
if not n.startswith("_"):
yield (t)

def _methods():
x = frozenset(_super_methods())
return _iter(
lambda t: inspect.ismethod(t)
and t.__name__ not in x
and t.__name__ in dir(cli)
)

def _super_methods():
return itertools.chain(*(dir(b) for b in cli.__class__.__bases__))

return sorted(
_functions() if isinstance(cli, types.ModuleType) else _methods(),
key=lambda f: f.__name__.lower(),
)


def _default_command(cmds, argv):
Expand Down Expand Up @@ -270,24 +294,10 @@ def _imp(path_list):
return _imp(path + [name])


def _is_command(obj, cli):
"""Is this a valid command function?
Args:
obj (object): candidate
cli (module): module to which function should belong
Returns:
bool: True if obj is a valid command
"""
if not inspect.isfunction(obj) or obj.__name__.startswith("_"):
return False
return hasattr(obj, "__module__") and obj.__module__ == cli.__name__


def _is_help(argv):
"""Does the user want help?
Args:
argv (list): list of args
Expand Down
4 changes: 1 addition & 3 deletions tests/pkcli_data/package1/pkcli/conf1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import absolute_import, division, print_function

last_cmd = None

from pykern.pkdebug import pkdp

def cmd1(arg1):
"""Subject line for cmd1
Expand All @@ -14,6 +11,7 @@ def cmd1(arg1):
last_cmd = cmd1
return


def cmd2():
"""Subject line for cmd2
Expand Down
3 changes: 1 addition & 2 deletions tests/pkcli_data/package1/pkcli/conf2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import absolute_import, division, print_function

last_cmd = None


def cmd1(arg1):
global last_cmd
last_cmd = cmd1
Expand Down
4 changes: 1 addition & 3 deletions tests/pkcli_data/package1/pkcli/conf3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import absolute_import, division, print_function

last_cmd = None

last_arg = None


def default_command(arg1):
global last_cmd, last_arg
last_cmd = default_command
last_arg = arg1
return
22 changes: 22 additions & 0 deletions tests/pkcli_data/package1/pkcli/conf4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
last_self = None
last_cmd = None
last_arg = None


def should_not_find():
pass


class Commands:
def __init__(self):
global last_self

last_self = self

def default_command(self, arg1):
global last_cmd, last_arg
last_cmd = self.default_command
last_arg = arg1

def _should_not_find():
pass
22 changes: 22 additions & 0 deletions tests/pkcli_data/package1/pkcli/conf5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
last_self = None
last_cmd = None
last_arg = None


def should_not_find():
pass


class Commands:
def __init__(self):
global last_self

last_self = self

def cmd1(self, arg1):
global last_cmd, last_arg
last_cmd = self.cmd1
last_arg = arg1

def _should_not_find():
pass
27 changes: 20 additions & 7 deletions tests/pkcli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,23 @@ def test_main1():
_conf(rp, ["conf1", "cmd2"], first_time=False)
_conf(rp, ["conf2", "cmd1", "2"])
_conf(rp, ["conf3", "3"], default_command=True)
_conf(rp, ["conf4", "99"], default_command=True)
first_self = _conf(rp, ["conf5", "cmd1", "10"])
_conf(rp, ["conf5", "cmd1", "3"], first_self=first_self)


def test_main2(capsys):
from pykern import pkconfig
import six

all_modules = r":\nconf1\nconf2\nconf3\n$"
all_modules = r":\nconf1\nconf2\nconf3\nconf4\nconf5\n$"
pkconfig.reset_state_for_testing()
rp = "package1"
_deviance(rp, [], None, all_modules, capsys)
_deviance(rp, ["--help"], None, all_modules, capsys)
_deviance(rp, ["conf1"], SystemExit, r"cmd1,cmd2.*too few", capsys)
_deviance(rp, ["conf1", "-h"], SystemExit, r"\{cmd1,cmd2\}.*commands", capsys)
_deviance(rp, ["conf5", "-h"], SystemExit, r"\{cmd1\}.*commands", capsys)
if six.PY2:
_deviance(rp, ["not_found"], None, r"no module", capsys)
else:
Expand Down Expand Up @@ -112,19 +116,28 @@ def test_command_info():
)


def _conf(root_pkg, argv, first_time=True, default_command=False):
def _conf(root_pkg, argv, first_time=True, default_command=False, first_self=None):
from pykern.pkunit import pkeq, pkne, pkok
import sys

rv = None
full_name = ".".join([root_pkg, "pkcli", argv[0]])
if not first_time:
assert not hasattr(sys.modules, full_name)
assert _main(root_pkg, argv) == 0, "Unexpected exit"
pkok(not hasattr(sys.modules, full_name), "module loaded before first call")
pkeq(0, _main(root_pkg, argv), "Unexpected exit")
m = sys.modules[full_name]
if default_command:
assert m.last_cmd.__name__ == "default_command"
assert m.last_arg == argv[1]
pkeq("default_command", m.last_cmd.__name__)
pkeq(argv[1], m.last_arg)
else:
assert m.last_cmd.__name__ == argv[1]
pkeq(argv[1], m.last_cmd.__name__)
if hasattr(m, "last_self"):
if first_self:
pkne(first_self, m.last_self)
else:
pkok(m.last_self, "")
rv = m.last_self
return rv


def _deviance(root_pkg, argv, exc, expect, capsys):
Expand Down

0 comments on commit 80278ba

Please sign in to comment.