Source code for hail.utils.struct

import pprint
from collections import OrderedDict
from collections.abc import Mapping
from typing import Any, Dict

from hail.typecheck import anytype, typecheck, typecheck_method
from hail.utils.misc import get_nice_attr_error, get_nice_field_error


[docs]class Struct(Mapping): """ Nested annotation structure. >>> bar = hl.Struct(**{'foo': 5, '1kg': 10}) Struct elements are treated as both 'items' and 'attributes', which allows either syntax for accessing the element "foo" of struct "bar": >>> bar.foo >>> bar['foo'] Field names that are not valid Python identifiers, such as fields that start with numbers or contain spaces, must be accessed with the latter syntax: >>> bar['1kg'] The ``pprint`` module can be used to print nested Structs in a more human-readable fashion: >>> from pprint import pprint >>> pprint(bar) Parameters ---------- attributes Field names and values. Note ---- This object refers to the Python value returned by taking or collecting Hail expressions, e.g. ``mt.info.take(5)``. This is rare; it is much more common to manipulate the :class:`.StructExpression` object, which is constructed using the :func:`.struct` function. """ def __init__(self, **kwargs): # Set this way to avoid an infinite recursion in `__getattr__`. self.__dict__["_fields"] = kwargs def __contains__(self, item): return item in self._fields def __getstate__(self) -> Dict[str, Any]: return self._fields def __setstate__(self, state: Dict[str, Any]): self.__dict__["_fields"] = state def _get_field(self, item): if item in self._fields: return self._fields[item] else: raise KeyError(get_nice_field_error(self, item)) @typecheck_method(item=str) def __getitem__(self, item): return self._get_field(item) def __setattr__(self, key, value): if key in self._fields: raise ValueError("Structs are immutable, cannot overwrite a field.") else: super().__setattr__(key, value) def __getattr__(self, item): if item in self.__dict__: return self.__dict__[item] elif item in self._fields: return self._fields[item] else: raise AttributeError(get_nice_attr_error(self, item)) def __len__(self): return len(self._fields) def __repr__(self): return str(self) def __str__(self): if all(k.isidentifier() for k in self._fields): return 'Struct(' + ', '.join(f'{k}={v!r}' for k, v in self._fields.items()) + ')' return 'Struct(**{' + ', '.join(f'{k!r}: {v!r}' for k, v in self._fields.items()) + '})' def __eq__(self, other): return self._fields == other._fields if isinstance(other, Struct) else NotImplemented def __hash__(self): return 37 + hash(tuple(sorted(self._fields.items()))) def __iter__(self): return iter(self._fields) def __dir__(self): super_dir = super().__dir__() return super_dir + list(self._fields.keys()) def annotate(self, **kwargs): """Add new fields or recompute existing fields. Notes ----- If an expression in `kwargs` shares a name with a field of the struct, then that field will be replaced but keep its position in the struct. New fields will be appended to the end of the struct. Parameters ---------- kwargs : keyword args Fields to add. Returns ------- :class:`.Struct` Struct with new or updated fields. Examples -------- Define a Struct `s` >>> s = hl.Struct(food=8, fruit=5) Add a new field to `s` >>> s.annotate(bar=2) Struct(food=8, fruit=5, bar=2) Add multiple fields to `s` >>> s.annotate(banana=2, apple=3) Struct(food=8, fruit=5, banana=2, apple=3) Recompute an existing field in `s` >>> s.annotate(bar=4, fruit=2) Struct(food=8, fruit=2, bar=4) """ d = OrderedDict(self.items()) for k, v in kwargs.items(): d[k] = v return Struct(**d) @typecheck_method(fields=str, kwargs=anytype) def select(self, *fields, **kwargs): """Select existing fields and compute new ones. Notes ----- The `fields` argument is a list of field names to keep. These fields will appear in the resulting struct in the order they appear in `fields`. The `kwargs` arguments are new fields to add. Parameters ---------- fields : varargs of :class:`str` Field names to keep. named_exprs : keyword args New field. Returns ------- :class:`.Struct` Struct containing specified existing fields and computed fields. Examples -------- Define a Struct 's' >>> s = hl.Struct(foo=5, apple=10) Keep just one original field >>> s.select('foo') Struct(foo=5) Add one new field and keeps one old field >>> s.select('apple', bar=123) Struct(apple=10, bar=123) Adds two new fields and replaces old fields >>> s.select(bar=123, banana=1) Struct(bar=123, banana=1) """ d = OrderedDict() for a in fields: d[a] = self[a] for k, v in kwargs.items(): if k in d: raise ValueError("Cannot select and assign field '{}' in the same statement".format(k)) d[k] = v return Struct(**d) @typecheck_method(args=str) def drop(self, *args): """Drop fields from the struct. Parameters ---------- fields: varargs of :class:`str` Fields to drop. Returns ------- :class:`.Struct` Struct without certain fields. Examples -------- Define a Struct `s` >>> s = hl.Struct(food=8, fruit=5, bar=2, apple=10) Drop one field from `s` >>> s.drop('bar') Struct(food=8, fruit=5, apple=10) Drop two fields from `s` >>> s.drop('food', 'fruit') Struct(bar=2, apple=10) """ d = OrderedDict((k, v) for k, v in self.items() if k not in args) return Struct(**d)
@typecheck(struct=Struct) def to_dict(struct): return dict(struct.items()) _old_printer = pprint.PrettyPrinter class StructPrettyPrinter(pprint.PrettyPrinter): def _format(self, obj, stream, indent, allowance, context, level, *args, **kwargs): if isinstance(obj, Struct): rep = self._repr(obj, context, level) max_width = self._width - indent - allowance if len(rep) <= max_width: stream.write(rep) return stream.write('Struct(') indent += len('Struct(') if all(k.isidentifier() for k in obj): n = len(obj.items()) for i, (k, v) in enumerate(obj.items()): is_first = i == 0 is_last = i == n - 1 if not is_first: stream.write(' ' * indent) stream.write(k) stream.write('=') this_indent = indent + len(k) + len('=') self._format(v, stream, this_indent, allowance, context, level, *args, **kwargs) if not is_last: stream.write(',\n') else: stream.write('**{') indent += len('**{') n = len(obj.items()) for i, (k, v) in enumerate(obj.items()): is_first = i == 0 is_last = i == n - 1 if not is_first: stream.write(' ' * indent) stream.write(repr(k)) stream.write(': ') this_indent = indent + len(repr(k)) + len(': ') self._format(v, stream, this_indent, allowance, context, level, *args, **kwargs) if not is_last: stream.write(',\n') stream.write('}') stream.write(')') else: _old_printer._format(self, obj, stream, indent, allowance, context, level, *args, **kwargs) pprint.PrettyPrinter = StructPrettyPrinter # monkey-patch pprint