Source code for hail.utils.misc

import atexit
import datetime
import difflib
import json
import os
import re
import secrets
import shutil
import string
import tempfile
from collections import Counter, defaultdict
from contextlib import contextmanager
from io import StringIO
from typing import Literal, Optional
from urllib.parse import urlparse

import hail
import hail as hl
from hail.typecheck import enumeration, nullable, typecheck
from hail.utils.java import Env, error


[docs]@typecheck(n_rows=int, n_cols=int, n_partitions=nullable(int)) def range_matrix_table(n_rows, n_cols, n_partitions=None) -> 'hail.MatrixTable': """Construct a matrix table with row and column indices and no entry fields. Examples -------- >>> range_ds = hl.utils.range_matrix_table(n_rows=100, n_cols=10) >>> range_ds.count_rows() 100 >>> range_ds.count_cols() 10 Notes ----- The resulting matrix table contains the following fields: - `row_idx` (:py:data:`.tint32`) - Row index (row key). - `col_idx` (:py:data:`.tint32`) - Column index (column key). It contains no entry fields. This method is meant for testing and learning, and is not optimized for production performance. Parameters ---------- n_rows : :obj:`int` Number of rows. n_cols : :obj:`int` Number of columns. n_partitions : int, optional Number of partitions (uses Spark default parallelism if None). Returns ------- :class:`.MatrixTable` """ check_nonnegative_and_in_range('range_matrix_table', 'n_rows', n_rows) check_nonnegative_and_in_range('range_matrix_table', 'n_cols', n_cols) if n_partitions is not None: check_positive_and_in_range('range_matrix_table', 'n_partitions', n_partitions) return hail.MatrixTable( hail.ir.MatrixRead( hail.ir.MatrixRangeReader(n_rows, n_cols, n_partitions), _assert_type=hl.tmatrix( hl.tstruct(), hl.tstruct(col_idx=hl.tint32), ['col_idx'], hl.tstruct(row_idx=hl.tint32), ['row_idx'], hl.tstruct(), ), ) )
[docs]@typecheck(n=int, n_partitions=nullable(int)) def range_table(n, n_partitions=None) -> 'hail.Table': """Construct a table with the row index and no other fields. Examples -------- >>> df = hl.utils.range_table(100) >>> df.count() 100 Notes ----- The resulting table contains one field: - `idx` (:py:data:`.tint32`) - Row index (key). This method is meant for testing and learning, and is not optimized for production performance. Parameters ---------- n : int Number of rows. n_partitions : int, optional Number of partitions (uses Spark default parallelism if None). Returns ------- :class:`.Table` """ check_nonnegative_and_in_range('range_table', 'n', n) if n_partitions is not None: check_positive_and_in_range('range_table', 'n_partitions', n_partitions) return hail.Table(hail.ir.TableRange(n, n_partitions))
def check_positive_and_in_range(caller, name, value): if value <= 0: raise ValueError(f"'{caller}': parameter '{name}' must be positive, found {value}") elif value > hail.tint32.max_value: raise ValueError( f"'{caller}': parameter '{name}' must be less than or equal to {hail.tint32.max_value}, " f"found {value}" ) def check_nonnegative_and_in_range(caller, name, value): if value < 0: raise ValueError(f"'{caller}': parameter '{name}' must be non-negative, found {value}") elif value > hail.tint32.max_value: raise ValueError( f"'{caller}': parameter '{name}' must be less than or equal to {hail.tint32.max_value}, " f"found {value}" ) def wrap_to_list(s): if isinstance(s, list): return s elif isinstance(s, tuple): return list(s) else: return [s] def wrap_to_tuple(x): if isinstance(x, tuple): return x else: return (x,) def wrap_to_sequence(x): if isinstance(x, tuple): return x if isinstance(x, list): return tuple(x) else: return (x,) def get_env_or_default(maybe, envvar, default): import os return maybe or os.environ.get(envvar) or default def uri_path(uri): return urlparse(uri).path def local_path_uri(path): return 'file://' + path def new_temp_file(prefix=None, extension=None): tmpdir = Env.hc()._tmpdir alphabet = string.ascii_letters + string.digits token = ''.join([secrets.choice(alphabet) for _ in range(22)]) f = token if prefix is not None: f = f'{prefix}-{f}' if extension is not None: f = f'{f}.{extension}' return f'{tmpdir}/{f}' def new_local_temp_dir(suffix=None, prefix=None, dir=None): local_temp_dir = tempfile.mkdtemp(suffix, prefix, dir) atexit.register(shutil.rmtree, local_temp_dir) return local_temp_dir def new_local_temp_file(filename: str = 'temp') -> str: local_temp_dir = new_local_temp_dir() path = local_temp_dir + "/" + filename return path @contextmanager def with_local_temp_file(filename: str = 'temp') -> str: path = new_local_temp_file(filename) try: yield path finally: try: os.remove(path) except FileNotFoundError: pass storage_level = enumeration( 'NONE', 'DISK_ONLY', 'DISK_ONLY_2', 'MEMORY_ONLY', 'MEMORY_ONLY_2', 'MEMORY_ONLY_SER', 'MEMORY_ONLY_SER_2', 'MEMORY_AND_DISK', 'MEMORY_AND_DISK_2', 'MEMORY_AND_DISK_SER', 'MEMORY_AND_DISK_SER_2', 'OFF_HEAP', ) def run_command(args): import subprocess as sp try: sp.check_output(args, stderr=sp.STDOUT) except sp.CalledProcessError as e: print(e.output) raise e def hl_plural(orig, n, alternate=None): if alternate is None: plural = orig + 's' else: plural = alternate return hl.if_else(n == 1, orig, plural) def plural(orig, n, alternate=None): if n == 1: return orig elif alternate: return alternate else: return orig + 's' def get_obj_metadata(obj): from hail.expr.expressions import ArrayStructExpression, SetStructExpression, StructExpression from hail.matrixtable import GroupedMatrixTable, MatrixTable from hail.table import GroupedTable, Table from hail.utils import Struct def table_error(index_obj): def fmt_field(field): assert field in index_obj._fields inds = index_obj[field]._indices if inds == index_obj._global_indices: return "'{}' [globals]".format(field) elif inds == index_obj._row_indices: return "'{}' [row]".format(field) elif inds == index_obj._col_indices: # Table will never get here return "'{}' [col]".format(field) else: assert inds == index_obj._entry_indices return "'{}' [entry]".format(field) return fmt_field def struct_error(s): def fmt_field(field): assert field in s._fields return "'{}'".format(field) return fmt_field if isinstance(obj, MatrixTable): return 'MatrixTable', MatrixTable, table_error(obj), True elif isinstance(obj, GroupedMatrixTable): return 'GroupedMatrixTable', GroupedMatrixTable, table_error(obj._parent), True elif isinstance(obj, Table): return 'Table', Table, table_error(obj), True elif isinstance(obj, GroupedTable): return 'GroupedTable', GroupedTable, table_error(obj), False elif isinstance(obj, Struct): return 'Struct', Struct, struct_error(obj), False elif isinstance(obj, StructExpression): return 'StructExpression', StructExpression, struct_error(obj), True elif isinstance(obj, ArrayStructExpression): return 'ArrayStructExpression', ArrayStructExpression, struct_error(obj), True elif isinstance(obj, SetStructExpression): return 'SetStructExpression', SetStructExpression, struct_error(obj), True else: raise NotImplementedError(obj) def get_nice_attr_error(obj, item): class_name, cls, handler, has_describe = get_obj_metadata(obj) if item.startswith('_'): # don't handle 'private' attribute access return "{} instance has no attribute '{}'".format(class_name, item) else: field_names = obj._fields.keys() field_dict = defaultdict(lambda: []) for f in field_names: field_dict[f.lower()].append(f) obj_namespace = {x for x in dir(cls) if not x.startswith('_')} inherited = {x for x in obj_namespace if x not in cls.__dict__} methods = {x for x in obj_namespace if x in cls.__dict__ and callable(cls.__dict__[x])} props = obj_namespace - methods - inherited item_lower = item.lower() field_matches = difflib.get_close_matches(item_lower, field_dict, n=5) inherited_matches = difflib.get_close_matches(item_lower, inherited, n=5) method_matches = difflib.get_close_matches(item_lower, methods, n=5) prop_matches = difflib.get_close_matches(item_lower, props, n=5) s = ["{} instance has no field, method, or property '{}'".format(class_name, item)] if any([field_matches, method_matches, prop_matches, inherited_matches]): s.append('\n Did you mean:') if field_matches: fs = [] for f in field_matches: fs.extend(field_dict[f]) word = plural('field', len(fs)) s.append('\n Data {}: {}'.format(word, ', '.join(handler(f) for f in fs))) if method_matches: word = plural('method', len(method_matches)) s.append( '\n {} {}: {}'.format(class_name, word, ', '.join("'{}'".format(m) for m in method_matches)) ) if prop_matches: word = plural('property', len(prop_matches), 'properties') s.append( '\n {} {}: {}'.format(class_name, word, ', '.join("'{}'".format(p) for p in prop_matches)) ) if inherited_matches: word = plural('inherited method', len(inherited_matches)) s.append( '\n {} {}: {}'.format( class_name, word, ', '.join("'{}'".format(m) for m in inherited_matches) ) ) elif has_describe: s.append("\n Hint: use 'describe()' to show the names of all data fields.") return ''.join(s) def get_nice_field_error(obj, item): class_name, _, handler, has_describe = get_obj_metadata(obj) field_names = obj._fields.keys() dd = defaultdict(lambda: []) for f in field_names: dd[f.lower()].append(f) item_lower = item.lower() field_matches = difflib.get_close_matches(item_lower, dd, n=5) s = ["{} instance has no field '{}'".format(class_name, item)] if field_matches: s.append('\n Did you mean:') for f in field_matches: for orig_f in dd[f]: s.append("\n {}".format(handler(orig_f))) if has_describe: s.append("\n Hint: use 'describe()' to show the names of all data fields.") return ''.join(s) def check_collisions(caller, names, indices, override_protected_indices=None): from hail.expr.expressions import ExpressionException fields = indices.source._fields if override_protected_indices is not None: def invalid(e): return e._indices in override_protected_indices else: def invalid(e): return e._indices != indices # check collisions with fields on other axes for name in names: if name in fields and invalid(fields[name]): msg = f"{caller!r}: name collision with field indexed by {list(fields[name]._indices.axes)}: {name!r}" error('Analysis exception: {}'.format(msg)) raise ExpressionException(msg) # check duplicate fields for k, v in Counter(names).items(): if v > 1: from hail.expr.expressions import ExpressionException raise ExpressionException(f"{caller!r}: selection would produce duplicate field {k!r}") def get_key_by_exprs(caller, exprs, named_exprs, indices, override_protected_indices=None): from hail.expr.expressions import ExpressionException, analyze, to_expr exprs = [indices.source[e] if isinstance(e, str) else e for e in exprs] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} bindings = [] def is_top_level_field(e): return e in indices.source._fields_inverse existing_key_fields = [] final_key = [] for e in exprs: analyze(caller, e, indices, broadcast=False) if not e._ir.is_nested_field: raise ExpressionException( f"{caller!r} expects keyword arguments for complex expressions\n" f" Correct: ht = ht.key_by('x')\n" f" Correct: ht = ht.key_by(ht.x)\n" f" Correct: ht = ht.key_by(x = ht.x.replace(' ', '_'))\n" f" INCORRECT: ht = ht.key_by(ht.x.replace(' ', '_'))" ) name = e._ir.name final_key.append(name) if not is_top_level_field(e): bindings.append((name, e)) else: existing_key_fields.append(name) final_key.extend(named_exprs) bindings.extend(named_exprs.items()) check_collisions(caller, final_key, indices, override_protected_indices=override_protected_indices) return final_key, dict(bindings) def check_keys(caller, name, protected_key): from hail.expr.expressions import ExpressionException if name in protected_key: msg = ( f"{caller!r}: cannot overwrite key field {name!r} with annotate, select or drop; " f"use key_by to modify keys." ) error('Analysis exception: {}'.format(msg)) raise ExpressionException(msg) def get_select_exprs(caller, exprs, named_exprs, indices, base_struct): from hail.expr.expressions import ExpressionException, analyze, to_expr exprs = [indices.source[e] if isinstance(e, str) else e for e in exprs] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} select_fields = indices.protected_key[:] protected_key = set(select_fields) insertions = {} final_fields = select_fields[:] def is_top_level_field(e): return e in indices.source._fields_inverse for e in exprs: if not e._ir.is_nested_field: raise ExpressionException( f"{caller!r} expects keyword arguments for complex expressions\n" f" Correct: ht = ht.select('x')\n" f" Correct: ht = ht.select(ht.x)\n" f" Correct: ht = ht.select(x = ht.x.replace(' ', '_'))\n" f" INCORRECT: ht = ht.select(ht.x.replace(' ', '_'))" ) analyze(caller, e, indices, broadcast=False) name = e._ir.name check_keys(caller, name, protected_key) final_fields.append(name) if is_top_level_field(e): select_fields.append(name) else: insertions[name] = e for k, e in named_exprs.items(): check_keys(caller, k, protected_key) final_fields.append(k) insertions[k] = e check_collisions(caller, final_fields, indices) if final_fields == select_fields + list(insertions): # don't clog the IR with redundant field names s = base_struct.select(*select_fields).annotate(**insertions) else: s = base_struct.select(*select_fields)._annotate_ordered(insertions, final_fields) assert list(s) == final_fields return s def check_annotate_exprs(caller, named_exprs, indices, agg_axes): from hail.expr.expressions import analyze protected_key = set(indices.protected_key) for k, v in named_exprs.items(): analyze(f'{caller}: field {k!r}', v, indices, agg_axes, broadcast=True) check_keys(caller, k, protected_key) check_collisions(caller, list(named_exprs), indices) return named_exprs def process_joins(obj, exprs): all_uids = [] left = obj used_joins = set() for e in exprs: joins = e._ir.search(lambda a: isinstance(a, hail.ir.Join)) for j in sorted(joins, key=lambda j: j.idx): # Make sure joins happen in order if j.idx not in used_joins: left = j.join_func(left) all_uids.extend(j.temp_vars) used_joins.add(j.idx) def cleanup(table): remaining_uids = [uid for uid in all_uids if uid in table._fields] return table.drop(*remaining_uids) return left, cleanup def divide_null(num, denom): from hail.expr import if_else, missing from hail.expr.expressions.base_expression import unify_types_limited typ = unify_types_limited(num.dtype, denom.dtype) assert typ is not None return if_else(denom != 0, num / denom, missing(typ)) def lookup_bit(byte, which_bit): return (byte >> which_bit) & 1 def timestamp_path(base, suffix=''): return ''.join([base, '-', datetime.datetime.now().strftime("%Y%m%d-%H%M"), suffix]) def upper_hex(n, num_digits=None): if num_digits is None: return "{0:X}".format(n) else: return "{0:0{1}X}".format(n, num_digits) def escape_str(s, backticked=False): sb = StringIO() rewrite_dict = {'\b': '\\b', '\n': '\\n', '\t': '\\t', '\f': '\\f', '\r': '\\r'} for ch in s: chNum = ord(ch) if chNum > 0x7F: sb.write("\\u" + upper_hex(chNum, 4)) elif chNum < 32: if ch in rewrite_dict: sb.write(rewrite_dict[ch]) elif chNum > 0xF: sb.write("\\u00" + upper_hex(chNum)) else: sb.write("\\u000" + upper_hex(chNum)) elif ch == '"': if backticked: sb.write('"') else: sb.write('\\"') elif ch == '`': if backticked: sb.write("\\`") else: sb.write("`") elif ch == '\\': sb.write('\\\\') else: sb.write(ch) escaped = sb.getvalue() sb.close() return escaped def escape_id(s): if re.fullmatch(r'[_a-zA-Z]\w*', s): return s else: return "`{}`".format(escape_str(s, backticked=True)) def parsable_strings(strs): strs = ' '.join(f'"{escape_str(s)}"' for s in strs) return f"({strs})" def _dumps_partitions(partitions, row_key_type): parts_type = partitions.dtype if not (isinstance(parts_type, hl.tarray) and isinstance(parts_type.element_type, hl.tinterval)): raise ValueError(f'partitions type invalid: {parts_type} must be array of intervals') point_type = parts_type.element_type.point_type f1, t1 = next(iter(row_key_type.items())) if point_type == t1: partitions = hl.map( lambda x: hl.interval( start=hl.struct(**{f1: x.start}), end=hl.struct(**{f1: x.end}), includes_start=x.includes_start, includes_end=x.includes_end, ), partitions, ) else: if not isinstance(point_type, hl.tstruct): raise ValueError(f'partitions has wrong type: {point_type} must be struct or type of first row key field') if not point_type._is_prefix_of(row_key_type): raise ValueError(f'partitions type invalid: {point_type} must be prefix of {row_key_type}') s = json.dumps(partitions.dtype._convert_to_json(hl.eval(partitions))) return s, partitions.dtype def default_handler(): try: from IPython.display import display return display except ImportError: return print def guess_cloud_spark_provider() -> Optional[Literal['dataproc', 'hdinsight']]: if 'HAIL_DATAPROC' in os.environ: return 'dataproc' if 'AZURE_SPARK' in os.environ or 'hdinsight' in os.getenv('CLASSPATH', ''): return 'hdinsight' return None def no_service_backend(unsupported_feature): from hail import current_backend from hail.backend.service_backend import ServiceBackend if isinstance(current_backend(), ServiceBackend): raise NotImplementedError( f'{unsupported_feature!r} is not yet supported on the service backend.' f'\n If this is a pressing need, please alert the team on the discussion' f'\n forum to aid in prioritization: https://discuss.hail.is' ) ANY_REGION = ['any_region']