Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] pretty boxes #14254

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 241 additions & 42 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Iterable, Optional, Dict, Tuple, Any, List
from typing import Iterable, Optional, Dict, Tuple, Any, List, Callable
from collections import Counter
import hail as hl
from hail.expr.expressions import (
Expand Down Expand Up @@ -37,6 +37,7 @@
from hail.utils import storage_level, default_handler, deduplicate
from hail.utils.java import warning, Env, info
from hail.utils.misc import wrap_to_tuple, get_key_by_exprs, get_select_exprs, check_annotate_exprs, process_joins
from hailtop.utils.box_drawing import TableStyle, standard_ts
import warnings


Expand Down Expand Up @@ -2807,25 +2808,202 @@ def write(
Env.backend().execute(ir.MatrixWrite(self._mir, writer))

class _Show:
def __init__(self, table, n_rows, actual_n_cols, displayed_n_cols, width, truncate, types):
self.table_show = table._show(n_rows, width, truncate, types)
self.actual_n_cols = actual_n_cols
self.displayed_n_cols = displayed_n_cols
def __init__(
self,
matrix_table: 'MatrixTable',
table_style: TableStyle,
n_rows: Optional[int] = None,
n_cols: Optional[int] = None,
include_row_fields: bool = False,
width: Optional[int] = None,
truncate_limit: Optional[int] = None,
should_show_types: bool = True,
):
self.matrix_table = matrix_table
self.table_style = table_style
self.n_rows = n_rows
self.n_cols = n_cols
self.include_row_fields = include_row_fields
self.width = width
self.truncate_limit = truncate_limit
self.should_show_types = should_show_types

def __str__(self):
s = self.table_show.__str__()
if self.displayed_n_cols != self.actual_n_cols:
s += f"showing the first { self.displayed_n_cols } of { self.actual_n_cols } columns"
return s
matrix_table = self.matrix_table
n_rows = self.n_rows
n_cols = self.n_cols
include_row_fields = self.include_row_fields
width = self.width
truncate_limit = self.truncate_limit
should_show_types = self.should_show_types

def truncate(string: str):
return string[: truncate_limit - 3] + "..." if len(string) > truncate_limit else string

table = matrix_table.head(n_rows, n_cols).localize_entries('entries', 'cols')
table_key_dtype = table.key.dtype
if len(table.key) > 0:
# order_by unkeys the table
table = table.order_by(*table.key)

row_dtype = table.row.dtype
rows, _ = table._take_n(n_rows)
row_fields = list(row_dtype) if include_row_fields else list(table_key_dtype)
row_fields = [row_field for row_field in row_fields if row_field != 'entries']
truncated_row_fields = [truncate(field) for field in row_fields]
truncated_rows = [[truncate(repr(row[field])) for field in row_fields] for row in rows]
truncated_row_field_types = (
[truncate(str(row_dtype[field])) for field in row_fields] if should_show_types else None
)
row_field_right_align = [hl.expr.types.is_numeric(row_dtype[field]) for field in row_fields]

entry_dtype = matrix_table.entry.dtype
entry_fields = list(entry_dtype)
# entries has shape (n_rows, n_cols, len(entry_fields))
entries = [[[entry[field] for field in entry_fields] for entry in row['entries']] for row in rows]
truncated_entries = [[[truncate(repr(entry)) for entry in column] for column in row] for row in entries]
# entry_field_right_align has shape (len(entry_fields) * n_cols)
entry_field_right_align = [hl.expr.types.is_numeric(entry_dtype[field]) for field in entry_fields] * n_cols
# truncated_entry_field_types has shape (len(entry_fields) * n_cols)
truncated_entry_field_types = (
[truncate(str(entry_dtype[field])) for field in entry_fields] * n_cols if should_show_types else None
)

# truncated_values has shape (n_rows, len(row_fields) + n_cols * len(entry_fields))
truncated_values = [
truncated_row
+ [
truncated_entry
for truncated_entries in truncated_row_entries
for truncated_entry in truncated_entries
]
for truncated_row, truncated_row_entries in zip(truncated_rows, truncated_entries)
]
truncated_types = truncated_row_field_types + truncated_entry_field_types if should_show_types else None
right_align = row_field_right_align + entry_field_right_align

# Here we must distinguish between matrix table columns
# and the columns that are used to display this matrix table.
# Hail displays both the row fields and the matrix table's columns
# and entry fields as columns in a tabular format.
# We refer to the columns used to display the matrix table as "display columns".
display_column_names = None
if len(matrix_table.col_key.dtype) == 1 and matrix_table.col_key.dtype[0] in (
hl.tstr,
hl.tint32,
hl.tint64,
):
columns = matrix_table.col_key[0].take(n_cols)
if len(set(columns)) == len(columns):
display_column_names = truncated_row_fields + [
f'{repr(column)}.{entry_field}' for column in columns for entry_field in entry_fields
]
else:
display_column_names = truncated_row_fields + [
f'<col {index}>.{entry_field}' for index in range(0, n_cols) for entry_field in entry_fields
]

def max_value_width(index):
return max(itertools.chain([0], (len(row[index]) for row in truncated_values)))

display_column_widths = [
max(
len(display_column_names[index]),
len(truncated_types[index]) if should_show_types else 0,
max_value_width(index),
)
for index in range(len(display_column_names))
]

display_block_slices = []
start_index = 0
end_index = 1
block_width = display_column_widths[0] + 4 if display_column_widths else 0
while end_index < len(display_column_names):
block_width += display_column_widths[end_index] + 3
if block_width > width:
display_block_slices.append(slice(start_index, end_index))
start_index = end_index
block_width = display_column_widths[end_index] + 4
end_index += 1
display_block_slices.append(slice(start_index, end_index))

def pad(value: str, width: int, right_align: bool):
extra_count = width - len(value)
if right_align:
return ' ' * extra_count + value
else:
return value + ' ' * extra_count

ascii_str = ''
is_first_display_block = True
for block_slice in display_block_slices:
if is_first_display_block:
is_first_display_block = False
else:
ascii_str += '\n'

block_display_column_widths = display_column_widths[block_slice]
block_right_align = right_align[block_slice]

def format_row(values: List[str], widths: List[int], right_align: List[bool], *, header: bool = False):
return self.table_style.format_row(map(pad, values, widths, right_align), header=header)

ascii_str += self.table_style.format_line(block_display_column_widths, 'top')
ascii_str += format_row(
display_column_names[block_slice], block_display_column_widths, block_right_align, header=True
)
if should_show_types:
ascii_str += self.table_style.format_line(block_display_column_widths, 'top-inner')
ascii_str += format_row(
truncated_types[block_slice], block_display_column_widths, block_right_align, header=True
)
ascii_str += self.table_style.format_line(block_display_column_widths, 'top-bottom')
for row in truncated_values:
ascii_str += format_row(
row[block_slice], block_display_column_widths, block_right_align, header=False
)
ascii_str += self.table_style.format_line(block_display_column_widths, 'bottom')

if n_rows < matrix_table.count_rows():
row_count = len(rows)
ascii_str += f"showing top { row_count } { 'row' if row_count == 1 else 'rows' }\n"
total_n_cols = matrix_table.count_cols()
if n_cols != total_n_cols:
ascii_str += f"showing the first { n_cols } of { total_n_cols } columns"

return ascii_str

def __repr__(self):
return self.__str__()

def _repr_html_(self):
s = self.table_show._repr_html_()
if self.displayed_n_cols != self.actual_n_cols:
# This method is not thoroughly tested and should be used carefully.
total_n_cols = self.matrix_table.count_cols()
table = self.matrix_table.localize_entries('entries', 'cols')
if len(table.key) > 0:
table = table.order_by(*table.key)
col_key_type = self.matrix_table.col_key.dtype

col_headers = [f'<col {i}>' for i in range(0, self.n_cols)]
if len(col_key_type) == 1 and col_key_type[0] in (hl.tstr, hl.tint32, hl.tint64):
cols = self.matrix_table.col_key[0].take(self.n_cols)
if len(set(cols)) == len(cols):
col_headers = [repr(c) for c in cols]

entries = {col_headers[i]: table.entries[i] for i in range(0, self.n_cols)}
table = table.select(
**{f: table[f] for f in self.matrix_table.row_key},
**{f: table[f] for f in self.matrix_table.row_value if self.include_row_fields},
**entries,
)
table_show = table._show(
self.n_rows, self.width, self.truncate_limit, self.should_show_types, self.table_style
)
s = table_show._repr_html_()
if self.n_cols != total_n_cols:
s += '<p style="background: #fdd; padding: 0.4em;">'
s += f"showing the first { self.displayed_n_cols } of { self.actual_n_cols } columns"
s += f"showing the first { self.n_cols } of { total_n_cols } columns"
s += '</p>\n'
return s

Expand All @@ -2837,11 +3015,20 @@ def _repr_html_(self):
truncate=nullable(int),
types=bool,
handler=nullable(anyfunc),
box=nullable(TableStyle),
)
def show(
self, n_rows=None, n_cols=None, include_row_fields=False, width=None, truncate=None, types=True, handler=None
self,
n_rows: Optional[int] = None,
n_cols: Optional[int] = None,
include_row_fields: bool = False,
width: Optional[int] = None,
truncate: Optional[int] = None,
types: bool = True,
handler: Optional[Callable[[str], Any]] = None,
table_style: Optional[TableStyle] = None,
):
"""Print the first few rows of the matrix table to the console.
"""Print the first few rows and columns of the matrix table to the console.

.. include:: _templates/experimental.rst

Expand All @@ -2857,6 +3044,8 @@ def show(
Maximum number of rows to show.
n_cols : :obj:`int`
Maximum number of columns to show.
include_row_fields : :obj:`bool`
Whether to include row fields in the output.
width : :obj:`int`
Horizontal width at which to break fields.
truncate : :obj:`int`, optional
Expand All @@ -2866,43 +3055,53 @@ def show(
Print an extra header line with the type of each field.
handler : Callable[[str], Any]
Handler function for data string.
table_style : TableStyle
A table style to use.
"""

def estimate_size(struct_expression):
return sum(max(len(f), len(str(x.dtype))) + 3 for f, x in struct_expression.flatten().items())

if n_cols is None:
if n_rows is None: # Careful with truthiness here, n_rows can be 0
n_rows = 10
n_rows = min(max(n_rows, 0), self.count_rows())
if n_cols is None or width is None:
import shutil

(characters, _) = shutil.get_terminal_size((80, 10))
characters -= 6 # borders
key_characters = estimate_size(self.row_key)
characters -= key_characters
if include_row_fields:
characters -= estimate_size(self.row_value)
characters = max(characters, 0)
n_cols = characters // (estimate_size(self.entry) + 4) # 4 for the column index
actual_n_cols = self.count_cols()
displayed_n_cols = min(actual_n_cols, n_cols)

t = self.localize_entries('entries', 'cols')
if len(t.key) > 0:
t = t.order_by(*t.key)
col_key_type = self.col_key.dtype

col_headers = [f'<col {i}>' for i in range(0, displayed_n_cols)]
if len(col_key_type) == 1 and col_key_type[0] in (hl.tstr, hl.tint32, hl.tint64):
cols = self.col_key[0].take(displayed_n_cols)
if len(set(cols)) == len(cols):
col_headers = [repr(c) for c in cols]

entries = {col_headers[i]: t.entries[i] for i in range(0, displayed_n_cols)}
t = t.select(
**{f: t[f] for f in self.row_key}, **{f: t[f] for f in self.row_value if include_row_fields}, **entries
)
(characters, _) = shutil.get_terminal_size()
if width is None:
width = characters
if n_cols is None:
characters -= 6 # borders
key_characters = estimate_size(self.row_key)
characters -= key_characters
if include_row_fields:
characters -= estimate_size(self.row_value)
characters = max(characters, 0)
n_cols = characters // (estimate_size(self.entry) + 4) # 4 for the column index
total_n_cols = self.count_cols()
n_cols = min(total_n_cols, n_cols)

truncate = min(truncate, width - 4) if truncate is not None else width - 4
truncate = max(truncate, 4)
truncate_limit = truncate

if handler is None:
handler = default_handler()
return handler(MatrixTable._Show(t, n_rows, actual_n_cols, displayed_n_cols, width, truncate, types))
if table_style is None:
table_style = standard_ts

show = self._Show(
self,
table_style,
n_rows,
n_cols,
include_row_fields,
width,
truncate_limit,
types,
)
handler(show)

def globals_table(self) -> Table:
"""Returns a table with a single row with the globals of the matrix table.
Expand Down
Loading
Loading