Source code for hail.experimental.pca

import hail as hl
from hail.expr.expressions import (
    expr_array,
    expr_call,
    expr_numeric,
    raise_unless_entry_indexed,
    raise_unless_row_indexed,
)
from hail.typecheck import typecheck


[docs]@typecheck(call_expr=expr_call, loadings_expr=expr_array(expr_numeric), af_expr=expr_numeric) def pc_project(call_expr, loadings_expr, af_expr): """Projects genotypes onto pre-computed PCs. Requires loadings and allele-frequency from a reference dataset (see example). Note that `loadings_expr` must have no missing data and reflect the rows from the original PCA run for this method to be accurate. Example ------- >>> # Compute loadings and allele frequency for reference dataset >>> _, _, loadings_ht = hl.hwe_normalized_pca(mt.GT, k=10, compute_loadings=True) # doctest: +SKIP >>> mt = mt.annotate_rows(af=hl.agg.mean(mt.GT.n_alt_alleles()) / 2) # doctest: +SKIP >>> loadings_ht = loadings_ht.annotate(af=mt.rows()[loadings_ht.key].af) # doctest: +SKIP >>> # Project new genotypes onto loadings >>> ht = pc_project(mt_to_project.GT, loadings_ht.loadings, loadings_ht.af) # doctest: +SKIP Parameters ---------- call_expr : :class:`.CallExpression` Entry-indexed call expression for genotypes to project onto loadings. loadings_expr : :class:`.ArrayNumericExpression` Location of expression for loadings af_expr : :class:`.Float64Expression` Location of expression for allele frequency Returns ------- :class:`.Table` Table with scores calculated from loadings in column `scores` """ raise_unless_entry_indexed('pc_project', call_expr) raise_unless_row_indexed('pc_project', loadings_expr) raise_unless_row_indexed('pc_project', af_expr) gt_source = call_expr._indices.source loadings_source = loadings_expr._indices.source af_source = af_expr._indices.source loadings_expr = _get_expr_or_join(loadings_expr, loadings_source, gt_source, '_loadings') af_expr = _get_expr_or_join(af_expr, af_source, gt_source, '_af') mt = gt_source._annotate_all( row_exprs={'_loadings': loadings_expr, '_af': af_expr}, entry_exprs={'_call': call_expr} ) if isinstance(loadings_source, hl.MatrixTable): n_variants = loadings_source.count_rows() else: n_variants = loadings_source.count() mt = mt.filter_rows(hl.is_defined(mt._loadings) & hl.is_defined(mt._af) & (mt._af > 0) & (mt._af < 1)) gt_norm = (mt._call.n_alt_alleles() - 2 * mt._af) / hl.sqrt(n_variants * 2 * mt._af * (1 - mt._af)) return mt.select_cols(scores=hl.agg.array_sum(mt._loadings * gt_norm)).cols()
def _get_expr_or_join(expr, source, other_source, loc): if source != other_source: if isinstance(source, hl.MatrixTable): source = source.annotate_rows(**{loc: expr}) else: source = source.annotate(**{loc: expr}) expr = source[other_source.row_key][loc] return expr