import hail as hl
from hail.expr.expressions import (
    expr_array,
    expr_call,
    expr_numeric,
    raise_unless_entry_indexed,
    raise_unless_row_indexed,
)
from hail.typecheck import typecheck
[docs]@typecheck(call_expr=expr_call, loadings_expr=expr_array(expr_numeric), af_expr=expr_numeric)
def pc_project(call_expr, loadings_expr, af_expr):
    """Projects genotypes onto pre-computed PCs. Requires loadings and
    allele-frequency from a reference dataset (see example). Note that
    `loadings_expr` must have no missing data and reflect the rows
    from the original PCA run for this method to be accurate.
    Example
    -------
    >>> # Compute loadings and allele frequency for reference dataset
    >>> _, _, loadings_ht = hl.hwe_normalized_pca(mt.GT, k=10, compute_loadings=True)   # doctest: +SKIP
    >>> mt = mt.annotate_rows(af=hl.agg.mean(mt.GT.n_alt_alleles()) / 2)                # doctest: +SKIP
    >>> loadings_ht = loadings_ht.annotate(af=mt.rows()[loadings_ht.key].af)            # doctest: +SKIP
    >>> # Project new genotypes onto loadings
    >>> ht = pc_project(mt_to_project.GT, loadings_ht.loadings, loadings_ht.af)         # doctest: +SKIP
    Parameters
    ----------
    call_expr : :class:`.CallExpression`
        Entry-indexed call expression for genotypes
        to project onto loadings.
    loadings_expr : :class:`.ArrayNumericExpression`
        Location of expression for loadings
    af_expr : :class:`.Float64Expression`
        Location of expression for allele frequency
    Returns
    -------
    :class:`.Table`
        Table with scores calculated from loadings in column `scores`
    """
    raise_unless_entry_indexed('pc_project', call_expr)
    raise_unless_row_indexed('pc_project', loadings_expr)
    raise_unless_row_indexed('pc_project', af_expr)
    gt_source = call_expr._indices.source
    loadings_source = loadings_expr._indices.source
    af_source = af_expr._indices.source
    loadings_expr = _get_expr_or_join(loadings_expr, loadings_source, gt_source, '_loadings')
    af_expr = _get_expr_or_join(af_expr, af_source, gt_source, '_af')
    mt = gt_source._annotate_all(
        row_exprs={'_loadings': loadings_expr, '_af': af_expr}, entry_exprs={'_call': call_expr}
    )
    if isinstance(loadings_source, hl.MatrixTable):
        n_variants = loadings_source.count_rows()
    else:
        n_variants = loadings_source.count()
    mt = mt.filter_rows(hl.is_defined(mt._loadings) & hl.is_defined(mt._af) & (mt._af > 0) & (mt._af < 1))
    gt_norm = (mt._call.n_alt_alleles() - 2 * mt._af) / hl.sqrt(n_variants * 2 * mt._af * (1 - mt._af))
    return mt.select_cols(scores=hl.agg.array_sum(mt._loadings * gt_norm)).cols() 
def _get_expr_or_join(expr, source, other_source, loc):
    if source != other_source:
        if isinstance(source, hl.MatrixTable):
            source = source.annotate_rows(**{loc: expr})
        else:
            source = source.annotate(**{loc: expr})
        expr = source[other_source.row_key][loc]
    return expr