import builtins
import itertools
import math
from typing import Callable, Dict, List, Optional, Tuple, Union
import hail as hl
import hail.expr.aggregators as agg
from hail import ir
from hail.expr import (
Expression,
ExpressionException,
NDArrayNumericExpression,
StructExpression,
analyze,
expr_any,
expr_call,
expr_float64,
expr_locus,
expr_numeric,
matrix_table_source,
raise_unless_column_indexed,
raise_unless_entry_indexed,
raise_unless_row_indexed,
table_source,
)
from hail.expr.functions import expit
from hail.expr.types import tarray, tbool, tfloat64, tint32, tndarray, tstruct
from hail.genetics.reference_genome import reference_genome_type
from hail.linalg import BlockMatrix
from hail.matrixtable import MatrixTable
from hail.methods.misc import require_biallelic, require_row_key_variant
from hail.stats import LinearMixedModel
from hail.table import Table
from hail.typecheck import anytype, enumeration, nullable, numeric, oneof, sequenceof, sized_tupleof, typecheck
from hail.utils import FatalError, new_temp_file, wrap_to_list
from hail.utils.java import Env, info, warning
from ..backend.spark_backend import SparkBackend
from . import pca, relatedness
pc_relate = relatedness.pc_relate
identity_by_descent = relatedness.identity_by_descent
_blanczos_pca = pca._blanczos_pca
_hwe_normalized_blanczos = pca._hwe_normalized_blanczos
_spectral_moments = pca._spectral_moments
_pca_and_moments = pca._pca_and_moments
hwe_normalized_pca = pca.hwe_normalized_pca
pca = pca.pca
tvector64 = tndarray(tfloat64, 1)
tmatrix64 = tndarray(tfloat64, 2)
numerical_regression_fit_dtype = tstruct(
b=tvector64,
score=tvector64,
fisher=tmatrix64,
mu=tvector64,
n_iterations=tint32,
log_lkhd=tfloat64,
converged=tbool,
exploded=tbool,
)
[docs]@typecheck(
call=expr_call,
aaf_threshold=numeric,
include_par=bool,
female_threshold=numeric,
male_threshold=numeric,
aaf=nullable(str),
)
def impute_sex(call, aaf_threshold=0.0, include_par=False, female_threshold=0.2, male_threshold=0.8, aaf=None) -> Table:
r"""Impute sex of samples by calculating inbreeding coefficient on the
X chromosome.
.. include:: ../_templates/req_tvariant.rst
.. include:: ../_templates/req_biallelic.rst
Examples
--------
Remove samples where imputed sex does not equal reported sex:
>>> imputed_sex = hl.impute_sex(dataset.GT)
>>> dataset_result = dataset.filter_cols(imputed_sex[dataset.s].is_female != dataset.pheno.is_female,
... keep=False)
Notes
-----
We have used the same implementation as `PLINK v1.7
<https://zzz.bwh.harvard.edu/plink/summary.shtml#sexcheck>`__.
Let `gr` be the the reference genome of the type of the `locus` key (as
given by :attr:`.tlocus.reference_genome`)
1. Filter the dataset to loci on the X contig defined by `gr`.
2. Calculate alternate allele frequency (AAF) for each row from the dataset.
3. Filter to variants with AAF above `aaf_threshold`.
4. Remove loci in the pseudoautosomal region, as defined by `gr`, unless
`include_par` is ``True`` (it defaults to ``False``)
5. For each row and column with a non-missing genotype call, :math:`E`, the
expected number of homozygotes (from population AAF), is computed as
:math:`1.0 - (2.0*\mathrm{maf}*(1.0-\mathrm{maf}))`.
6. For each row and column with a non-missing genotype call, :math:`O`, the
observed number of homozygotes, is computed interpreting ``0`` as
heterozygote and ``1`` as homozygote`
7. For each row and column with a non-missing genotype call, :math:`N` is
incremented by 1
8. For each column, :math:`E`, :math:`O`, and :math:`N` are combined across
variants
9. For each column, :math:`F` is calculated by :math:`(O - E) / (N - E)`
10. A sex is assigned to each sample with the following criteria:
- Female when ``F < 0.2``
- Male when ``F > 0.8``
Use `female_threshold` and `male_threshold` to change this behavior.
**Annotations**
The returned column-key indexed :class:`.Table` has the following fields in
addition to the matrix table's column keys:
- **is_female** (:py:data:`.tbool`) -- True if the imputed sex is female,
false if male, missing if undetermined.
- **f_stat** (:py:data:`.tfloat64`) -- Inbreeding coefficient.
- **n_called** (:py:data:`.tint64`) -- Number of variants with a genotype call.
- **expected_homs** (:py:data:`.tfloat64`) -- Expected number of homozygotes.
- **observed_homs** (:py:data:`.tint64`) -- Observed number of homozygotes.
call : :class:`.CallExpression`
A genotype call for each row and column. The source dataset's row keys
must be [[locus], alleles] with types :class:`.tlocus` and
:class:`.tarray` of :obj:`.tstr`. Moreover, the alleles array must have
exactly two elements (i.e. the variant must be biallelic).
aaf_threshold : :obj:`float`
Minimum alternate allele frequency threshold.
include_par : :obj:`bool`
Include pseudoautosomal regions.
female_threshold : :obj:`float`
Samples are called females if F < female_threshold.
male_threshold : :obj:`float`
Samples are called males if F > male_threshold.
aaf : :class:`str` or :obj:`None`
A field defining the alternate allele frequency for each row. If
``None``, AAF will be computed from `call`.
Return
------
:class:`.Table`
Sex imputation statistics per sample.
"""
if aaf_threshold < 0.0 or aaf_threshold > 1.0:
raise FatalError("Invalid argument for `aaf_threshold`. Must be in range [0, 1].")
mt = call._indices.source
mt, _ = mt._process_joins(call)
mt = mt.annotate_entries(call=call)
mt = require_biallelic(mt, 'impute_sex')
if aaf is None:
mt = mt.annotate_rows(aaf=agg.call_stats(mt.call, mt.alleles).AF[1])
aaf = 'aaf'
rg = mt.locus.dtype.reference_genome
mt = hl.filter_intervals(
mt, hl.map(lambda x_contig: hl.parse_locus_interval(x_contig, rg), rg.x_contigs), keep=True
)
if not include_par:
interval_type = hl.tarray(hl.tinterval(hl.tlocus(rg)))
mt = hl.filter_intervals(mt, hl.literal(rg.par, interval_type), keep=False)
mt = mt.filter_rows((mt[aaf] > aaf_threshold) & (mt[aaf] < (1 - aaf_threshold)))
mt = mt.annotate_cols(ib=agg.inbreeding(mt.call, mt[aaf]))
kt = mt.select_cols(
is_female=hl.if_else(
mt.ib.f_stat < female_threshold, True, hl.if_else(mt.ib.f_stat > male_threshold, False, hl.missing(tbool))
),
**mt.ib,
).cols()
return kt
def _get_regression_row_fields(mt, pass_through, method) -> Dict[str, str]:
row_fields = dict(zip(mt.row_key.keys(), mt.row_key.keys()))
for f in pass_through:
if isinstance(f, str):
if f not in mt.row:
raise ValueError(f"'{method}/pass_through': MatrixTable has no row field {f!r}")
if f in row_fields:
# allow silent pass through of key fields
if f in mt.row_key:
pass
else:
raise ValueError(f"'{method}/pass_through': found duplicated field {f!r}")
row_fields[f] = mt[f]
else:
assert isinstance(f, Expression)
if not f._ir.is_nested_field:
raise ValueError(f"'{method}/pass_through': expect fields or nested fields, not complex expressions")
if not f._indices == mt._row_indices:
raise ExpressionException(
f"'{method}/pass_through': require row-indexed fields, found indices {f._indices.axes}"
)
name = f._ir.name
if name in row_fields:
# allow silent pass through of key fields
if not (name in mt.row_key and f._ir == mt[name]._ir):
raise ValueError(f"'{method}/pass_through': found duplicated field {name!r}")
row_fields[name] = f
for k in mt.row_key:
del row_fields[k]
return row_fields
[docs]@typecheck(
y=oneof(expr_float64, sequenceof(expr_float64), sequenceof(sequenceof(expr_float64))),
x=expr_float64,
covariates=sequenceof(expr_float64),
block_size=int,
pass_through=sequenceof(oneof(str, Expression)),
weights=nullable(oneof(expr_float64, sequenceof(expr_float64))),
)
def linear_regression_rows(y, x, covariates, block_size=16, pass_through=(), *, weights=None) -> Table:
r"""For each row, test an input variable for association with
response variables using linear regression.
Examples
--------
>>> result_ht = hl.linear_regression_rows(
... y=dataset.pheno.height,
... x=dataset.GT.n_alt_alleles(),
... covariates=[1, dataset.pheno.age, dataset.pheno.is_female])
Warning
-------
As in the example, the intercept covariate ``1`` must be
included **explicitly** if desired.
Warning
-------
If `y` is a single value or a list, :func:`.linear_regression_rows`
considers the same set of columns (i.e., samples, points) for every response
variable and row, namely those columns for which **all** response variables
and covariates are defined.
If `y` is a list of lists, then each inner list is treated as an
independent group, subsetting columns for missingness separately.
Notes
-----
With the default root and `y` a single expression, the following row-indexed
fields are added.
- **<row key fields>** (Any) -- Row key fields.
- **<pass_through fields>** (Any) -- Row fields in `pass_through`.
- **n** (:py:data:`.tint32`) -- Number of columns used.
- **sum_x** (:py:data:`.tfloat64`) -- Sum of input values `x`.
- **y_transpose_x** (:py:data:`.tfloat64`) -- Dot product of response
vector `y` with the input vector `x`.
- **beta** (:py:data:`.tfloat64`) --
Fit effect coefficient of `x`, :math:`\hat\beta_1` below.
- **standard_error** (:py:data:`.tfloat64`) --
Estimated standard error, :math:`\widehat{\mathrm{se}}_1`.
- **t_stat** (:py:data:`.tfloat64`) -- :math:`t`-statistic, equal to
:math:`\hat\beta_1 / \widehat{\mathrm{se}}_1`.
- **p_value** (:py:data:`.tfloat64`) -- :math:`p`-value.
If `y` is a list of expressions, then the last five fields instead have type
:class:`.tarray` of :py:data:`.tfloat64`, with corresponding indexing of
the list and each array.
If `y` is a list of lists of expressions, then `n` and `sum_x` are of type
``array<float64>``, and the last five fields are of type
``array<array<float64>>``. Index into these arrays with
``a[index_in_outer_list, index_in_inner_list]``. For example, if
``y=[[a], [b, c]]`` then the p-value for ``b`` is ``p_value[1][0]``.
In the statistical genetics example above, the input variable `x` encodes
genotype as the number of alternate alleles (0, 1, or 2). For each variant
(row), genotype is tested for association with height controlling for age
and sex, by fitting the linear regression model:
.. math::
\mathrm{height} = \beta_0 + \beta_1 \, \mathrm{genotype}
+ \beta_2 \, \mathrm{age}
+ \beta_3 \, \mathrm{is\_female}
+ \varepsilon,
\quad
\varepsilon \sim \mathrm{N}(0, \sigma^2)
Boolean covariates like :math:`\mathrm{is\_female}` are encoded as 1 for
``True`` and 0 for ``False``. The null model sets :math:`\beta_1 = 0`.
The standard least-squares linear regression model is derived in Section
3.2 of `The Elements of Statistical Learning, 2nd Edition
<http://statweb.stanford.edu/~tibs/ElemStatLearn/printings/ESLII_print10.pdf>`__.
See equation 3.12 for the t-statistic which follows the t-distribution with
:math:`n - k - 1` degrees of freedom, under the null hypothesis of no
effect, with :math:`n` samples and :math:`k` covariates in addition to
``x``.
Note
----
Use the `pass_through` parameter to include additional row fields from
matrix table underlying ``x``. For example, to include an "rsid" field, set
``pass_through=['rsid']`` or ``pass_through=[mt.rsid]``.
Parameters
----------
y : :class:`.Float64Expression` or :obj:`list` of :class:`.Float64Expression`
One or more column-indexed response expressions.
x : :class:`.Float64Expression`
Entry-indexed expression for input variable.
covariates : :obj:`list` of :class:`.Float64Expression`
List of column-indexed covariate expressions.
block_size : :obj:`int`
Number of row regressions to perform simultaneously per core. Larger blocks
require more memory but may improve performance.
pass_through : :obj:`list` of :class:`str` or :class:`.Expression`
Additional row fields to include in the resulting table.
weights : :class:`.Float64Expression` or :obj:`list` of :class:`.Float64Expression`
Optional column-indexed weighting for doing weighted least squares regression. Specify a single weight if a
single y or list of ys is specified. If a list of lists of ys is specified, specify one weight per inner list.
Returns
-------
:class:`.Table`
"""
if not isinstance(Env.backend(), SparkBackend) or weights is not None:
return _linear_regression_rows_nd(y, x, covariates, block_size, weights, pass_through)
mt = matrix_table_source('linear_regression_rows/x', x)
raise_unless_entry_indexed('linear_regression_rows/x', x)
y_is_list = isinstance(y, list)
if y_is_list and len(y) == 0:
raise ValueError("'linear_regression_rows': found no values for 'y'")
is_chained = y_is_list and isinstance(y[0], list)
if is_chained and any(len(lst) == 0 for lst in y):
raise ValueError("'linear_regression_rows': found empty inner list for 'y'")
y = [
raise_unless_column_indexed('linear_regression_rows_nd/y', y) or ys
for ys in wrap_to_list(y)
for y in (ys if is_chained else [ys])
]
for e in itertools.chain.from_iterable(y) if is_chained else y:
analyze('linear_regression_rows/y', e, mt._col_indices)
for e in covariates:
analyze('linear_regression_rows/covariates', e, mt._col_indices)
_warn_if_no_intercept('linear_regression_rows', covariates)
x_field_name = Env.get_uid()
if is_chained:
y_field_names = [[f'__y_{i}_{j}' for j in range(len(y[i]))] for i in range(len(y))]
y_dict = dict(zip(itertools.chain.from_iterable(y_field_names), itertools.chain.from_iterable(y)))
func = 'LinearRegressionRowsChained'
else:
y_field_names = list(f'__y_{i}' for i in range(len(y)))
y_dict = dict(zip(y_field_names, y))
func = 'LinearRegressionRowsSingle'
cov_field_names = list(f'__cov{i}' for i in range(len(covariates)))
row_fields = _get_regression_row_fields(mt, pass_through, 'linear_regression_rows')
# FIXME: selecting an existing entry field should be emitted as a SelectFields
mt = mt._select_all(
col_exprs=dict(**y_dict, **dict(zip(cov_field_names, covariates))),
row_exprs=row_fields,
col_key=[],
entry_exprs={x_field_name: x},
)
config = {
'name': func,
'yFields': y_field_names,
'xField': x_field_name,
'covFields': cov_field_names,
'rowBlockSize': block_size,
'passThrough': [x for x in row_fields if x not in mt.row_key],
}
ht_result = Table(ir.MatrixToTableApply(mt._mir, config))
if not y_is_list:
fields = ['y_transpose_x', 'beta', 'standard_error', 't_stat', 'p_value']
ht_result = ht_result.annotate(**{f: ht_result[f][0] for f in fields})
return ht_result.persist()
@typecheck(
y=oneof(expr_float64, sequenceof(expr_float64), sequenceof(sequenceof(expr_float64))),
x=expr_float64,
covariates=sequenceof(expr_float64),
block_size=int,
weights=nullable(oneof(expr_float64, sequenceof(expr_float64))),
pass_through=sequenceof(oneof(str, Expression)),
)
def _linear_regression_rows_nd(y, x, covariates, block_size=16, weights=None, pass_through=()) -> Table:
mt = matrix_table_source('linear_regression_rows_nd/x', x)
raise_unless_entry_indexed('linear_regression_rows_nd/x', x)
y_is_list = isinstance(y, list)
if y_is_list and len(y) == 0:
raise ValueError("'linear_regression_rows_nd': found no values for 'y'")
is_chained = y_is_list and isinstance(y[0], list)
if is_chained and any(len(lst) == 0 for lst in y):
raise ValueError("'linear_regression_rows': found empty inner list for 'y'")
y = [
raise_unless_column_indexed('linear_regression_rows_nd/y', y) or ys
for ys in wrap_to_list(y)
for y in (ys if is_chained else [ys])
]
if weights is not None:
if y_is_list and is_chained and not isinstance(weights, list):
raise ValueError("When y is a list of lists, weights should be a list.")
elif y_is_list and not is_chained and isinstance(weights, list):
raise ValueError("When y is a single list, weights should be a single expression.")
elif not y_is_list and isinstance(weights, list):
raise ValueError("When y is a single expression, weights should be a single expression.")
weights = wrap_to_list(weights) if weights is not None else None
for e in itertools.chain.from_iterable(y) if is_chained else y:
analyze('linear_regression_rows_nd/y', e, mt._col_indices)
for e in covariates:
analyze('linear_regression_rows_nd/covariates', e, mt._col_indices)
_warn_if_no_intercept('linear_regression_rows_nd', covariates)
x_field_name = Env.get_uid()
if is_chained:
y_field_name_groups = [[f'__y_{i}_{j}' for j in range(len(y[i]))] for i in range(len(y))]
y_dict = dict(zip(itertools.chain.from_iterable(y_field_name_groups), itertools.chain.from_iterable(y)))
if weights is not None and len(weights) != len(y):
raise ValueError("Must specify same number of weights as groups of phenotypes")
else:
y_field_name_groups = list(f'__y_{i}' for i in range(len(y)))
y_dict = dict(zip(y_field_name_groups, y))
# Wrapping in a list since the code is written for the more general chained case.
y_field_name_groups = [y_field_name_groups]
if weights is not None and len(weights) != 1:
raise ValueError("Must specify same number of weights as groups of phenotypes")
cov_field_names = list(f'__cov{i}' for i in range(len(covariates)))
weight_field_names = list(f'__weight_for_group_{i}' for i in range(len(weights))) if weights is not None else None
weight_dict = dict(zip(weight_field_names, weights)) if weights is not None else {}
row_field_names = _get_regression_row_fields(mt, pass_through, 'linear_regression_rows_nd')
# FIXME: selecting an existing entry field should be emitted as a SelectFields
mt = mt._select_all(
col_exprs=dict(**y_dict, **weight_dict, **dict(zip(cov_field_names, covariates))),
row_exprs=row_field_names,
col_key=[],
entry_exprs={x_field_name: x},
)
entries_field_name = 'ent'
sample_field_name = "by_sample"
num_y_lists = len(y_field_name_groups)
# Given a hail array, get the mean of the nonmissing entries and
# return new array where the missing entries are the mean.
def mean_impute(hl_array):
non_missing_mean = hl.mean(hl_array, filter_missing=True)
return hl_array.map(lambda entry: hl.if_else(hl.is_defined(entry), entry, non_missing_mean))
def select_array_indices(hl_array, indices):
return indices.map(lambda i: hl_array[i])
def dot_rows_with_themselves(matrix):
return (matrix * matrix).sum(1)
def no_missing(hail_array):
return hail_array.all(lambda element: hl.is_defined(element))
ht_local = mt._localize_entries(entries_field_name, sample_field_name)
ht = ht_local.transmute(**{entries_field_name: ht_local[entries_field_name][x_field_name]})
def setup_globals(ht):
# cov_arrays is per sample, then per cov.
if covariates:
ht = ht.annotate_globals(
cov_arrays=ht[sample_field_name].map(
lambda sample_struct: [sample_struct[cov_name] for cov_name in cov_field_names]
)
)
else:
ht = ht.annotate_globals(
cov_arrays=ht[sample_field_name].map(lambda sample_struct: hl.empty_array(hl.tfloat64))
)
y_arrays_per_group = [
ht[sample_field_name].map(lambda sample_struct: [sample_struct[y_name] for y_name in one_y_field_name_set])
for one_y_field_name_set in y_field_name_groups
]
if weight_field_names:
weight_arrays = ht[sample_field_name].map(
lambda sample_struct: [sample_struct[weight_name] for weight_name in weight_field_names]
)
else:
weight_arrays = ht[sample_field_name].map(lambda sample_struct: hl.empty_array(hl.tfloat64))
ht = ht.annotate_globals(y_arrays_per_group=y_arrays_per_group, weight_arrays=weight_arrays)
ht = ht.annotate_globals(all_covs_defined=ht.cov_arrays.map(lambda sample_covs: no_missing(sample_covs)))
def get_kept_samples(group_idx, sample_ys):
# sample_ys is an array of samples, with each element being an array of the y_values
return (
hl.enumerate(sample_ys)
.filter(
lambda idx_and_y_values: ht.all_covs_defined[idx_and_y_values[0]]
& no_missing(idx_and_y_values[1])
& (hl.is_defined(ht.weight_arrays[idx_and_y_values[0]][group_idx]) if weights else True)
)
.map(lambda idx_and_y_values: idx_and_y_values[0])
)
ht = ht.annotate_globals(kept_samples=hl.enumerate(ht.y_arrays_per_group).starmap(get_kept_samples))
ht = ht.annotate_globals(
y_nds=hl.zip(ht.kept_samples, ht.y_arrays_per_group).starmap(
lambda sample_indices, y_arrays: hl.nd.array(sample_indices.map(lambda idx: y_arrays[idx]))
)
)
ht = ht.annotate_globals(
cov_nds=ht.kept_samples.map(lambda group: hl.nd.array(group.map(lambda idx: ht.cov_arrays[idx])))
)
if weights is None:
ht = ht.annotate_globals(sqrt_weights=hl.missing(hl.tarray(hl.tndarray(hl.tfloat64, 2))))
ht = ht.annotate_globals(scaled_y_nds=ht.y_nds)
ht = ht.annotate_globals(scaled_cov_nds=ht.cov_nds)
else:
ht = ht.annotate_globals(
weight_nds=hl.enumerate(ht.kept_samples).starmap(
lambda group_idx, group_sample_indices: hl.nd.array(
group_sample_indices.map(lambda group_sample_idx: ht.weight_arrays[group_sample_idx][group_idx])
)
)
)
ht = ht.annotate_globals(
sqrt_weights=ht.weight_nds.map(lambda weight_nd: weight_nd.map(lambda e: hl.sqrt(e)))
)
ht = ht.annotate_globals(
scaled_y_nds=hl.zip(ht.y_nds, ht.sqrt_weights).starmap(
lambda y, sqrt_weight: y * sqrt_weight.reshape(-1, 1)
)
)
ht = ht.annotate_globals(
scaled_cov_nds=hl.zip(ht.cov_nds, ht.sqrt_weights).starmap(
lambda cov, sqrt_weight: cov * sqrt_weight.reshape(-1, 1)
)
)
k = builtins.len(covariates)
ht = ht.annotate_globals(ns=ht.kept_samples.map(lambda one_sample_set: hl.len(one_sample_set)))
def log_message(i):
if is_chained:
return (
"linear regression_rows["
+ hl.str(i)
+ "] running on "
+ hl.str(ht.ns[i])
+ " samples for "
+ hl.str(ht.scaled_y_nds[i].shape[1])
+ f" response variables y, with input variables x, and {len(covariates)} additional covariates..."
)
else:
return (
"linear_regression_rows running on "
+ hl.str(ht.ns[0])
+ " samples for "
+ hl.str(ht.scaled_y_nds[i].shape[1])
+ f" response variables y, with input variables x, and {len(covariates)} additional covariates..."
)
ht = ht.annotate_globals(ns=hl.range(num_y_lists).map(lambda i: hl._console_log(log_message(i), ht.ns[i])))
ht = ht.annotate_globals(
cov_Qts=hl.if_else(
k > 0,
ht.scaled_cov_nds.map(lambda one_cov_nd: hl.nd.qr(one_cov_nd)[0].T),
ht.ns.map(lambda n: hl.nd.zeros((0, n))),
)
)
ht = ht.annotate_globals(Qtys=hl.zip(ht.cov_Qts, ht.scaled_y_nds).starmap(lambda cov_qt, y: cov_qt @ y))
return ht.select_globals(
kept_samples=ht.kept_samples,
__scaled_y_nds=ht.scaled_y_nds,
__sqrt_weight_nds=ht.sqrt_weights,
ns=ht.ns,
ds=ht.ns.map(lambda n: n - k - 1),
__cov_Qts=ht.cov_Qts,
__Qtys=ht.Qtys,
__yyps=hl.range(num_y_lists).map(
lambda i: dot_rows_with_themselves(ht.scaled_y_nds[i].T) - dot_rows_with_themselves(ht.Qtys[i].T)
),
)
ht = setup_globals(ht)
def process_block(block):
rows_in_block = hl.len(block)
# Processes one block group based on given idx. Returns a single struct.
def process_y_group(idx):
if weights is not None:
X = (
hl.nd.array(
block[entries_field_name].map(
lambda row: mean_impute(select_array_indices(row, ht.kept_samples[idx]))
)
)
* ht.__sqrt_weight_nds[idx]
).T
else:
X = hl.nd.array(
block[entries_field_name].map(
lambda row: mean_impute(select_array_indices(row, ht.kept_samples[idx]))
)
).T
n = ht.ns[idx]
sum_x = X.sum(0)
Qtx = ht.__cov_Qts[idx] @ X
ytx = ht.__scaled_y_nds[idx].T @ X
xyp = ytx - (ht.__Qtys[idx].T @ Qtx)
xxpRec = (dot_rows_with_themselves(X.T) - dot_rows_with_themselves(Qtx.T)).map(lambda entry: 1 / entry)
b = xyp * xxpRec
se = ((1.0 / ht.ds[idx]) * (ht.__yyps[idx].reshape((-1, 1)) @ xxpRec.reshape((1, -1)) - (b * b))).map(
lambda entry: hl.sqrt(entry)
)
t = b / se
return hl.rbind(
t,
lambda t: hl.rbind(
ht.ds[idx],
lambda d: hl.rbind(
t.map(lambda entry: 2 * hl.expr.functions.pT(-hl.abs(entry), d, True, False)),
lambda p: hl.struct(
n=hl.range(rows_in_block).map(lambda i: n),
sum_x=sum_x._data_array(),
y_transpose_x=ytx.T._data_array(),
beta=b.T._data_array(),
standard_error=se.T._data_array(),
t_stat=t.T._data_array(),
p_value=p.T._data_array(),
),
),
),
)
per_y_list = hl.range(num_y_lists).map(lambda i: process_y_group(i))
key_field_names = [key_field for key_field in ht.key]
def build_row(row_idx):
# For every field we care about, map across all y's, getting the row_idxth one from each.
idxth_keys = {field_name: block[field_name][row_idx] for field_name in key_field_names}
computed_row_field_names = ['n', 'sum_x', 'y_transpose_x', 'beta', 'standard_error', 't_stat', 'p_value']
computed_row_fields = {
field_name: per_y_list.map(lambda one_y: one_y[field_name][row_idx])
for field_name in computed_row_field_names
}
pass_through_rows = {field_name: block[field_name][row_idx] for field_name in row_field_names}
if not is_chained:
computed_row_fields = {key: value[0] for key, value in computed_row_fields.items()}
return hl.struct(**{**idxth_keys, **computed_row_fields, **pass_through_rows})
new_rows = hl.range(rows_in_block).map(build_row)
return new_rows
def process_partition(part):
grouped = part.grouped(block_size)
return grouped.flatmap(lambda block: process_block(block)._to_stream())
res = ht._map_partitions(process_partition)
if not y_is_list:
fields = ['y_transpose_x', 'beta', 'standard_error', 't_stat', 'p_value']
res = res.annotate(**{f: res[f][0] for f in fields})
res = res.select_globals()
temp_file_name = hl.utils.new_temp_file("_linear_regression_rows_nd", "result")
res = res.checkpoint(temp_file_name)
return res
[docs]@typecheck(
test=enumeration('wald', 'lrt', 'score', 'firth'),
y=oneof(expr_float64, sequenceof(expr_float64)),
x=expr_float64,
covariates=sequenceof(expr_float64),
pass_through=sequenceof(oneof(str, Expression)),
max_iterations=nullable(int),
tolerance=nullable(float),
)
def logistic_regression_rows(
test, y, x, covariates, pass_through=(), *, max_iterations: Optional[int] = None, tolerance: Optional[float] = None
) -> Table:
r"""For each row, test an input variable for association with a
binary response variable using logistic regression.
Examples
--------
Run the logistic regression Wald test per variant using a Boolean
phenotype, intercept and two covariates stored in column-indexed
fields:
>>> result_ht = hl.logistic_regression_rows(
... test='wald',
... y=dataset.pheno.is_case,
... x=dataset.GT.n_alt_alleles(),
... covariates=[1, dataset.pheno.age, dataset.pheno.is_female])
Run the logistic regression Wald test per variant using a list of binary (0/1)
phenotypes, intercept and two covariates stored in column-indexed
fields:
>>> result_ht = hl.logistic_regression_rows(
... test='wald',
... y=[dataset.pheno.is_case, dataset.pheno.is_case], # where pheno values are 0, 1, or missing
... x=dataset.GT.n_alt_alleles(),
... covariates=[1, dataset.pheno.age, dataset.pheno.is_female])
As above but with at most 100 Newton iterations and a stricter-than-default tolerance of 1e-8:
>>> result_ht = hl.logistic_regression_rows(
... test='wald',
... y=[dataset.pheno.is_case, dataset.pheno.is_case], # where pheno values are 0, 1, or missing
... x=dataset.GT.n_alt_alleles(),
... covariates=[1, dataset.pheno.age, dataset.pheno.is_female],
... max_iterations=100,
... tolerance=1e-8)
Warning
-------
:func:`.logistic_regression_rows` considers the same set of
columns (i.e., samples, points) for every row, namely those columns for
which **all** response variables and covariates are defined. For each row, missing values of
`x` are mean-imputed over these columns. As in the example, the
intercept covariate ``1`` must be included **explicitly** if desired.
Notes
-----
This method performs, for each row, a significance test of the input
variable in predicting a binary (case-control) response variable based
on the logistic regression model. The response variable type must either
be numeric (with all present values 0 or 1) or Boolean, in which case
true and false are coded as 1 and 0, respectively.
Hail supports the Wald test ('wald'), likelihood ratio test ('lrt'),
Rao score test ('score'), and Firth test ('firth'). Hail only includes
columns for which the response variable and all covariates are defined.
For each row, Hail imputes missing input values as the mean of the
non-missing values.
The example above considers a model of the form
.. math::
\mathrm{Prob}(\mathrm{is\_case}) =
\mathrm{sigmoid}(\beta_0 + \beta_1 \, \mathrm{gt}
+ \beta_2 \, \mathrm{age}
+ \beta_3 \, \mathrm{is\_female} + \varepsilon),
\quad
\varepsilon \sim \mathrm{N}(0, \sigma^2)
where :math:`\mathrm{sigmoid}` is the `sigmoid function`_, the genotype
:math:`\mathrm{gt}` is coded as 0 for HomRef, 1 for Het, and 2 for
HomVar, and the Boolean covariate :math:`\mathrm{is\_female}` is coded as
for ``True`` (female) and 0 for ``False`` (male). The null model sets
:math:`\beta_1 = 0`.
.. _sigmoid function: https://en.wikipedia.org/wiki/Sigmoid_function
The structure of the emitted row field depends on the test statistic as
shown in the tables below.
========== ================== ======= ============================================
Test Field Type Value
========== ================== ======= ============================================
Wald `beta` float64 fit effect coefficient,
:math:`\hat\beta_1`
Wald `standard_error` float64 estimated standard error,
:math:`\widehat{\mathrm{se}}`
Wald `z_stat` float64 Wald :math:`z`-statistic, equal to
:math:`\hat\beta_1 / \widehat{\mathrm{se}}`
Wald `p_value` float64 Wald p-value testing :math:`\beta_1 = 0`
LRT, Firth `beta` float64 fit effect coefficient,
:math:`\hat\beta_1`
LRT, Firth `chi_sq_stat` float64 deviance statistic
LRT, Firth `p_value` float64 LRT / Firth p-value testing
:math:`\beta_1 = 0`
Score `chi_sq_stat` float64 score statistic
Score `p_value` float64 score p-value testing :math:`\beta_1 = 0`
========== ================== ======= ============================================
For the Wald and likelihood ratio tests, Hail fits the logistic model for
each row using Newton iteration and only emits the above fields
when the maximum likelihood estimate of the coefficients converges. The
Firth test uses a modified form of Newton iteration. To help diagnose
convergence issues, Hail also emits three fields which summarize the
iterative fitting process:
================ =================== ======= ===============================
Test Field Type Value
================ =================== ======= ===============================
Wald, LRT, Firth `fit.n_iterations` int32 number of iterations until
convergence, explosion, or
reaching the max (by default,
25 for Wald, LRT; 100 for Firth)
Wald, LRT, Firth `fit.converged` bool ``True`` if iteration converged
Wald, LRT, Firth `fit.exploded` bool ``True`` if iteration exploded
================ =================== ======= ===============================
We consider iteration to have converged when every coordinate of
:math:`\beta` changes by less than :math:`10^{-6}` by default. For Wald and
LRT, up to 25 iterations are attempted by default; in testing we find 4 or 5
iterations nearly always suffice. Convergence may also fail due to
explosion, which refers to low-level numerical linear algebra exceptions
caused by manipulating ill-conditioned matrices. Explosion may result from
(nearly) linearly dependent covariates or complete separation_.
.. _separation: https://en.wikipedia.org/wiki/Separation_(statistics)
A more common situation in genetics is quasi-complete seperation, e.g.
variants that are observed only in cases (or controls). Such variants
inevitably arise when testing millions of variants with very low minor
allele count. The maximum likelihood estimate of :math:`\beta` under
logistic regression is then undefined but convergence may still occur
after a large number of iterations due to a very flat likelihood
surface. In testing, we find that such variants produce a secondary bump
from 10 to 15 iterations in the histogram of number of iterations per
variant. We also find that this faux convergence produces large standard
errors and large (insignificant) p-values. To not miss such variants,
consider using Firth logistic regression, linear regression, or
group-based tests.
Here's a concrete illustration of quasi-complete seperation in R. Suppose
we have 2010 samples distributed as follows for a particular variant:
======= ====== === ======
Status HomRef Het HomVar
======= ====== === ======
Case 1000 10 0
Control 1000 0 0
======= ====== === ======
The following R code fits the (standard) logistic, Firth logistic,
and linear regression models to this data, where ``x`` is genotype,
``y`` is phenotype, and ``logistf`` is from the logistf package:
.. code-block:: R
x <- c(rep(0,1000), rep(1,1000), rep(1,10)
y <- c(rep(0,1000), rep(0,1000), rep(1,10))
logfit <- glm(y ~ x, family=binomial())
firthfit <- logistf(y ~ x)
linfit <- lm(y ~ x)
The resulting p-values for the genotype coefficient are 0.991, 0.00085,
and 0.0016, respectively. The erroneous value 0.991 is due to
quasi-complete separation. Moving one of the 10 hets from case to control
eliminates this quasi-complete separation; the p-values from R are then
0.0373, 0.0111, and 0.0116, respectively, as expected for a less
significant association.
The Firth test reduces bias from small counts and resolves the issue of
separation by penalizing maximum likelihood estimation by the `Jeffrey's
invariant prior <https://en.wikipedia.org/wiki/Jeffreys_prior>`__. This test
is slower, as both the null and full model must be fit per variant, and
convergence of the modified Newton method is linear rather than
quadratic. For Firth, 100 iterations are attempted by default for the null
model and, if that is successful, for the full model as well. In testing we
find 20 iterations nearly always suffices. If the null model fails to
converge, then the `logreg.fit` fields reflect the null model; otherwise,
they reflect the full model.
See
`Recommended joint and meta-analysis strategies for case-control association testing of single low-count variants <http://www.ncbi.nlm.nih.gov/pmc/articles/PMC4049324/>`__
for an empirical comparison of the logistic Wald, LRT, score, and Firth
tests. The theoretical foundations of the Wald, likelihood ratio, and score
tests may be found in Chapter 3 of Gesine Reinert's notes
`Statistical Theory <http://www.stats.ox.ac.uk/~reinert/stattheory/theoryshort09.pdf>`__.
Firth introduced his approach in
`Bias reduction of maximum likelihood estimates, 1993 <http://www2.stat.duke.edu/~scs/Courses/Stat376/Papers/GibbsFieldEst/BiasReductionMLE.pdf>`__.
Heinze and Schemper further analyze Firth's approach in
`A solution to the problem of separation in logistic regression, 2002 <https://cemsiis.meduniwien.ac.at/fileadmin/msi_akim/CeMSIIS/KB/volltexte/Heinze_Schemper_2002_Statistics_in_Medicine.pdf>`__.
Hail's logistic regression tests correspond to the ``b.wald``,
``b.lrt``, and ``b.score`` tests in `EPACTS`_. For each variant, Hail
imputes missing input values as the mean of non-missing input values,
whereas EPACTS subsets to those samples with called genotypes. Hence,
Hail and EPACTS results will currently only agree for variants with no
missing genotypes.
.. _EPACTS: http://genome.sph.umich.edu/wiki/EPACTS#Single_Variant_Tests
Note
----
Use the `pass_through` parameter to include additional row fields from
matrix table underlying ``x``. For example, to include an "rsid" field, set
``pass_through=['rsid']`` or ``pass_through=[mt.rsid]``.
Parameters
----------
test : {'wald', 'lrt', 'score', 'firth'}
Statistical test.
y : :class:`.Float64Expression` or :obj:`list` of :class:`.Float64Expression`
One or more column-indexed response expressions.
All non-missing values must evaluate to 0 or 1.
Note that a :class:`.BooleanExpression` will be implicitly converted to
a :class:`.Float64Expression` with this property.
x : :class:`.Float64Expression`
Entry-indexed expression for input variable.
covariates : :obj:`list` of :class:`.Float64Expression`
Non-empty list of column-indexed covariate expressions.
pass_through : :obj:`list` of :class:`str` or :class:`.Expression`
Additional row fields to include in the resulting table.
max_iterations : :obj:`int`
The maximum number of iterations.
tolerance : :obj:`float`, optional
The iterative fit of this model is considered "converged" if the change in the estimated
beta is smaller than tolerance. By default the tolerance is 1e-6.
Returns
-------
:class:`.Table`
"""
if max_iterations is None:
max_iterations = 25 if test != 'firth' else 100
if hl.current_backend().requires_lowering:
return _logistic_regression_rows_nd(
test, y, x, covariates, pass_through, max_iterations=max_iterations, tolerance=tolerance
)
if tolerance is None:
tolerance = 1e-6
assert tolerance > 0.0
if len(covariates) == 0:
raise ValueError('logistic regression requires at least one covariate expression')
mt = matrix_table_source('logistic_regresion_rows/x', x)
raise_unless_entry_indexed('logistic_regresion_rows/x', x)
y_is_list = isinstance(y, list)
if y_is_list and len(y) == 0:
raise ValueError("'logistic_regression_rows': found no values for 'y'")
y = [raise_unless_column_indexed('logistic_regression_rows/y', y) or y for y in wrap_to_list(y)]
for e in covariates:
analyze('logistic_regression_rows/covariates', e, mt._col_indices)
_warn_if_no_intercept('logistic_regression_rows', covariates)
x_field_name = Env.get_uid()
y_field = [f'__y_{i}' for i in range(len(y))]
y_dict = dict(zip(y_field, y))
cov_field_names = [f'__cov{i}' for i in range(len(covariates))]
row_fields = _get_regression_row_fields(mt, pass_through, 'logistic_regression_rows')
# FIXME: selecting an existing entry field should be emitted as a SelectFields
mt = mt._select_all(
col_exprs=dict(**y_dict, **dict(zip(cov_field_names, covariates))),
row_exprs=row_fields,
col_key=[],
entry_exprs={x_field_name: x},
)
config = {
'name': 'LogisticRegression',
'test': test,
'yFields': y_field,
'xField': x_field_name,
'covFields': cov_field_names,
'passThrough': [x for x in row_fields if x not in mt.row_key],
'maxIterations': max_iterations,
'tolerance': tolerance,
}
result = Table(ir.MatrixToTableApply(mt._mir, config))
if not y_is_list:
result = result.transmute(**result.logistic_regression[0])
return result.persist()
# Helpers for logreg:
def mean_impute(hl_array):
non_missing_mean = hl.mean(hl_array, filter_missing=True)
return hl_array.map(lambda entry: hl.coalesce(entry, non_missing_mean))
sigmoid = expit
def nd_max(hl_nd):
return hl.max(hl.array(hl_nd.reshape(-1)))
def logreg_fit(
X: NDArrayNumericExpression, # (K,)
y: NDArrayNumericExpression, # (N, K)
null_fit: Optional[StructExpression],
max_iterations: int,
tolerance: float,
) -> StructExpression:
"""Iteratively reweighted least squares to fit the model y ~ Bernoulli(logit(X \beta))
When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1.
"""
assert max_iterations >= 0
assert X.ndim == 2
assert y.ndim == 1
# X is samples by covs.
# y is length num samples, for one cov.
n = X.shape[0]
m = X.shape[1]
if null_fit is None:
avg = y.sum() / n
logit_avg = hl.log(avg / (1 - avg))
b = hl.nd.hstack([hl.nd.array([logit_avg]), hl.nd.zeros((hl.int32(m - 1)))])
mu = sigmoid(X @ b)
score = X.T @ (y - mu)
# Reshape so we do a rowwise multiply
fisher = X.T @ (X * (mu * (1 - mu)).reshape(-1, 1))
else:
# num covs used to fit null model.
m0 = null_fit.b.shape[0]
m_diff = m - m0
X0 = X[:, 0:m0]
X1 = X[:, m0:]
b = hl.nd.hstack([null_fit.b, hl.nd.zeros((m_diff,))])
mu = sigmoid(X @ b)
score = hl.nd.hstack([null_fit.score, X1.T @ (y - mu)])
fisher00 = null_fit.fisher
fisher01 = X0.T @ (X1 * (mu * (1 - mu)).reshape(-1, 1))
fisher10 = fisher01.T
fisher11 = X1.T @ (X1 * (mu * (1 - mu)).reshape(-1, 1))
fisher = hl.nd.vstack([hl.nd.hstack([fisher00, fisher01]), hl.nd.hstack([fisher10, fisher11])])
dtype = numerical_regression_fit_dtype
blank_struct = hl.struct(**{k: hl.missing(dtype[k]) for k in dtype})
def search(recur, iteration, b, mu, score, fisher):
def cont(exploded, delta_b, max_delta_b):
log_lkhd = hl.log((y * mu) + (1 - y) * (1 - mu)).sum()
next_b = b + delta_b
next_mu = sigmoid(X @ next_b)
next_score = X.T @ (y - next_mu)
next_fisher = X.T @ (X * (next_mu * (1 - next_mu)).reshape(-1, 1))
return (
hl.case()
.when(
exploded | hl.is_nan(delta_b[0]),
blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True),
)
.when(
max_delta_b < tolerance,
hl.struct(
b=b,
score=score,
fisher=fisher,
mu=mu,
n_iterations=iteration,
log_lkhd=log_lkhd,
converged=True,
exploded=False,
),
)
.when(
iteration == max_iterations,
blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False),
)
.default(recur(iteration + 1, next_b, next_mu, next_score, next_fisher))
)
delta_b_struct = hl.nd.solve(fisher, score, no_crash=True)
exploded = delta_b_struct.failed
delta_b = delta_b_struct.solution
max_delta_b = nd_max(hl.abs(delta_b))
return hl.bind(cont, exploded, delta_b, max_delta_b)
if max_iterations == 0:
return blank_struct.annotate(n_iterations=0, log_lkhd=0, converged=False, exploded=False)
return hl.experimental.loop(search, numerical_regression_fit_dtype, 1, b, mu, score, fisher)
def wald_test(X, fit):
se = hl.sqrt(hl.nd.diagonal(hl.nd.inv(fit.fisher)))
z = fit.b / se
p = z.map(lambda e: 2 * hl.pnorm(-hl.abs(e)))
return hl.struct(
beta=fit.b[X.shape[1] - 1],
standard_error=se[X.shape[1] - 1],
z_stat=z[X.shape[1] - 1],
p_value=p[X.shape[1] - 1],
fit=fit.select('n_iterations', 'converged', 'exploded'),
)
def lrt_test(X, null_fit, fit):
chi_sq = hl.if_else(~fit.converged, hl.missing(hl.tfloat64), 2 * (fit.log_lkhd - null_fit.log_lkhd))
p = hl.pchisqtail(chi_sq, X.shape[1] - null_fit.b.shape[0])
return hl.struct(
beta=fit.b[X.shape[1] - 1],
chi_sq_stat=chi_sq,
p_value=p,
fit=fit.select('n_iterations', 'converged', 'exploded'),
)
def logistic_score_test(X, y, null_fit):
m = X.shape[1]
m0 = null_fit.b.shape[0]
b = hl.nd.hstack([null_fit.b, hl.nd.zeros((hl.int32(m - m0)))])
X0 = X[:, 0:m0]
X1 = X[:, m0:]
mu = hl.expit(X @ b)
score_0 = null_fit.score
score_1 = X1.T @ (y - mu)
score = hl.nd.hstack([score_0, score_1])
fisher00 = null_fit.fisher
fisher01 = X0.T @ (X1 * (mu * (1 - mu)).reshape(-1, 1))
fisher10 = fisher01.T
fisher11 = X1.T @ (X1 * (mu * (1 - mu)).reshape(-1, 1))
fisher = hl.nd.vstack([hl.nd.hstack([fisher00, fisher01]), hl.nd.hstack([fisher10, fisher11])])
solve_attempt = hl.nd.solve(fisher, score, no_crash=True)
chi_sq = hl.or_missing(~solve_attempt.failed, (score * solve_attempt.solution).sum())
p = hl.pchisqtail(chi_sq, m - m0)
return hl.struct(chi_sq_stat=chi_sq, p_value=p)
def _firth_fit(
b: NDArrayNumericExpression, # (K,)
X: NDArrayNumericExpression, # (N, K)
y: NDArrayNumericExpression, # (N,)
max_iterations: int,
tolerance: float,
) -> StructExpression:
"""Iteratively reweighted least squares using Firth's regression to fit the model y ~ Bernoulli(logit(X \beta))
When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1.
"""
assert max_iterations >= 0
assert X.ndim == 2
assert y.ndim == 1
assert b.ndim == 1
dtype = numerical_regression_fit_dtype._drop_fields(['score', 'fisher'])
blank_struct = hl.struct(**{k: hl.missing(dtype[k]) for k in dtype})
X_bslice = X[:, : b.shape[0]]
def fit(recur, iteration, b):
def cont(exploded, delta_b, max_delta_b):
log_lkhd_left = hl.log(y * mu + (hl.literal(1.0) - y) * (1 - mu)).sum()
log_lkhd_right = hl.log(hl.abs(hl.nd.diagonal(r))).sum()
log_lkhd = log_lkhd_left + log_lkhd_right
next_b = b + delta_b
return (
hl.case()
.when(
exploded | hl.is_nan(delta_b[0]),
blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True),
)
.when(
max_delta_b < tolerance,
hl.struct(b=b, mu=mu, n_iterations=iteration, log_lkhd=log_lkhd, converged=True, exploded=False),
)
.when(
iteration == max_iterations,
blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False),
)
.default(recur(iteration + 1, next_b))
)
m = b.shape[0] # n_covariates or n_covariates + 1, depending on improved null fit vs full fit
mu = sigmoid(X_bslice @ b)
sqrtW = hl.sqrt(mu * (1 - mu))
q, r = hl.nd.qr(X * sqrtW.T.reshape(-1, 1))
h = (q * q).sum(1)
coef = r[:m, :m]
residual = y - mu
dep = q[:, :m].T @ ((residual + (h * (0.5 - mu))) / sqrtW)
delta_b_struct = hl.nd.solve_triangular(coef, dep.reshape(-1, 1), no_crash=True)
exploded = delta_b_struct.failed
delta_b = delta_b_struct.solution.reshape(-1)
max_delta_b = nd_max(hl.abs(delta_b))
return hl.bind(cont, exploded, delta_b, max_delta_b)
if max_iterations == 0:
return blank_struct.annotate(n_iterations=0, log_lkhd=0, converged=False, exploded=False)
return hl.experimental.loop(fit, dtype, 1, b)
def _firth_test(null_fit, X, y, max_iterations, tolerance) -> StructExpression:
firth_improved_null_fit = _firth_fit(null_fit.b, X, y, max_iterations=max_iterations, tolerance=tolerance)
dof = 1 # 1 variant
def cont(firth_improved_null_fit):
initial_b_full_model = hl.nd.hstack([firth_improved_null_fit.b, hl.nd.array([0.0])])
firth_fit = _firth_fit(initial_b_full_model, X, y, max_iterations=max_iterations, tolerance=tolerance)
def cont2(firth_fit):
firth_chi_sq = 2 * (firth_fit.log_lkhd - firth_improved_null_fit.log_lkhd)
firth_p = hl.pchisqtail(firth_chi_sq, dof)
blank_struct = hl.struct(
beta=hl.missing(hl.tfloat64),
chi_sq_stat=hl.missing(hl.tfloat64),
p_value=hl.missing(hl.tfloat64),
firth_null_fit=hl.missing(firth_improved_null_fit.dtype),
fit=hl.missing(firth_fit.dtype),
)
return (
hl.case()
.when(
firth_improved_null_fit.converged,
hl.case()
.when(
firth_fit.converged,
hl.struct(
beta=firth_fit.b[firth_fit.b.shape[0] - 1],
chi_sq_stat=firth_chi_sq,
p_value=firth_p,
firth_null_fit=firth_improved_null_fit,
fit=firth_fit,
),
)
.default(blank_struct.annotate(firth_null_fit=firth_improved_null_fit, fit=firth_fit)),
)
.default(blank_struct.annotate(firth_null_fit=firth_improved_null_fit))
)
return hl.bind(cont2, firth_fit)
return hl.bind(cont, firth_improved_null_fit)
@typecheck(
test=enumeration('wald', 'lrt', 'score', 'firth'),
y=oneof(expr_float64, sequenceof(expr_float64)),
x=expr_float64,
covariates=sequenceof(expr_float64),
pass_through=sequenceof(oneof(str, Expression)),
max_iterations=nullable(int),
tolerance=nullable(float),
)
def _logistic_regression_rows_nd(
test, y, x, covariates, pass_through=(), *, max_iterations: Optional[int] = None, tolerance: Optional[float] = None
) -> Table:
r"""For each row, test an input variable for association with a
binary response variable using logistic regression.
Examples
--------
Run the logistic regression Wald test per variant using a Boolean
phenotype, intercept and two covariates stored in column-indexed
fields:
>>> result_ht = hl.logistic_regression_rows(
... test='wald',
... y=dataset.pheno.is_case,
... x=dataset.GT.n_alt_alleles(),
... covariates=[1, dataset.pheno.age, dataset.pheno.is_female])
Run the logistic regression Wald test per variant using a list of binary (0/1)
phenotypes, intercept and two covariates stored in column-indexed
fields:
>>> result_ht = hl.logistic_regression_rows(
... test='wald',
... y=[dataset.pheno.is_case, dataset.pheno.is_case], # where pheno values are 0, 1, or missing
... x=dataset.GT.n_alt_alleles(),
... covariates=[1, dataset.pheno.age, dataset.pheno.is_female])
Warning
-------
:func:`.logistic_regression_rows` considers the same set of
columns (i.e., samples, points) for every row, namely those columns for
which **all** response variables and covariates are defined. For each row, missing values of
`x` are mean-imputed over these columns. As in the example, the
intercept covariate ``1`` must be included **explicitly** if desired.
Notes
-----
This method performs, for each row, a significance test of the input
variable in predicting a binary (case-control) response variable based
on the logistic regression model. The response variable type must either
be numeric (with all present values 0 or 1) or Boolean, in which case
true and false are coded as 1 and 0, respectively.
Hail supports the Wald test ('wald'), likelihood ratio test ('lrt'),
Rao score test ('score'), and Firth test ('firth'). Hail only includes
columns for which the response variable and all covariates are defined.
For each row, Hail imputes missing input values as the mean of the
non-missing values.
The example above considers a model of the form
.. math::
\mathrm{Prob}(\mathrm{is\_case}) =
\mathrm{sigmoid}(\beta_0 + \beta_1 \, \mathrm{gt}
+ \beta_2 \, \mathrm{age}
+ \beta_3 \, \mathrm{is\_female} + \varepsilon),
\quad
\varepsilon \sim \mathrm{N}(0, \sigma^2)
where :math:`\mathrm{sigmoid}` is the `sigmoid function`_, the genotype
:math:`\mathrm{gt}` is coded as 0 for HomRef, 1 for Het, and 2 for
HomVar, and the Boolean covariate :math:`\mathrm{is\_female}` is coded as
for ``True`` (female) and 0 for ``False`` (male). The null model sets
:math:`\beta_1 = 0`.
.. _sigmoid function: https://en.wikipedia.org/wiki/Sigmoid_function
The structure of the emitted row field depends on the test statistic as
shown in the tables below.
========== ================== ======= ============================================
Test Field Type Value
========== ================== ======= ============================================
Wald `beta` float64 fit effect coefficient,
:math:`\hat\beta_1`
Wald `standard_error` float64 estimated standard error,
:math:`\widehat{\mathrm{se}}`
Wald `z_stat` float64 Wald :math:`z`-statistic, equal to
:math:`\hat\beta_1 / \widehat{\mathrm{se}}`
Wald `p_value` float64 Wald p-value testing :math:`\beta_1 = 0`
LRT, Firth `beta` float64 fit effect coefficient,
:math:`\hat\beta_1`
LRT, Firth `chi_sq_stat` float64 deviance statistic
LRT, Firth `p_value` float64 LRT / Firth p-value testing
:math:`\beta_1 = 0`
Score `chi_sq_stat` float64 score statistic
Score `p_value` float64 score p-value testing :math:`\beta_1 = 0`
========== ================== ======= ============================================
For the Wald and likelihood ratio tests, Hail fits the logistic model for
each row using Newton iteration and only emits the above fields
when the maximum likelihood estimate of the coefficients converges. The
Firth test uses a modified form of Newton iteration. To help diagnose
convergence issues, Hail also emits three fields which summarize the
iterative fitting process:
================ =================== ======= ===============================
Test Field Type Value
================ =================== ======= ===============================
Wald, LRT, Firth `fit.n_iterations` int32 number of iterations until
convergence, explosion, or
reaching the max (25 for
Wald, LRT; 100 for Firth)
Wald, LRT, Firth `fit.converged` bool ``True`` if iteration converged
Wald, LRT, Firth `fit.exploded` bool ``True`` if iteration exploded
================ =================== ======= ===============================
We consider iteration to have converged when every coordinate of
:math:`\beta` changes by less than :math:`10^{-6}`. For Wald and LRT,
up to 25 iterations are attempted; in testing we find 4 or 5 iterations
nearly always suffice. Convergence may also fail due to explosion,
which refers to low-level numerical linear algebra exceptions caused by
manipulating ill-conditioned matrices. Explosion may result from (nearly)
linearly dependent covariates or complete separation_.
.. _separation: https://en.wikipedia.org/wiki/Separation_(statistics)
A more common situation in genetics is quasi-complete seperation, e.g.
variants that are observed only in cases (or controls). Such variants
inevitably arise when testing millions of variants with very low minor
allele count. The maximum likelihood estimate of :math:`\beta` under
logistic regression is then undefined but convergence may still occur
after a large number of iterations due to a very flat likelihood
surface. In testing, we find that such variants produce a secondary bump
from 10 to 15 iterations in the histogram of number of iterations per
variant. We also find that this faux convergence produces large standard
errors and large (insignificant) p-values. To not miss such variants,
consider using Firth logistic regression, linear regression, or
group-based tests.
Here's a concrete illustration of quasi-complete seperation in R. Suppose
we have 2010 samples distributed as follows for a particular variant:
======= ====== === ======
Status HomRef Het HomVar
======= ====== === ======
Case 1000 10 0
Control 1000 0 0
======= ====== === ======
The following R code fits the (standard) logistic, Firth logistic,
and linear regression models to this data, where ``x`` is genotype,
``y`` is phenotype, and ``logistf`` is from the logistf package:
.. code-block:: R
x <- c(rep(0,1000), rep(1,1000), rep(1,10)
y <- c(rep(0,1000), rep(0,1000), rep(1,10))
logfit <- glm(y ~ x, family=binomial())
firthfit <- logistf(y ~ x)
linfit <- lm(y ~ x)
The resulting p-values for the genotype coefficient are 0.991, 0.00085,
and 0.0016, respectively. The erroneous value 0.991 is due to
quasi-complete separation. Moving one of the 10 hets from case to control
eliminates this quasi-complete separation; the p-values from R are then
0.0373, 0.0111, and 0.0116, respectively, as expected for a less
significant association.
The Firth test reduces bias from small counts and resolves the issue of
separation by penalizing maximum likelihood estimation by the `Jeffrey's
invariant prior <https://en.wikipedia.org/wiki/Jeffreys_prior>`__. This
test is slower, as both the null and full model must be fit per variant,
and convergence of the modified Newton method is linear rather than
quadratic. For Firth, 100 iterations are attempted for the null model
and, if that is successful, for the full model as well. In testing we
find 20 iterations nearly always suffices. If the null model fails to
converge, then the `logreg.fit` fields reflect the null model;
otherwise, they reflect the full model.
See
`Recommended joint and meta-analysis strategies for case-control association testing of single low-count variants <http://www.ncbi.nlm.nih.gov/pmc/articles/PMC4049324/>`__
for an empirical comparison of the logistic Wald, LRT, score, and Firth
tests. The theoretical foundations of the Wald, likelihood ratio, and score
tests may be found in Chapter 3 of Gesine Reinert's notes
`Statistical Theory <http://www.stats.ox.ac.uk/~reinert/stattheory/theoryshort09.pdf>`__.
Firth introduced his approach in
`Bias reduction of maximum likelihood estimates, 1993 <http://www2.stat.duke.edu/~scs/Courses/Stat376/Papers/GibbsFieldEst/BiasReductionMLE.pdf>`__.
Heinze and Schemper further analyze Firth's approach in
`A solution to the problem of separation in logistic regression, 2002 <https://cemsiis.meduniwien.ac.at/fileadmin/msi_akim/CeMSIIS/KB/volltexte/Heinze_Schemper_2002_Statistics_in_Medicine.pdf>`__.
Hail's logistic regression tests correspond to the ``b.wald``,
``b.lrt``, and ``b.score`` tests in `EPACTS`_. For each variant, Hail
imputes missing input values as the mean of non-missing input values,
whereas EPACTS subsets to those samples with called genotypes. Hence,
Hail and EPACTS results will currently only agree for variants with no
missing genotypes.
.. _EPACTS: http://genome.sph.umich.edu/wiki/EPACTS#Single_Variant_Tests
Note
----
Use the `pass_through` parameter to include additional row fields from
matrix table underlying ``x``. For example, to include an "rsid" field, set
``pass_through=['rsid']`` or ``pass_through=[mt.rsid]``.
Parameters
----------
test : {'wald', 'lrt', 'score', 'firth'}
Statistical test.
y : :class:`.Float64Expression` or :obj:`list` of :class:`.Float64Expression`
One or more column-indexed response expressions.
All non-missing values must evaluate to 0 or 1.
Note that a :class:`.BooleanExpression` will be implicitly converted to
a :class:`.Float64Expression` with this property.
x : :class:`.Float64Expression`
Entry-indexed expression for input variable.
covariates : :obj:`list` of :class:`.Float64Expression`
Non-empty list of column-indexed covariate expressions.
pass_through : :obj:`list` of :class:`str` or :class:`.Expression`
Additional row fields to include in the resulting table.
Returns
-------
:class:`.Table`
"""
if max_iterations is None:
max_iterations = 25 if test != 'firth' else 100
if tolerance is None:
tolerance = 1e-8
assert tolerance > 0.0
if len(covariates) == 0:
raise ValueError('logistic regression requires at least one covariate expression')
mt = matrix_table_source('logistic_regresion_rows/x', x)
raise_unless_entry_indexed('logistic_regresion_rows/x', x)
y_is_list = isinstance(y, list)
if y_is_list and len(y) == 0:
raise ValueError("'logistic_regression_rows': found no values for 'y'")
y = [raise_unless_column_indexed('logistic_regression_rows/y', y) or y for y in wrap_to_list(y)]
for e in covariates:
analyze('logistic_regression_rows/covariates', e, mt._col_indices)
# _warn_if_no_intercept('logistic_regression_rows', covariates)
x_field_name = Env.get_uid()
y_field_names = [f'__y_{i}' for i in range(len(y))]
y_dict = dict(zip(y_field_names, y))
cov_field_names = [f'__cov{i}' for i in range(len(covariates))]
row_fields = _get_regression_row_fields(mt, pass_through, 'logistic_regression_rows')
# Handle filtering columns with missing values:
mt = mt.filter_cols(hl.array(y + covariates).all(hl.is_defined))
# FIXME: selecting an existing entry field should be emitted as a SelectFields
mt = mt._select_all(
col_exprs=dict(**y_dict, **dict(zip(cov_field_names, covariates))),
row_exprs=row_fields,
col_key=[],
entry_exprs={x_field_name: x},
)
ht = mt._localize_entries('entries', 'samples')
# covmat rows are samples, columns are the different covariates
ht = ht.annotate_globals(
covmat=hl.nd.array(ht.samples.map(lambda s: [s[cov_name] for cov_name in cov_field_names]))
)
# yvecs is a list of sample-length vectors, one for each dependent variable.
ht = ht.annotate_globals(yvecs=[hl.nd.array(ht.samples[y_name]) for y_name in y_field_names])
# Fit null models, which means doing a logreg fit with just the covariates for each phenotype.
def fit_null(yvec):
def error_if_not_converged(null_fit):
return (
hl.case()
.when(
~null_fit.exploded,
(
hl.case()
.when(null_fit.converged, null_fit)
.or_error(
"Failed to fit logistic regression null model (standard MLE with covariates only): "
"Newton iteration failed to converge"
)
),
)
.or_error(
hl.format(
"Failed to fit logistic regression null model (standard MLE with covariates only): "
"exploded at Newton iteration %d",
null_fit.n_iterations,
)
)
)
null_fit = logreg_fit(ht.covmat, yvec, None, max_iterations=max_iterations, tolerance=tolerance)
return hl.bind(error_if_not_converged, null_fit)
ht = ht.annotate_globals(null_fits=ht.yvecs.map(fit_null))
ht = ht.transmute(x=hl.nd.array(mean_impute(ht.entries[x_field_name])))
ht = ht.annotate(covs_and_x=hl.nd.hstack([ht.covmat, ht.x.reshape((-1, 1))]))
def run_test(yvec, null_fit):
if test == 'score':
return logistic_score_test(ht.covs_and_x, yvec, null_fit)
if test == 'firth':
return _firth_test(null_fit, ht.covs_and_x, yvec, max_iterations=max_iterations, tolerance=tolerance)
test_fit = logreg_fit(ht.covs_and_x, yvec, null_fit, max_iterations=max_iterations, tolerance=tolerance)
if test == 'wald':
return wald_test(ht.covs_and_x, test_fit)
assert test == 'lrt', test
return lrt_test(ht.covs_and_x, null_fit, test_fit)
ht = ht.select(
logistic_regression=hl.starmap(run_test, hl.zip(ht.yvecs, ht.null_fits)), **{f: ht[f] for f in row_fields}
)
assert 'null_fits' not in row_fields
assert 'logistic_regression' not in row_fields
if not y_is_list:
assert all(f not in row_fields for f in ht.null_fits[0])
assert all(f not in row_fields for f in ht.logistic_regression[0])
ht = ht.select_globals(**ht.null_fits[0])
return ht.transmute(**ht.logistic_regression[0])
ht = ht.select_globals('null_fits')
return ht
[docs]@typecheck(
test=enumeration('wald', 'lrt', 'score'),
y=expr_float64,
x=expr_float64,
covariates=sequenceof(expr_float64),
pass_through=sequenceof(oneof(str, Expression)),
max_iterations=int,
tolerance=nullable(float),
)
def poisson_regression_rows(
test, y, x, covariates, pass_through=(), *, max_iterations: int = 25, tolerance: Optional[float] = None
) -> Table:
r"""For each row, test an input variable for association with a
count response variable using `Poisson regression <https://en.wikipedia.org/wiki/Poisson_regression>`__.
Notes
-----
See :func:`.logistic_regression_rows` for more info on statistical tests
of general linear models.
Note
----
Use the `pass_through` parameter to include additional row fields from
matrix table underlying ``x``. For example, to include an "rsid" field, set
``pass_through=['rsid']`` or ``pass_through=[mt.rsid]``.
Parameters
----------
y : :class:`.Float64Expression`
Column-indexed response expression.
All non-missing values must evaluate to a non-negative integer.
x : :class:`.Float64Expression`
Entry-indexed expression for input variable.
covariates : :obj:`list` of :class:`.Float64Expression`
Non-empty list of column-indexed covariate expressions.
pass_through : :obj:`list` of :class:`str` or :class:`.Expression`
Additional row fields to include in the resulting table.
tolerance : :obj:`float`, optional
The iterative fit of this model is considered "converged" if the change in the estimated
beta is smaller than tolerance. By default the tolerance is 1e-6.
Returns
-------
:class:`.Table`
"""
if hl.current_backend().requires_lowering:
return _lowered_poisson_regression_rows(
test, y, x, covariates, pass_through, max_iterations=max_iterations, tolerance=tolerance
)
if tolerance is None:
tolerance = 1e-6
assert tolerance > 0.0
if len(covariates) == 0:
raise ValueError('Poisson regression requires at least one covariate expression')
mt = matrix_table_source('poisson_regression_rows/x', x)
raise_unless_entry_indexed('poisson_regression_rows/x', x)
analyze('poisson_regression_rows/y', y, mt._col_indices)
all_exprs = [y]
for e in covariates:
all_exprs.append(e)
analyze('poisson_regression_rows/covariates', e, mt._col_indices)
_warn_if_no_intercept('poisson_regression_rows', covariates)
x_field_name = Env.get_uid()
y_field_name = '__y'
cov_field_names = list(f'__cov{i}' for i in range(len(covariates)))
row_fields = _get_regression_row_fields(mt, pass_through, 'poisson_regression_rows')
# FIXME: selecting an existing entry field should be emitted as a SelectFields
mt = mt._select_all(
col_exprs=dict(**{y_field_name: y}, **dict(zip(cov_field_names, covariates))),
row_exprs=row_fields,
col_key=[],
entry_exprs={x_field_name: x},
)
config = {
'name': 'PoissonRegression',
'test': test,
'yField': y_field_name,
'xField': x_field_name,
'covFields': cov_field_names,
'passThrough': [x for x in row_fields if x not in mt.row_key],
'maxIterations': max_iterations,
'tolerance': tolerance,
}
return Table(ir.MatrixToTableApply(mt._mir, config)).persist()
@typecheck(
test=enumeration('wald', 'lrt', 'score'),
y=expr_float64,
x=expr_float64,
covariates=sequenceof(expr_float64),
pass_through=sequenceof(oneof(str, Expression)),
max_iterations=int,
tolerance=nullable(float),
)
def _lowered_poisson_regression_rows(
test, y, x, covariates, pass_through=(), *, max_iterations: int = 25, tolerance: Optional[float] = None
):
assert max_iterations > 0
if tolerance is None:
tolerance = 1e-8
assert tolerance > 0.0
k = len(covariates)
if k == 0:
raise ValueError('_lowered_poisson_regression_rows: at least one covariate is required.')
_warn_if_no_intercept('_lowered_poisson_regression_rows', covariates)
mt = matrix_table_source('_lowered_poisson_regression_rows/x', x)
raise_unless_entry_indexed('_lowered_poisson_regression_rows/x', x)
row_exprs = _get_regression_row_fields(mt, pass_through, '_lowered_poisson_regression_rows')
mt = mt._select_all(
row_exprs=dict(pass_through=hl.struct(**row_exprs)),
col_exprs=dict(y=y, covariates=covariates),
entry_exprs=dict(x=x),
)
# FIXME: the order of the columns is irrelevant to regression
mt = mt.key_cols_by()
mt = mt.filter_cols(hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)]))
mt = mt.annotate_globals(
**mt.aggregate_cols(
hl.struct(
yvec=hl.agg.collect(hl.float(mt.y)),
covmat=hl.agg.collect(mt.covariates.map(hl.float)),
n=hl.agg.count(),
),
_localize=False,
)
)
mt = mt.annotate_globals(
yvec=(
hl.case()
.when(mt.n - k - 1 >= 1, hl.nd.array(mt.yvec))
.or_error(
hl.format("_lowered_poisson_regression_rows: insufficient degrees of freedom: n=%s, k=%s", mt.n, k)
)
),
covmat=hl.nd.array(mt.covmat),
n_complete_samples=mt.n,
)
covmat = mt.covmat
yvec = mt.yvec
n = mt.n_complete_samples
logmean = hl.log(yvec.sum() / n)
b = hl.nd.array([logmean, *[0 for _ in range(k - 1)]])
mu = hl.exp(covmat @ b)
residual = yvec - mu
score = covmat.T @ residual
fisher = (mu * covmat.T) @ covmat
mt = mt.annotate_globals(null_fit=_poisson_fit(covmat, yvec, b, mu, score, fisher, max_iterations, tolerance))
mt = mt.annotate_globals(
null_fit=hl.case()
.when(mt.null_fit.converged, mt.null_fit)
.or_error(
hl.format(
'_lowered_poisson_regression_rows: null model did not converge: %s',
mt.null_fit.select('n_iterations', 'log_lkhd', 'converged', 'exploded'),
)
)
)
mt = mt.annotate_rows(mean_x=hl.agg.mean(mt.x))
mt = mt.annotate_rows(xvec=hl.nd.array(hl.agg.collect(hl.coalesce(mt.x, mt.mean_x))))
ht = mt.rows()
covmat = ht.covmat
null_fit = ht.null_fit
# FIXME: we should test a whole block of variants at a time not one-by-one
xvec = ht.xvec
yvec = ht.yvec
if test == 'score':
chi_sq, p = _poisson_score_test(null_fit, covmat, yvec, xvec)
return ht.select(chi_sq_stat=chi_sq, p_value=p, **ht.pass_through).select_globals('null_fit')
X = hl.nd.hstack([covmat, xvec.T.reshape(-1, 1)])
b = hl.nd.hstack([null_fit.b, hl.nd.array([0.0])])
mu = sigmoid(X @ b)
residual = yvec - mu
score = hl.nd.hstack([null_fit.score, hl.nd.array([xvec @ residual])])
fisher00 = null_fit.fisher
fisher01 = ((covmat.T * mu) @ xvec).reshape((-1, 1))
fisher10 = fisher01.T
fisher11 = hl.nd.array([[(mu * xvec.T) @ xvec]])
fisher = hl.nd.vstack([hl.nd.hstack([fisher00, fisher01]), hl.nd.hstack([fisher10, fisher11])])
test_fit = _poisson_fit(X, yvec, b, mu, score, fisher, max_iterations, tolerance)
if test == 'lrt':
return ht.select(test_fit=test_fit, **lrt_test(X, null_fit, test_fit), **ht.pass_through).select_globals(
'null_fit'
)
assert test == 'wald'
return ht.select(test_fit=test_fit, **wald_test(X, test_fit), **ht.pass_through).select_globals('null_fit')
def _poisson_fit(
X: NDArrayNumericExpression, # (N, K)
y: NDArrayNumericExpression, # (N,)
b: NDArrayNumericExpression, # (K,)
mu: NDArrayNumericExpression, # (N,)
score: NDArrayNumericExpression, # (K,)
fisher: NDArrayNumericExpression, # (K, K)
max_iterations: int,
tolerance: float,
) -> StructExpression:
"""Iteratively reweighted least squares to fit the model y ~ Poisson(exp(X \beta))
When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1.
"""
assert max_iterations >= 0
assert X.ndim == 2
assert y.ndim == 1
assert b.ndim == 1
assert mu.ndim == 1
assert score.ndim == 1
assert fisher.ndim == 2
dtype = numerical_regression_fit_dtype
blank_struct = hl.struct(**{k: hl.missing(dtype[k]) for k in dtype})
def fit(recur, iteration, b, mu, score, fisher):
def cont(exploded, delta_b, max_delta_b):
log_lkhd = y @ hl.log(mu) - mu.sum()
next_b = b + delta_b
next_mu = hl.exp(X @ next_b)
next_score = X.T @ (y - next_mu)
next_fisher = (next_mu * X.T) @ X
return (
hl.case()
.when(
exploded | hl.is_nan(delta_b[0]),
blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True),
)
.when(
max_delta_b < tolerance,
hl.struct(
b=b,
score=score,
fisher=fisher,
mu=mu,
n_iterations=iteration,
log_lkhd=log_lkhd,
converged=True,
exploded=False,
),
)
.when(
iteration == max_iterations,
blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False),
)
.default(recur(iteration + 1, next_b, next_mu, next_score, next_fisher))
)
delta_b_struct = hl.nd.solve(fisher, score, no_crash=True)
exploded = delta_b_struct.failed
delta_b = delta_b_struct.solution
max_delta_b = nd_max(delta_b.map(lambda e: hl.abs(e)))
return hl.bind(cont, exploded, delta_b, max_delta_b)
if max_iterations == 0:
return blank_struct.select(n_iterations=0, log_lkhd=0, converged=False, exploded=False)
return hl.experimental.loop(fit, dtype, 1, b, mu, score, fisher)
def _poisson_score_test(null_fit, covmat, y, xvec):
dof = 1
X = hl.nd.hstack([covmat, xvec.T.reshape(-1, 1)])
b = hl.nd.hstack([null_fit.b, hl.nd.array([0.0])])
mu = hl.exp(X @ b)
score = hl.nd.hstack([null_fit.score, hl.nd.array([xvec @ (y - mu)])])
fisher00 = null_fit.fisher
fisher01 = ((mu * covmat.T) @ xvec).reshape((-1, 1))
fisher10 = fisher01.T
fisher11 = hl.nd.array([[(mu * xvec.T) @ xvec]])
fisher = hl.nd.vstack([hl.nd.hstack([fisher00, fisher01]), hl.nd.hstack([fisher10, fisher11])])
fisher_div_score = hl.nd.solve(fisher, score, no_crash=True)
chi_sq = hl.or_missing(~fisher_div_score.failed, score @ fisher_div_score.solution)
p = hl.pchisqtail(chi_sq, dof)
return chi_sq, p
[docs]def linear_mixed_model(y, x, z_t=None, k=None, p_path=None, overwrite=False, standardize=True, mean_impute=True):
r"""Initialize a linear mixed model from a matrix table.
.. warning::
This functionality is no longer implemented/supported as of Hail 0.2.94.
"""
raise NotImplementedError("linear_mixed_model is no longer implemented/supported as of Hail 0.2.94")
[docs]@typecheck(
entry_expr=expr_float64,
model=LinearMixedModel,
pa_t_path=nullable(str),
a_t_path=nullable(str),
mean_impute=bool,
partition_size=nullable(int),
pass_through=sequenceof(oneof(str, Expression)),
)
def linear_mixed_regression_rows(
entry_expr, model, pa_t_path=None, a_t_path=None, mean_impute=True, partition_size=None, pass_through=()
):
"""For each row, test an input variable for association using a linear
mixed model.
.. warning::
This functionality is no longer implemented/supported as of Hail 0.2.94.
"""
raise NotImplementedError("linear_mixed_model is no longer implemented/supported as of Hail 0.2.94")
@typecheck(
group=expr_any,
weight=expr_float64,
y=expr_float64,
x=expr_float64,
covariates=sequenceof(expr_float64),
max_size=int,
accuracy=numeric,
iterations=int,
)
def _linear_skat(
group, weight, y, x, covariates, max_size: int = 46340, accuracy: float = 1e-6, iterations: int = 10000
):
r"""The linear sequence kernel association test (SKAT).
Linear SKAT tests if the phenotype, `y`, is significantly associated with the genotype, `x`. For
:math:`N` samples, in a group of :math:`M` variants, with :math:`K` covariates, the model is
given by:
.. math::
\begin{align*}
X &: R^{N \times K} \quad\quad \textrm{covariates} \\
G &: \{0, 1, 2\}^{N \times M} \textrm{genotypes} \\
\\
\varepsilon &\sim N(0, \sigma^2) \\
y &= \beta_0 X + \beta_1 G + \varepsilon
\end{align*}
The usual null hypothesis is :math:`\beta_1 = 0`. SKAT tests for an association, but does not
provide an effect size or other information about the association.
Wu et al. argue that, under the null hypothesis, a particular value, :math:`Q`, is distributed
according to a generalized chi-squared distribution with parameters determined by the genotypes,
weights, and residual phenotypes. The SKAT p-value is the probability of drawing even larger
values of :math:`Q`. :math:`Q` is defined by Wu et al. as:
.. math::
\begin{align*}
r &= y - \widehat{\beta_\textrm{null}} X \\
W_{ii} &= w_i \\
\\
Q &= r^T G W G^T r
\end{align*}
:math:`\widehat{\beta_\textrm{null}}` is the best-fit beta under the null model:
.. math::
y = \beta_\textrm{null} X + \varepsilon \quad\quad \varepsilon \sim N(0, \sigma^2)
Therefore :math:`r`, the residual phenotype, is the portion of the phenotype unexplained by the
covariates alone. Also notice:
1. The residual phenotypes are normally distributed with mean zero and variance
:math:`\sigma^2`.
2. :math:`G W G^T`, is a symmetric positive-definite matrix when the weights are non-negative.
We can transform the residuals into standard normal variables by normalizing by their
variance. Note that the variance is corrected for the degrees of freedom in the null model:
.. math::
\begin{align*}
\widehat{\sigma} &= \frac{1}{N - K} r^T r \\
h &= \frac{1}{\widehat{\sigma}} r \\
h &\sim N(0, 1) \\
r &= h \widehat{\sigma}
\end{align*}
We can rewrite :math:`Q` in terms of a Grammian matrix and these new standard normal random variables:
.. math::
\begin{align*}
Q &= h^T \widehat{\sigma} G W G^T \widehat{\sigma} h \\
A &= \widehat{\sigma} G W^{1/2} \\
B &= A A^T \\
\\
Q &= h^T B h \\
\end{align*}
This expression is a `"quadratic form" <https://en.wikipedia.org/wiki/Quadratic_form>`__ of the
vector :math:`h`. Because :math:`B` is a real symmetric matrix, we can eigendecompose it into an
orthogonal matrix and a diagonal matrix of eigenvalues:
.. math::
\begin{align*}
U \Lambda U^T &= B \quad\quad \Lambda \textrm{ diagonal } U \textrm{ orthogonal} \\
Q &= h^T U \Lambda U^T h
\end{align*}
An orthogonal matrix transforms a vector of i.i.d. standard normal variables into a new vector
of different i.i.d standard normal variables, so we can interpret :math:`Q` as a weighted sum of
i.i.d. standard normal variables:
.. math::
\begin{align*}
\tilde{h} &= U^T h \\
Q &= \sum_s \Lambda_{ss} \tilde{h}_s^2
\end{align*}
The distribution of such sums (indeed, any quadratic form of i.i.d. standard normal variables)
is governed by the generalized chi-squared distribution (the CDF is available in Hail as
:func:`.pgenchisq`):
.. math::
\begin{align*}
\lambda_i &= \Lambda_{ii} \\
Q &\sim \mathrm{GeneralizedChiSquared}(\lambda, \vec{1}, \vec{0}, 0, 0)
\end{align*}
Therefore, we can test the null hypothesis by calculating the probability of receiving values
larger than :math:`Q`. If that probability is very small, then the residual phenotypes are
likely not i.i.d. normal variables with variance :math:`\widehat{\sigma}^2`.
The SKAT method was originally described in:
Wu MC, Lee S, Cai T, Li Y, Boehnke M, Lin X. *Rare-variant association testing for
sequencing data with the sequence kernel association test.* Am J Hum Genet. 2011 Jul
15;89(1):82-93. doi: 10.1016/j.ajhg.2011.05.029. Epub 2011 Jul 7. PMID: 21737059; PMCID:
PMC3135811. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3135811/
Examples
--------
Generate a dataset with a phenotype noisily computed from the genotypes:
>>> hl.reset_global_randomness()
>>> mt = hl.balding_nichols_model(1, n_samples=100, n_variants=20)
>>> mt = mt.annotate_rows(gene = mt.locus.position // 12)
>>> mt = mt.annotate_rows(weight = 1)
>>> mt = mt.annotate_cols(phenotype = hl.agg.sum(mt.GT.n_alt_alleles()) - 20 + hl.rand_norm(0, 1))
Test if the phenotype is significantly associated with the genotype:
>>> skat = hl._linear_skat(
... mt.gene,
... mt.weight,
... mt.phenotype,
... mt.GT.n_alt_alleles(),
... covariates=[1.0])
>>> skat.show()
+-------+-------+----------+----------+-------+
| group | size | q_stat | p_value | fault |
+-------+-------+----------+----------+-------+
| int32 | int64 | float64 | float64 | int32 |
+-------+-------+----------+----------+-------+
| 0 | 11 | 8.76e+02 | 1.23e-05 | 0 |
| 1 | 9 | 8.13e+02 | 3.95e-05 | 0 |
+-------+-------+----------+----------+-------+
The same test, but using the original paper's suggested weights which are derived from the
allele frequency.
>>> mt = hl.variant_qc(mt)
>>> skat = hl._linear_skat(
... mt.gene,
... hl.dbeta(mt.variant_qc.AF[0], 1, 25),
... mt.phenotype,
... mt.GT.n_alt_alleles(),
... covariates=[1.0])
>>> skat.show()
+-------+-------+----------+----------+-------+
| group | size | q_stat | p_value | fault |
+-------+-------+----------+----------+-------+
| int32 | int64 | float64 | float64 | int32 |
+-------+-------+----------+----------+-------+
| 0 | 11 | 2.39e+01 | 4.32e-01 | 0 |
| 1 | 9 | 1.69e+01 | 7.82e-02 | 0 |
+-------+-------+----------+----------+-------+
Our simulated data was unweighted, so the null hypothesis appears true. In real datasets, we
expect the allele frequency to correlate with effect size.
Notice that, in the second group, the fault flag is set to 1. This indicates that the numerical
integration to calculate the p-value failed to achieve the required accuracy (by default,
1e-6). In this particular case, the null hypothesis is likely true and the numerical integration
returned a (nonsensical) value greater than one.
The `max_size` parameter allows us to skip large genes that would cause "out of memory" errors:
>>> skat = hl._linear_skat(
... mt.gene,
... mt.weight,
... mt.phenotype,
... mt.GT.n_alt_alleles(),
... covariates=[1.0],
... max_size=10)
>>> skat.show()
+-------+-------+----------+----------+-------+
| group | size | q_stat | p_value | fault |
+-------+-------+----------+----------+-------+
| int32 | int64 | float64 | float64 | int32 |
+-------+-------+----------+----------+-------+
| 0 | 11 | NA | NA | NA |
| 1 | 9 | 8.13e+02 | 3.95e-05 | 0 |
+-------+-------+----------+----------+-------+
Notes
-----
In the SKAT R package, the "weights" are actually the *square root* of the weight expression
from the paper. This method uses the definition from the paper.
The paper includes an explicit intercept term but this method expects the user to specify the
intercept as an extra covariate with the value 1.
This method does not perform small sample size correction.
The `q_stat` return value is *not* the :math:`Q` statistic from the paper. We match the output
of the SKAT R package which returns :math:`\tilde{Q}`:
.. math::
\tilde{Q} = \frac{Q}{2 \widehat{\sigma}^2}
Parameters
----------
group : :class:`.Expression`
Row-indexed expression indicating to which group a variant belongs. This is typically a gene
name or an interval.
weight : :class:`.Float64Expression`
Row-indexed expression for weights. Must be non-negative.
y : :class:`.Float64Expression`
Column-indexed response (dependent variable) expression.
x : :class:`.Float64Expression`
Entry-indexed expression for input (independent variable).
covariates : :obj:`list` of :class:`.Float64Expression`
List of column-indexed covariate expressions. You must explicitly provide an intercept term
if desired. You must provide at least one covariate.
max_size : :obj:`int`
Maximum size of group on which to run the test. Groups which exceed this size will have a
missing p-value and missing q statistic. Defaults to 46340.
accuracy : :obj:`float`
The accuracy of the p-value if fault value is zero. Defaults to 1e-6.
iterations : :obj:`int`
The maximum number of iterations used to calculate the p-value (which has no closed
form). Defaults to 1e5.
Returns
-------
:class:`.Table`
One row per-group. The key is `group`. The row fields are:
- group : the `group` parameter.
- size : :obj:`.tint64`, the number of variants in this group.
- q_stat : :obj:`.tfloat64`, the :math:`Q` statistic, see Notes for why this differs from the paper.
- p_value : :obj:`.tfloat64`, the test p-value for the null hypothesis that the genotypes
have no linear influence on the phenotypes.
- fault : :obj:`.tint32`, the fault flag from :func:`.pgenchisq`.
The global fields are:
- n_complete_samples : :obj:`.tint32`, the number of samples with neither a missing
phenotype nor a missing covariate.
- y_residual : :obj:`.tint32`, the residual phenotype from the null model. This may be
interpreted as the component of the phenotype not explained by the covariates alone.
- s2 : :obj:`.tfloat64`, the variance of the residuals, :math:`\sigma^2` in the paper.
"""
mt = matrix_table_source('skat/x', x)
k = len(covariates)
if k == 0:
raise ValueError('_linear_skat: at least one covariate is required.')
_warn_if_no_intercept('_linear_skat', covariates)
mt = mt._select_all(
row_exprs=dict(group=group, weight=weight), col_exprs=dict(y=y, covariates=covariates), entry_exprs=dict(x=x)
)
mt = mt.filter_cols(hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)]))
yvec, covmat, n = mt.aggregate_cols(
(hl.agg.collect(hl.float(mt.y)), hl.agg.collect(mt.covariates.map(hl.float)), hl.agg.count()), _localize=False
)
mt = mt.annotate_globals(yvec=hl.nd.array(yvec), covmat=hl.nd.array(covmat), n_complete_samples=n)
# Instead of finding the best-fit beta, we go directly to the best-predicted value using the
# reduced QR decomposition:
#
# Q @ R = X
# y = X beta
# X^T y = X^T X beta
# (X^T X)^-1 X^T y = beta
# (R^T Q^T Q R)^-1 R^T Q^T y = beta
# (R^T R)^-1 R^T Q^T y = beta
# R^-1 R^T^-1 R^T Q^T y = beta
# R^-1 Q^T y = beta
#
# X beta = X R^-1 Q^T y
# = Q R R^-1 Q^T y
# = Q Q^T y
#
covmat_Q, _ = hl.nd.qr(mt.covmat)
mt = mt.annotate_globals(covmat_Q=covmat_Q)
null_mu = mt.covmat_Q @ (mt.covmat_Q.T @ mt.yvec)
y_residual = mt.yvec - null_mu
mt = mt.annotate_globals(y_residual=y_residual, s2=y_residual @ y_residual.T / (n - k))
mt = mt.annotate_rows(G_row_mean=hl.agg.mean(mt.x))
mt = mt.annotate_rows(G_row=hl.agg.collect(hl.coalesce(mt.x, mt.G_row_mean)))
ht = mt.rows()
ht = ht.filter(hl.all(hl.is_defined(ht.group), hl.is_defined(ht.weight)))
ht = ht.group_by('group').aggregate(
weight_take=hl.agg.take(ht.weight, n=max_size + 1),
G_take=hl.agg.take(ht.G_row, n=max_size + 1),
size=hl.agg.count(),
)
ht = ht.annotate(
weight=hl.nd.array(hl.or_missing(hl.len(ht.weight_take) <= max_size, ht.weight_take)),
G=hl.nd.array(hl.or_missing(hl.len(ht.G_take) <= max_size, ht.G_take)).T,
)
ht = ht.annotate(Q=((ht.y_residual @ ht.G).map(lambda x: x**2) * ht.weight).sum(0))
# Null model:
#
# y = X b + e, e ~ N(0, \sigma^2)
#
# We can find a best-fit b, bhat, and a best-fit y, yhat:
#
# bhat = (X.T X).inv X.T y
#
# Q R = X (reduced QR decomposition)
# bhat = R.inv Q.T y
#
# yhat = X bhat
# = Q R R.inv Q.T y
# = Q Q.T y
#
# The residual phenotype not captured by the covariates alone is r:
#
# r = y - yhat
# = (I - Q Q.T) y
#
# We can factor the Q-statistic (note there are two Qs: the Q from the QR decomposition and the
# Q-statistic from the paper):
#
# Q = r.T G diag(w) G.T r
# Z = r.T G diag(sqrt(w))
# Q = Z Z.T
#
# Plugging in our expresion for r:
#
# Z = y.T (I - Q Q.T) G diag(sqrt(w))
#
# Notice that I - Q Q.T is symmetric (ergo X = X.T) because each summand is symmetric and sums
# of symmetric matrices are symmetric matrices.
#
# We have asserted that
#
# y ~ N(0, \sigma^2)
#
# It will soon be apparent that the distribution of Q is easier to characterize if our random
# variables are standard normals:
#
# h ~ N(0, 1)
# y = \sigma h
#
# We set \sigma^2 to the sample variance of the residual vectors.
#
# Returning to Z:
#
# Z = h.T \sigma (I - Q Q.T) G diag(sqrt(w))
# Q = Z Z.T
#
# Which we can factor into a symmetric matrix and a standard normal:
#
# A = \sigma (I - Q Q.T) G diag(sqrt(w))
# B = A A.T
# Q = h.T B h
#
# This is called a "quadratic form". It is a weighted sum of products of pairs of entries of h,
# which we have asserted are i.i.d. standard normal variables. The distribution of such sums is
# given by the generalized chi-squared distribution:
#
# U L U.T = B B is symmetric and thus has an eigendecomposition
# h.T B h = Q ~ GeneralizedChiSquare(L, 1, 0, 0, 0)
#
# The orthogonal matrix U remixes the vector of i.i.d. normal variables into a new vector of
# different i.i.d. normal variables. The L matrix is diagonal and scales each squared normal
# variable.
#
# Since B = A A.T is symmetric, its eigenvalues are the square of the singular values of A or
# A.T:
#
# W S V = A
# U L U.T = B
# = A A.T
# = W S V V.T S W
# = W S S W V is orthogonal so V V.T = I
# = W S^2 W
weights_arr = hl.array(ht.weight)
A = (
hl.case()
.when(
hl.all(weights_arr.map(lambda x: x >= 0)),
(ht.G - ht.covmat_Q @ (ht.covmat_Q.T @ ht.G)) * hl.sqrt(ht.weight),
)
.or_error(
hl.format(
'hl._linear_skat: every weight must be positive, in group %s, the weights were: %s',
ht.group,
weights_arr,
)
)
)
singular_values = hl.nd.svd(A, compute_uv=False)
# SVD(M) = U S V. U and V are unitary, therefore SVD(k M) = U (k S) V.
eigenvalues = ht.s2 * singular_values.map(lambda x: x**2)
# The R implementation of SKAT, Function.R, Get_Lambda_Approx filters the eigenvalues,
# presumably because a good estimate of the Generalized Chi-Sqaured CDF is not significantly
# affected by chi-squared components with very tiny weights.
threshold = 1e-5 * eigenvalues.sum() / eigenvalues.shape[0]
w = hl.array(eigenvalues).filter(lambda y: y >= threshold)
genchisq_data = hl.pgenchisq(
ht.Q,
w=w,
k=hl.nd.ones(hl.len(w), dtype=hl.tint32),
lam=hl.nd.zeros(hl.len(w)),
mu=0,
sigma=0,
min_accuracy=accuracy,
max_iterations=iterations,
)
ht = ht.select(
'size',
# for reasons unknown, the R implementation calls this expression the Q statistic (which is
# *not* what they write in the paper)
q_stat=ht.Q / 2 / ht.s2,
# The reasoning for taking the complement of the CDF value is:
#
# 1. Q is a measure of variance and thus positive.
#
# 2. We want to know the probability of obtaining a variance even larger ("more extreme")
#
# Ergo, we want to check the right-tail of the distribution.
p_value=1.0 - genchisq_data.value,
fault=genchisq_data.fault,
)
return ht.select_globals('y_residual', 's2', 'n_complete_samples')
[docs]@typecheck(
group=expr_any,
weight=expr_float64,
y=expr_float64,
x=expr_float64,
covariates=sequenceof(expr_float64),
max_size=int,
null_max_iterations=int,
null_tolerance=float,
accuracy=numeric,
iterations=int,
)
def _logistic_skat(
group,
weight,
y,
x,
covariates,
max_size: int = 46340,
null_max_iterations: int = 25,
null_tolerance: float = 1e-6,
accuracy: float = 1e-6,
iterations: int = 10000,
):
r"""The logistic sequence kernel association test (SKAT).
Logistic SKAT tests if the phenotype, `y`, is significantly associated with the genotype,
`x`. For :math:`N` samples, in a group of :math:`M` variants, with :math:`K` covariates, the
model is given by:
.. math::
\begin{align*}
X &: R^{N \times K} \\
G &: \{0, 1, 2\}^{N \times M} \\
\\
Y &\sim \textrm{Bernoulli}(\textrm{logit}^{-1}(\beta_0 X + \beta_1 G))
\end{align*}
The usual null hypothesis is :math:`\beta_1 = 0`. SKAT tests for an association, but does not
provide an effect size or other information about the association.
Wu et al. argue that, under the null hypothesis, a particular value, :math:`Q`, is distributed
according to a generalized chi-squared distribution with parameters determined by the genotypes,
weights, and residual phenotypes. The SKAT p-value is the probability of drawing even larger
values of :math:`Q`. If :math:`\widehat{\beta_\textrm{null}}` is the best-fit beta under the
null model:
.. math::
Y \sim \textrm{Bernoulli}(\textrm{logit}^{-1}(\beta_\textrm{null} X))
Then :math:`Q` is defined by Wu et al. as:
.. math::
\begin{align*}
p_i &= \textrm{logit}^{-1}(\widehat{\beta_\textrm{null}} X) \\
r_i &= y_i - p_i \\
W_{ii} &= w_i \\
\\
Q &= r^T G W G^T r
\end{align*}
Therefore :math:`r_i`, the residual phenotype, is the portion of the phenotype unexplained by
the covariates alone. Also notice:
1. Each sample's phenotype is Bernoulli distributed with mean :math:`p_i` and variance
:math:`\sigma^2_i = p_i(1 - p_i)`, the binomial variance.
2. :math:`G W G^T`, is a symmetric positive-definite matrix when the weights are non-negative.
We describe below our interpretation of the mathematics as described in the main body and
appendix of Wu, et al. According to the paper, the distribution of :math:`Q` is given by a
generalized chi-squared distribution whose weights are the eigenvalues of a symmetric matrix
which we call :math:`Z Z^T`:
.. math::
\begin{align*}
V_{ii} &= \sigma^2_i \\
W_{ii} &= w_i \quad\quad \textrm{the weight for variant } i \\
\\
P_0 &= V - V X (X^T V X)^{-1} X^T V \\
Z Z^T &= P_0^{1/2} G W G^T P_0^{1/2}
\end{align*}
The eigenvalues of :math:`Z Z^T` and :math:`Z^T Z` are the squared singular values of :math:`Z`;
therefore, we instead focus on :math:`Z^T Z`. In the expressions below, we elide transpositions
of symmetric matrices:
.. math::
\begin{align*}
Z Z^T &= P_0^{1/2} G W G^T P_0^{1/2} \\
Z &= P_0^{1/2} G W^{1/2} \\
Z^T Z &= W^{1/2} G^T P_0 G W^{1/2}
\end{align*}
Before substituting the definition of :math:`P_0`, simplify it using the reduced QR
decomposition:
.. math::
\begin{align*}
Q R &= V^{1/2} X \\
R^T Q^T &= X^T V^{1/2} \\
\\
P_0 &= V - V X (X^T V X)^{-1} X^T V \\
&= V - V X (R^T Q^T Q R)^{-1} X^T V \\
&= V - V X (R^T R)^{-1} X^T V \\
&= V - V X R^{-1} (R^T)^{-1} X^T V \\
&= V - V^{1/2} Q (R^T)^{-1} X^T V^{1/2} \\
&= V - V^{1/2} Q Q^T V^{1/2} \\
&= V^{1/2} (I - Q Q^T) V^{1/2} \\
\end{align*}
Substitute this simplified expression into :math:`Z`:
.. math::
\begin{align*}
Z^T Z &= W^{1/2} G^T V^{1/2} (I - Q Q^T) V^{1/2} G W^{1/2} \\
\end{align*}
Split this symmetric matrix by observing that :math:`I - Q Q^T` is idempotent:
.. math::
\begin{align*}
I - Q Q^T &= (I - Q Q^T)(I - Q Q^T)^T \\
\\
Z &= (I - Q Q^T) V^{1/2} G W^{1/2} \\
Z &= (G - Q Q^T G) V^{1/2} W^{1/2}
\end{align*}
Finally, the squared singular values of :math:`Z` are the eigenvalues of :math:`Z^T Z`, so
:math:`Q` should be distributed as follows:
.. math::
\begin{align*}
U S V^T &= Z \quad\quad \textrm{the singular value decomposition} \\
\lambda_s &= S_{ss}^2 \\
\\
Q &\sim \textrm{GeneralizedChiSquared}(\lambda, \vec{1}, \vec{0}, 0, 0)
\end{align*}
The null hypothesis test tests for the probability of observing even larger values of :math:`Q`.
The SKAT method was originally described in:
Wu MC, Lee S, Cai T, Li Y, Boehnke M, Lin X. *Rare-variant association testing for
sequencing data with the sequence kernel association test.* Am J Hum Genet. 2011 Jul
15;89(1):82-93. doi: 10.1016/j.ajhg.2011.05.029. Epub 2011 Jul 7. PMID: 21737059; PMCID:
PMC3135811. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3135811/
Examples
--------
Generate a dataset with a phenotype noisily computed from the genotypes:
>>> hl.reset_global_randomness()
>>> mt = hl.balding_nichols_model(1, n_samples=100, n_variants=20)
>>> mt = mt.annotate_rows(gene = mt.locus.position // 12)
>>> mt = mt.annotate_rows(weight = 1)
>>> mt = mt.annotate_cols(phenotype = (hl.agg.sum(mt.GT.n_alt_alleles()) - 20 + hl.rand_norm(0, 1)) > 0.5)
Test if the phenotype is significantly associated with the genotype:
>>> skat = hl._logistic_skat(
... mt.gene,
... mt.weight,
... mt.phenotype,
... mt.GT.n_alt_alleles(),
... covariates=[1.0])
>>> skat.show()
+-------+-------+----------+----------+-------+
| group | size | q_stat | p_value | fault |
+-------+-------+----------+----------+-------+
| int32 | int64 | float64 | float64 | int32 |
+-------+-------+----------+----------+-------+
| 0 | 11 | 1.78e+02 | 1.68e-04 | 0 |
| 1 | 9 | 1.39e+02 | 1.82e-03 | 0 |
+-------+-------+----------+----------+-------+
The same test, but using the original paper's suggested weights which are derived from the
allele frequency.
>>> mt = hl.variant_qc(mt)
>>> skat = hl._logistic_skat(
... mt.gene,
... hl.dbeta(mt.variant_qc.AF[0], 1, 25),
... mt.phenotype,
... mt.GT.n_alt_alleles(),
... covariates=[1.0])
>>> skat.show()
+-------+-------+----------+----------+-------+
| group | size | q_stat | p_value | fault |
+-------+-------+----------+----------+-------+
| int32 | int64 | float64 | float64 | int32 |
+-------+-------+----------+----------+-------+
| 0 | 11 | 8.04e+00 | 3.50e-01 | 0 |
| 1 | 9 | 1.22e+00 | 5.04e-01 | 0 |
+-------+-------+----------+----------+-------+
Our simulated data was unweighted, so the null hypothesis appears true. In real datasets, we
expect the allele frequency to correlate with effect size.
Notice that, in the second group, the fault flag is set to 1. This indicates that the numerical
integration to calculate the p-value failed to achieve the required accuracy (by default,
1e-6). In this particular case, the null hypothesis is likely true and the numerical integration
returned a (nonsensical) value greater than one.
The `max_size` parameter allows us to skip large genes that would cause "out of memory" errors:
>>> skat = hl._logistic_skat(
... mt.gene,
... mt.weight,
... mt.phenotype,
... mt.GT.n_alt_alleles(),
... covariates=[1.0],
... max_size=10)
>>> skat.show()
+-------+-------+----------+----------+-------+
| group | size | q_stat | p_value | fault |
+-------+-------+----------+----------+-------+
| int32 | int64 | float64 | float64 | int32 |
+-------+-------+----------+----------+-------+
| 0 | 11 | NA | NA | NA |
| 1 | 9 | 1.39e+02 | 1.82e-03 | 0 |
+-------+-------+----------+----------+-------+
Notes
-----
In the SKAT R package, the "weights" are actually the *square root* of the weight expression
from the paper. This method uses the definition from the paper.
The paper includes an explicit intercept term but this method expects the user to specify the
intercept as an extra covariate with the value 1.
This method does not perform small sample size correction.
The `q_stat` return value is *not* the :math:`Q` statistic from the paper. We match the output
of the SKAT R package which returns :math:`\tilde{Q}`:
.. math::
\tilde{Q} = \frac{Q}{2}
Parameters
----------
group : :class:`.Expression`
Row-indexed expression indicating to which group a variant belongs. This is typically a gene
name or an interval.
weight : :class:`.Float64Expression`
Row-indexed expression for weights. Must be non-negative.
y : :class:`.Float64Expression`
Column-indexed response (dependent variable) expression.
x : :class:`.Float64Expression`
Entry-indexed expression for input (independent variable).
covariates : :obj:`list` of :class:`.Float64Expression`
List of column-indexed covariate expressions. You must explicitly provide an intercept term
if desired. You must provide at least one covariate.
max_size : :obj:`int`
Maximum size of group on which to run the test. Groups which exceed this size will have a
missing p-value and missing q statistic. Defaults to 46340.
null_max_iterations : :obj:`int`
The maximum number of iterations when fitting the logistic null model. Defaults to 25.
null_tolerance : :obj:`float`
The null model logisitic regression converges when the errors is less than this. Defaults to
1e-6.
accuracy : :obj:`float`
The accuracy of the p-value if fault value is zero. Defaults to 1e-6.
iterations : :obj:`int`
The maximum number of iterations used to calculate the p-value (which has no closed
form). Defaults to 1e5.
Returns
-------
:class:`.Table`
One row per-group. The key is `group`. The row fields are:
- group : the `group` parameter.
- size : :obj:`.tint64`, the number of variants in this group.
- q_stat : :obj:`.tfloat64`, the :math:`Q` statistic, see Notes for why this differs from the paper.
- p_value : :obj:`.tfloat64`, the test p-value for the null hypothesis that the genotypes
have no linear influence on the phenotypes.
- fault : :obj:`.tint32`, the fault flag from :func:`.pgenchisq`.
The global fields are:
- n_complete_samples : :obj:`.tint32`, the number of samples with neither a missing
phenotype nor a missing covariate.
- y_residual : :obj:`.tint32`, the residual phenotype from the null model. This may be
interpreted as the component of the phenotype not explained by the covariates alone.
- s2 : :obj:`.tfloat64`, the variance of the residuals, :math:`\sigma^2` in the paper.
- null_fit:
- b : :obj:`.tndarray` vector of coefficients.
- score : :obj:`.tndarray` vector of score statistics.
- fisher : :obj:`.tndarray` matrix of fisher statistics.
- mu : :obj:`.tndarray` the expected value under the null model.
- n_iterations : :obj:`.tint32` the number of iterations before termination.
- log_lkhd : :obj:`.tfloat64` the log-likelihood of the final iteration.
- converged : :obj:`.tbool` True if the null model converged.
- exploded : :obj:`.tbool` True if the null model failed to converge due to numerical
explosion.
"""
mt = matrix_table_source('skat/x', x)
k = len(covariates)
if k == 0:
raise ValueError('_logistic_skat: at least one covariate is required.')
_warn_if_no_intercept('_logistic_skat', covariates)
mt = mt._select_all(
row_exprs=dict(group=group, weight=weight), col_exprs=dict(y=y, covariates=covariates), entry_exprs=dict(x=x)
)
mt = mt.filter_cols(hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)]))
if mt.y.dtype != hl.tbool:
mt = mt.annotate_cols(
y=(
hl.case()
.when(hl.any(mt.y == 0, mt.y == 1), hl.bool(mt.y))
.or_error(
hl.format(
f'hl._logistic_skat: phenotypes must either be True, False, 0, or 1, found: %s of type {mt.y.dtype}',
mt.y,
)
)
)
)
yvec, covmat, n = mt.aggregate_cols(
(hl.agg.collect(hl.float(mt.y)), hl.agg.collect(mt.covariates.map(hl.float)), hl.agg.count()), _localize=False
)
mt = mt.annotate_globals(yvec=hl.nd.array(yvec), covmat=hl.nd.array(covmat), n_complete_samples=n)
null_fit = logreg_fit(mt.covmat, mt.yvec, None, max_iterations=null_max_iterations, tolerance=null_tolerance)
mt = mt.annotate_globals(
null_fit=hl.case()
.when(null_fit.converged, null_fit)
.or_error(hl.format('hl._logistic_skat: null model did not converge: %s', null_fit))
)
null_mu = mt.null_fit.mu
y_residual = mt.yvec - null_mu
mt = mt.annotate_globals(y_residual=y_residual, s2=null_mu * (1 - null_mu))
mt = mt.annotate_rows(G_row_mean=hl.agg.mean(mt.x))
mt = mt.annotate_rows(G_row=hl.agg.collect(hl.coalesce(mt.x, mt.G_row_mean)))
ht = mt.rows()
ht = ht.filter(hl.all(hl.is_defined(ht.group), hl.is_defined(ht.weight)))
ht = ht.group_by('group').aggregate(
weight_take=hl.agg.take(ht.weight, n=max_size + 1),
G_take=hl.agg.take(ht.G_row, n=max_size + 1),
size=hl.agg.count(),
)
ht = ht.annotate(
weight=hl.nd.array(hl.or_missing(hl.len(ht.weight_take) <= max_size, ht.weight_take)),
G=hl.nd.array(hl.or_missing(hl.len(ht.G_take) <= max_size, ht.G_take)).T,
)
ht = ht.annotate(
# Q=ht.y_residual @ (ht.G * ht.weight) @ ht.G.T @ ht.y_residual.T
Q=((ht.y_residual @ ht.G).map(lambda x: x**2) * ht.weight).sum(0)
)
# See linear SKAT code comment for an extensive description of the mathematics here.
sqrtv = hl.sqrt(ht.s2)
Q, _ = hl.nd.qr(ht.covmat * sqrtv.reshape(-1, 1))
weights_arr = hl.array(ht.weight)
G_scaled = ht.G * sqrtv.reshape(-1, 1)
A = (
hl.case()
.when(hl.all(weights_arr.map(lambda x: x >= 0)), (G_scaled - Q @ (Q.T @ G_scaled)) * hl.sqrt(ht.weight))
.or_error(
hl.format(
'hl._logistic_skat: every weight must be positive, in group %s, the weights were: %s',
ht.group,
weights_arr,
)
)
)
singular_values = hl.nd.svd(A, compute_uv=False)
eigenvalues = singular_values.map(lambda x: x**2)
# The R implementation of SKAT, Function.R, Get_Lambda_Approx filters the eigenvalues,
# presumably because a good estimate of the Generalized Chi-Sqaured CDF is not significantly
# affected by chi-squared components with very tiny weights.
threshold = 1e-5 * eigenvalues.sum() / eigenvalues.shape[0]
w = hl.array(eigenvalues).filter(lambda y: y >= threshold)
genchisq_data = hl.pgenchisq(
ht.Q,
w=w,
k=hl.nd.ones(hl.len(w), dtype=hl.tint32),
lam=hl.nd.zeros(hl.len(w)),
mu=0,
sigma=0,
min_accuracy=accuracy,
max_iterations=iterations,
)
ht = ht.select(
'size',
# for reasons unknown, the R implementation calls this expression the Q statistic (which is
# *not* what they write in the paper)
q_stat=ht.Q / 2,
# The reasoning for taking the complement of the CDF value is:
#
# 1. Q is a measure of variance and thus positive.
#
# 2. We want to know the probability of obtaining a variance even larger ("more extreme")
#
# Ergo, we want to check the right-tail of the distribution.
p_value=1.0 - genchisq_data.value,
fault=genchisq_data.fault,
)
return ht.select_globals('y_residual', 's2', 'n_complete_samples', 'null_fit')
[docs]@typecheck(
key_expr=expr_any,
weight_expr=expr_float64,
y=expr_float64,
x=expr_float64,
covariates=sequenceof(expr_float64),
logistic=oneof(bool, sized_tupleof(nullable(int), nullable(float))),
max_size=int,
accuracy=numeric,
iterations=int,
)
def skat(
key_expr,
weight_expr,
y,
x,
covariates,
logistic: Union[bool, Tuple[int, float]] = False,
max_size: int = 46340,
accuracy: float = 1e-6,
iterations: int = 10000,
) -> Table:
r"""Test each keyed group of rows for association by linear or logistic
SKAT test.
Examples
--------
Test each gene for association using the linear sequence kernel association
test:
>>> skat_table = hl.skat(key_expr=burden_ds.gene,
... weight_expr=burden_ds.weight,
... y=burden_ds.burden.pheno,
... x=burden_ds.GT.n_alt_alleles(),
... covariates=[1, burden_ds.burden.cov1, burden_ds.burden.cov2])
.. caution::
By default, the Davies algorithm iterates up to 10k times until an
accuracy of 1e-6 is achieved. Hence a reported p-value of zero with no
issues may truly be as large as 1e-6. The accuracy and maximum number of
iterations may be controlled by the corresponding function parameters.
In general, higher accuracy requires more iterations.
.. caution::
To process a group with :math:`m` rows, several copies of an
:math:`m \times m` matrix of doubles must fit in worker memory. Groups
with tens of thousands of rows may exhaust worker memory causing the
entire job to fail. In this case, use the `max_size` parameter to skip
groups larger than `max_size`.
Warning
-------
:func:`.skat` considers the same set of columns (i.e., samples, points) for
every group, namely those columns for which **all** covariates are defined.
For each row, missing values of `x` are mean-imputed over these columns.
As in the example, the intercept covariate ``1`` must be included
**explicitly** if desired.
Notes
-----
This method provides a scalable implementation of the score-based
variance-component test originally described in
`Rare-Variant Association Testing for Sequencing Data with the Sequence Kernel Association Test
<https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3135811/>`__.
Row weights must be non-negative. Rows with missing weights are ignored. In
the R package ``skat``---which assumes rows are variants---default weights
are given by evaluating the Beta(1, 25) density at the minor allele
frequency. To replicate these weights in Hail using alternate allele
frequencies stored in a row-indexed field `AF`, one can use the expression:
>>> hl.dbeta(hl.min(ds2.AF), 1.0, 25.0) ** 2
In the logistic case, the response `y` must either be numeric (with all
present values 0 or 1) or Boolean, in which case true and false are coded
as 1 and 0, respectively.
The resulting :class:`.Table` provides the group's key (`id`), thenumber of
rows in the group (`size`), the variance component score `q_stat`, the SKAT
`p-value`, and a `fault` flag. For the toy example above, the table has the
form:
+-------+------+--------+---------+-------+
| id | size | q_stat | p_value | fault |
+=======+======+========+=========+=======+
| geneA | 2 | 4.136 | 0.205 | 0 |
+-------+------+--------+---------+-------+
| geneB | 1 | 5.659 | 0.195 | 0 |
+-------+------+--------+---------+-------+
| geneC | 3 | 4.122 | 0.192 | 0 |
+-------+------+--------+---------+-------+
Groups larger than `max_size` appear with missing `q_stat`, `p_value`, and
`fault`. The hard limit on the number of rows in a group is 46340.
Note that the variance component score `q_stat` agrees with ``Q`` in the R
package ``skat``, but both differ from :math:`Q` in the paper by the factor
:math:`\frac{1}{2\sigma^2}` in the linear case and :math:`\frac{1}{2}` in
the logistic case, where :math:`\sigma^2` is the unbiased estimator of
residual variance for the linear null model. The R package also applies a
"small-sample adjustment" to the null distribution in the logistic case
when the sample size is less than 2000. Hail does not apply this
adjustment.
The fault flag is an integer indicating whether any issues occurred when
running the Davies algorithm to compute the p-value as the right tail of a
weighted sum of :math:`\chi^2(1)` distributions.
+-------------+-----------------------------------------+
| fault value | Description |
+=============+=========================================+
| 0 | no issues |
+------+------+-----------------------------------------+
| 1 | accuracy NOT achieved |
+------+------+-----------------------------------------+
| 2 | round-off error possibly significant |
+------+------+-----------------------------------------+
| 3 | invalid parameters |
+------+------+-----------------------------------------+
| 4 | unable to locate integration parameters |
+------+------+-----------------------------------------+
| 5 | out of memory |
+------+------+-----------------------------------------+
Parameters
----------
key_expr : :class:`.Expression`
Row-indexed expression for key associated to each row.
weight_expr : :class:`.Float64Expression`
Row-indexed expression for row weights.
y : :class:`.Float64Expression`
Column-indexed response expression.
If `logistic` is ``True``, all non-missing values must evaluate to 0 or
1. Note that a :class:`.BooleanExpression` will be implicitly converted
to a :class:`.Float64Expression` with this property.
x : :class:`.Float64Expression`
Entry-indexed expression for input variable.
covariates : :obj:`list` of :class:`.Float64Expression`
List of column-indexed covariate expressions.
logistic : :obj:`bool` or :obj:`tuple` of :obj:`int` and :obj:`float`
If false, use the linear test. If true, use the logistic test with no
more than 25 logistic iterations and a convergence tolerance of 1e-6. If
a tuple is given, use the logistic test with the tuple elements as the
maximum nubmer of iterations and convergence tolerance, respectively.
max_size : :obj:`int`
Maximum size of group on which to run the test.
accuracy : :obj:`float`
Accuracy achieved by the Davies algorithm if fault value is zero.
iterations : :obj:`int`
Maximum number of iterations attempted by the Davies algorithm.
Returns
-------
:class:`.Table`
Table of SKAT results.
"""
if hl.current_backend().requires_lowering:
if logistic:
kwargs = {'accuracy': accuracy, 'iterations': iterations}
if logistic is not True:
null_max_iterations, null_tolerance = logistic
kwargs['null_max_iterations'] = null_max_iterations
kwargs['null_tolerance'] = null_tolerance
ht = hl._logistic_skat(key_expr, weight_expr, y, x, covariates, max_size, **kwargs)
else:
ht = hl._linear_skat(key_expr, weight_expr, y, x, covariates, max_size, accuracy, iterations)
ht = ht.select_globals()
return ht
mt = matrix_table_source('skat/x', x)
raise_unless_entry_indexed('skat/x', x)
analyze('skat/key_expr', key_expr, mt._row_indices)
analyze('skat/weight_expr', weight_expr, mt._row_indices)
analyze('skat/y', y, mt._col_indices)
all_exprs = [key_expr, weight_expr, y]
for e in covariates:
all_exprs.append(e)
analyze('skat/covariates', e, mt._col_indices)
_warn_if_no_intercept('skat', covariates)
# FIXME: remove this logic when annotation is better optimized
if x in mt._fields_inverse:
x_field_name = mt._fields_inverse[x]
entry_expr = {}
else:
x_field_name = Env.get_uid()
entry_expr = {x_field_name: x}
y_field_name = '__y'
weight_field_name = '__weight'
key_field_name = '__key'
cov_field_names = list(f'__cov{i}' for i in range(len(covariates)))
mt = mt._select_all(
col_exprs=dict(**{y_field_name: y}, **dict(zip(cov_field_names, covariates))),
row_exprs={weight_field_name: weight_expr, key_field_name: key_expr},
entry_exprs=entry_expr,
)
if logistic is True:
use_logistic = True
max_iterations = 25
tolerance = 1e-6
elif logistic is False:
use_logistic = False
max_iterations = 0
tolerance = 0.0
else:
assert isinstance(logistic, tuple) and len(logistic) == 2
use_logistic = True
max_iterations, tolerance = logistic
config = {
'name': 'Skat',
'keyField': key_field_name,
'weightField': weight_field_name,
'xField': x_field_name,
'yField': y_field_name,
'covFields': cov_field_names,
'logistic': use_logistic,
'maxSize': max_size,
'accuracy': accuracy,
'iterations': iterations,
'logistic_max_iterations': max_iterations,
'logistic_tolerance': tolerance,
}
return Table(ir.MatrixToTableApply(mt._mir, config)).persist()
[docs]@typecheck(p_value=expr_numeric, approximate=bool)
def lambda_gc(p_value, approximate=True):
"""
Compute genomic inflation factor (lambda GC) from an Expression of p-values.
.. include:: ../_templates/experimental.rst
Parameters
----------
p_value : :class:`.NumericExpression`
Row-indexed numeric expression of p-values.
approximate : :obj:`bool`
If False, computes exact lambda GC (slower and uses more memory).
Returns
-------
:obj:`float`
Genomic inflation factor (lambda genomic control).
"""
raise_unless_row_indexed('lambda_gc', p_value)
t = table_source('lambda_gc', p_value)
med_chisq = _lambda_gc_agg(p_value, approximate)
return t.aggregate(med_chisq)
@typecheck(p_value=expr_numeric, approximate=bool)
def _lambda_gc_agg(p_value, approximate=True):
chisq = hl.qchisqtail(p_value, 1)
if approximate:
med_chisq = hl.agg.filter(~hl.is_nan(p_value), hl.agg.approx_quantiles(chisq, 0.5))
else:
med_chisq = hl.agg.filter(~hl.is_nan(p_value), hl.median(hl.agg.collect(chisq)))
return med_chisq / hl.qchisqtail(0.5, 1)
[docs]@typecheck(ds=oneof(Table, MatrixTable), keep_star=bool, left_aligned=bool, permit_shuffle=bool)
def split_multi(ds, keep_star=False, left_aligned=False, *, permit_shuffle=False):
"""Split multiallelic variants.
Warning
-------
In order to support a wide variety of data types, this function splits only
the variants on a :class:`.MatrixTable`, but **not the genotypes**. Use
:func:`.split_multi_hts` if possible, or split the genotypes yourself using
one of the entry modification methods: :meth:`.MatrixTable.annotate_entries`,
:meth:`.MatrixTable.select_entries`, :meth:`.MatrixTable.transmute_entries`.
The resulting dataset will be keyed by the split locus and alleles.
:func:`.split_multi` adds the following fields:
- `was_split` (*bool*) -- ``True`` if this variant was originally
multiallelic, otherwise ``False``.
- `a_index` (*int*) -- The original index of this alternate allele in the
multiallelic representation (NB: 1 is the first alternate allele or the
only alternate allele in a biallelic variant). For example, 1:100:A:T,C
splits into two variants: 1:100:A:T with ``a_index = 1`` and 1:100:A:C
with ``a_index = 2``.
- `old_locus` (*locus*) -- The original, unsplit locus.
- `old_alleles` (*array<str>*) -- The original, unsplit alleles.
All other fields are left unchanged.
Warning
-------
This method assumes `ds` contains at most one non-split variant per locus. This assumption permits the
most efficient implementation of the splitting algorithm. If your queries involving `split_multi`
crash with errors about out-of-order keys, this assumption may be violated. Otherwise, this
warning likely does not apply to your dataset.
If each locus in `ds` contains one multiallelic variant and one or more biallelic variants, you
can filter to the multiallelic variants, split those, and then combine the split variants with
the original biallelic variants.
For example, the following code splits a dataset `mt` which contains a mixture of split and
non-split variants.
>>> bi = mt.filter_rows(hl.len(mt.alleles) == 2)
>>> bi = bi.annotate_rows(a_index=1, was_split=False, old_locus=bi.locus, old_alleles=bi.alleles)
>>> multi = mt.filter_rows(hl.len(mt.alleles) > 2)
>>> split = hl.split_multi(multi)
>>> mt = split.union_rows(bi)
Example
-------
:func:`.split_multi_hts`, which splits multiallelic variants for the HTS
genotype schema and updates the entry fields by downcoding the genotype, is
implemented as:
>>> sm = hl.split_multi(ds)
>>> pl = hl.or_missing(
... hl.is_defined(sm.PL),
... (hl.range(0, 3).map(lambda i: hl.min(hl.range(0, hl.len(sm.PL))
... .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), sm.a_index) == hl.unphased_diploid_gt_index_call(i))
... .map(lambda j: sm.PL[j])))))
>>> split_ds = sm.annotate_entries(
... GT=hl.downcode(sm.GT, sm.a_index),
... AD=hl.or_missing(hl.is_defined(sm.AD),
... [hl.sum(sm.AD) - sm.AD[sm.a_index], sm.AD[sm.a_index]]),
... DP=sm.DP,
... PL=pl,
... GQ=hl.gq_from_pl(pl)).drop('old_locus', 'old_alleles')
See Also
--------
:func:`.split_multi_hts`
Parameters
----------
ds : :class:`.MatrixTable` or :class:`.Table`
An unsplit dataset.
keep_star : :obj:`bool`
Do not filter out * alleles.
left_aligned : :obj:`bool`
If ``True``, variants are assumed to be left aligned and have unique
loci. This avoids a shuffle. If the assumption is violated, an error
is generated.
permit_shuffle : :obj:`bool`
If ``True``, permit a data shuffle to sort out-of-order split results.
This will only be required if input data has duplicate loci, one of
which contains more than one alternate allele.
Returns
-------
:class:`.MatrixTable` or :class:`.Table`
"""
require_row_key_variant(ds, "split_multi")
new_id = Env.get_uid()
is_table = isinstance(ds, Table)
old_row = ds.row if is_table else ds._rvrow
kept_alleles = hl.range(1, hl.len(old_row.alleles))
if not keep_star:
kept_alleles = kept_alleles.filter(lambda i: old_row.alleles[i] != "*")
def new_struct(variant, i):
return hl.struct(alleles=variant.alleles, locus=variant.locus, a_index=i, was_split=hl.len(old_row.alleles) > 2)
def split_rows(expr, rekey):
if isinstance(ds, MatrixTable):
mt = ds.annotate_rows(**{new_id: expr}).explode_rows(new_id)
if rekey:
mt = mt.key_rows_by()
else:
mt = mt.key_rows_by('locus')
new_row_expr = mt._rvrow.annotate(
locus=mt[new_id]['locus'],
alleles=mt[new_id]['alleles'],
a_index=mt[new_id]['a_index'],
was_split=mt[new_id]['was_split'],
old_locus=mt.locus,
old_alleles=mt.alleles,
).drop(new_id)
mt = mt._select_rows('split_multi', new_row_expr)
if rekey:
return mt.key_rows_by('locus', 'alleles')
else:
return MatrixTable(ir.MatrixKeyRowsBy(mt._mir, ['locus', 'alleles'], is_sorted=True))
else:
assert isinstance(ds, Table)
ht = ds.annotate(**{new_id: expr}).explode(new_id)
if rekey:
ht = ht.key_by()
else:
ht = ht.key_by('locus')
new_row_expr = ht.row.annotate(
locus=ht[new_id]['locus'],
alleles=ht[new_id]['alleles'],
a_index=ht[new_id]['a_index'],
was_split=ht[new_id]['was_split'],
old_locus=ht.locus,
old_alleles=ht.alleles,
).drop(new_id)
ht = ht._select('split_multi', new_row_expr)
if rekey:
return ht.key_by('locus', 'alleles')
else:
return Table(ir.TableKeyBy(ht._tir, ['locus', 'alleles'], is_sorted=True))
if left_aligned:
def make_struct(i):
def error_on_moved(v):
return (
hl.case()
.when(v.locus == old_row.locus, new_struct(v, i))
.or_error("Found non-left-aligned variant in split_multi")
)
return hl.bind(error_on_moved, hl.min_rep(old_row.locus, [old_row.alleles[0], old_row.alleles[i]]))
return split_rows(hl.sorted(kept_alleles.map(make_struct)), permit_shuffle)
else:
def make_struct(i, cond):
def struct_or_empty(v):
return hl.case().when(cond(v.locus), hl.array([new_struct(v, i)])).or_missing()
return hl.bind(struct_or_empty, hl.min_rep(old_row.locus, [old_row.alleles[0], old_row.alleles[i]]))
def make_array(cond):
return hl.sorted(kept_alleles.flatmap(lambda i: make_struct(i, cond)))
left = split_rows(make_array(lambda locus: locus == ds['locus']), permit_shuffle)
moved = split_rows(make_array(lambda locus: locus != ds['locus']), True)
return left.union(moved) if is_table else left.union_rows(moved, _check_cols=False)
[docs]@typecheck(ds=oneof(Table, MatrixTable), keep_star=bool, left_aligned=bool, vep_root=str, permit_shuffle=bool)
def split_multi_hts(ds, keep_star=False, left_aligned=False, vep_root='vep', *, permit_shuffle=False):
"""Split multiallelic variants for datasets that contain one or more fields
from a standard high-throughput sequencing entry schema.
.. code-block:: text
struct {
GT: call,
AD: array<int32>,
DP: int32,
GQ: int32,
PL: array<int32>,
PGT: call,
PID: str
}
For other entry fields, write your own splitting logic using
:meth:`.MatrixTable.annotate_entries`.
Examples
--------
>>> hl.split_multi_hts(dataset).write('output/split.mt')
Warning
-------
This method assumes `ds` contains at most one non-split variant per locus. This assumption permits the
most efficient implementation of the splitting algorithm. If your queries involving `split_multi_hts`
crash with errors about out-of-order keys, this assumption may be violated. Otherwise, this
warning likely does not apply to your dataset.
If each locus in `ds` contains one multiallelic variant and one or more biallelic variants, you
can filter to the multiallelic variants, split those, and then combine the split variants with
the original biallelic variants.
For example, the following code splits a dataset `mt` which contains a mixture of split and
non-split variants.
>>> bi = mt.filter_rows(hl.len(mt.alleles) == 2)
>>> bi = bi.annotate_rows(a_index=1, was_split=False)
>>> multi = mt.filter_rows(hl.len(mt.alleles) > 2)
>>> split = hl.split_multi_hts(multi)
>>> mt = split.union_rows(bi)
Notes
-----
We will explain by example. Consider a hypothetical 3-allelic
variant:
.. code-block:: text
A C,T 0/2:7,2,6:15:45:99,50,99,0,45,99
:func:`.split_multi_hts` will create two biallelic variants (one for each
alternate allele) at the same position
.. code-block:: text
A C 0/0:13,2:15:45:0,45,99
A T 0/1:9,6:15:50:50,0,99
Each multiallelic `GT` or `PGT` field is downcoded once for each alternate allele. A
call for an alternate allele maps to 1 in the biallelic variant
corresponding to itself and 0 otherwise. For example, in the example above,
0/2 maps to 0/0 and 0/1. The genotype 1/2 maps to 0/1 and 0/1.
The biallelic alt `AD` entry is just the multiallelic `AD` entry
corresponding to the alternate allele. The ref AD entry is the sum of the
other multiallelic entries.
The biallelic `DP` is the same as the multiallelic `DP`.
The biallelic `PL` entry for a genotype g is the minimum over `PL` entries
for multiallelic genotypes that downcode to g. For example, the `PL` for (A,
T) at 0/1 is the minimum of the PLs for 0/1 (50) and 1/2 (45), and thus 45.
Fixing an alternate allele and biallelic variant, downcoding gives a map
from multiallelic to biallelic alleles and genotypes. The biallelic `AD` entry
for an allele is just the sum of the multiallelic `AD` entries for alleles
that map to that allele. Similarly, the biallelic `PL` entry for a genotype is
the minimum over multiallelic `PL` entries for genotypes that map to that
genotype.
`GQ` is recomputed from `PL` if `PL` is provided and is not
missing. If not, it is copied from the original GQ.
Here is a second example for a het non-ref
.. code-block:: text
A C,T 1/2:2,8,6:16:45:99,50,99,45,0,99
splits as
.. code-block:: text
A C 0/1:8,8:16:45:45,0,99
A T 0/1:10,6:16:50:50,0,99
**VCF Info Fields**
Hail does not split fields in the info field. This means that if a
multiallelic site with `info.AC` value ``[10, 2]`` is split, each split
site will contain the same array ``[10, 2]``. The provided allele index
field `a_index` can be used to select the value corresponding to the split
allele's position:
>>> split_ds = hl.split_multi_hts(dataset)
>>> split_ds = split_ds.filter_rows(split_ds.info.AC[split_ds.a_index - 1] < 10,
... keep = False)
VCFs split by Hail and exported to new VCFs may be
incompatible with other tools, if action is not taken
first. Since the "Number" of the arrays in split multiallelic
sites no longer matches the structure on import ("A" for 1 per
allele, for example), Hail will export these fields with
number ".".
If the desired output is one value per site, then it is
possible to use annotate_variants_expr to remap these
values. Here is an example:
>>> split_ds = hl.split_multi_hts(dataset)
>>> split_ds = split_ds.annotate_rows(info = split_ds.info.annotate(AC = split_ds.info.AC[split_ds.a_index - 1]))
>>> hl.export_vcf(split_ds, 'output/export.vcf') # doctest: +SKIP
The info field AC in *data/export.vcf* will have ``Number=1``.
**New Fields**
:func:`.split_multi_hts` adds the following fields:
- `was_split` (*bool*) -- ``True`` if this variant was originally
multiallelic, otherwise ``False``.
- `a_index` (*int*) -- The original index of this alternate allele in the
multiallelic representation (NB: 1 is the first alternate allele or the
only alternate allele in a biallelic variant). For example, 1:100:A:T,C
splits into two variants: 1:100:A:T with ``a_index = 1`` and 1:100:A:C
with ``a_index = 2``.
See Also
--------
:func:`.split_multi`
Parameters
----------
ds : :class:`.MatrixTable` or :class:`.Table`
An unsplit dataset.
keep_star : :obj:`bool`
Do not filter out * alleles.
left_aligned : :obj:`bool`
If ``True``, variants are assumed to be left
aligned and have unique loci. This avoids a shuffle. If the assumption
is violated, an error is generated.
vep_root : :class:`str`
Top-level location of vep data. All variable-length VEP fields
(intergenic_consequences, motif_feature_consequences,
regulatory_feature_consequences, and transcript_consequences)
will be split properly (i.e. a_index corresponding to the VEP allele_num).
permit_shuffle : :obj:`bool`
If ``True``, permit a data shuffle to sort out-of-order split results.
This will only be required if input data has duplicate loci, one of
which contains more than one alternate allele.
Returns
-------
:class:`.MatrixTable` or :class:`.Table`
A biallelic variant dataset.
"""
split = split_multi(ds, keep_star=keep_star, left_aligned=left_aligned, permit_shuffle=permit_shuffle)
row_fields = set(ds.row)
update_rows_expression = {}
if vep_root in row_fields:
update_rows_expression[vep_root] = split[vep_root].annotate(**{
x: split[vep_root][x].filter(lambda csq: csq.allele_num == split.a_index)
for x in (
'intergenic_consequences',
'motif_feature_consequences',
'regulatory_feature_consequences',
'transcript_consequences',
)
})
if isinstance(ds, Table):
return split.annotate(**update_rows_expression).drop('old_locus', 'old_alleles')
split = split.annotate_rows(**update_rows_expression)
entry_fields = ds.entry
expected_field_types = {
'GT': hl.tcall,
'AD': hl.tarray(hl.tint),
'DP': hl.tint,
'GQ': hl.tint,
'PL': hl.tarray(hl.tint),
'PGT': hl.tcall,
'PID': hl.tstr,
}
bad_fields = []
for field in entry_fields:
if field in expected_field_types and entry_fields[field].dtype != expected_field_types[field]:
bad_fields.append((field, entry_fields[field].dtype, expected_field_types[field]))
if bad_fields:
msg = '\n '.join([f"'{x[0]}'\tfound: {x[1]}\texpected: {x[2]}" for x in bad_fields])
raise TypeError("'split_multi_hts': Found invalid types for the following fields:\n " + msg)
update_entries_expression = {}
if 'GT' in entry_fields:
update_entries_expression['GT'] = hl.downcode(split.GT, split.a_index)
if 'DP' in entry_fields:
update_entries_expression['DP'] = split.DP
if 'AD' in entry_fields:
update_entries_expression['AD'] = hl.or_missing(
hl.is_defined(split.AD), [hl.sum(split.AD) - split.AD[split.a_index], split.AD[split.a_index]]
)
if 'PL' in entry_fields:
pl = hl.or_missing(
hl.is_defined(split.PL),
(
hl.range(0, 3).map(
lambda i: hl.min(
(
hl.range(0, hl.triangle(split.old_alleles.length()))
.filter(
lambda j: hl.downcode(
hl.unphased_diploid_gt_index_call(j), split.a_index
).unphased_diploid_gt_index()
== i
)
.map(lambda j: split.PL[j])
)
)
)
),
)
if 'GQ' in entry_fields:
update_entries_expression['PL'] = pl
update_entries_expression['GQ'] = hl.or_else(hl.gq_from_pl(pl), split.GQ)
else:
update_entries_expression['PL'] = pl
elif 'GQ' in entry_fields:
update_entries_expression['GQ'] = split.GQ
if 'PGT' in entry_fields:
update_entries_expression['PGT'] = hl.downcode(split.PGT, split.a_index)
if 'PID' in entry_fields:
update_entries_expression['PID'] = split.PID
return split.annotate_entries(**update_entries_expression).drop('old_locus', 'old_alleles')
[docs]@typecheck(call_expr=expr_call)
def realized_relationship_matrix(call_expr) -> BlockMatrix:
r"""Computes the realized relationship matrix (RRM).
Examples
--------
>>> rrm = hl.realized_relationship_matrix(dataset.GT)
Notes
-----
The realized relationship matrix (RRM) is defined as follows. Consider the
:math:`n \times m` matrix :math:`C` of raw genotypes, with rows indexed by
:math:`n` samples and columns indexed by the :math:`m` bialellic autosomal
variants; :math:`C_{ij}` is the number of alternate alleles of variant
:math:`j` carried by sample :math:`i`, which can be 0, 1, 2, or missing. For
each variant :math:`j`, the sample alternate allele frequency :math:`p_j` is
computed as half the mean of the non-missing entries of column :math:`j`.
Entries of :math:`M` are then mean-centered and variance-normalized as
.. math::
M_{ij} =
\frac{C_{ij}-2p_j}
{\sqrt{\frac{m}{n} \sum_{k=1}^n (C_{ij}-2p_j)^2}},
with :math:`M_{ij} = 0` for :math:`C_{ij}` missing (i.e. mean genotype
imputation). This scaling normalizes each variant column to have empirical
variance :math:`1/m`, which gives each sample row approximately unit total
variance (assuming linkage equilibrium) and yields the :math:`n \times n`
sample correlation or realized relationship matrix (RRM) :math:`K` as simply
.. math::
K = MM^T
Note that the only difference between the realized relationship matrix and
the genetic relatedness matrix (GRM) used in
:func:`.realized_relationship_matrix` is the variant (column) normalization:
where RRM uses empirical variance, GRM uses expected variance under
Hardy-Weinberg Equilibrium.
This method drops variants with zero variance before computing kinship.
Parameters
----------
call_expr : :class:`.CallExpression`
Entry-indexed call expression on matrix table with columns corresponding
to samples.
Returns
-------
:class:`.BlockMatrix`
Realized relationship matrix for all samples. Row and column indices
correspond to matrix table column index.
"""
mt = matrix_table_source('realized_relationship_matrix/call_expr', call_expr)
raise_unless_entry_indexed('realized_relationship_matrix/call_expr', call_expr)
mt = mt.select_entries(__gt=call_expr.n_alt_alleles()).unfilter_entries()
mt = mt.select_rows(
__AC=agg.sum(mt.__gt), __ACsq=agg.sum(mt.__gt * mt.__gt), __n_called=agg.count_where(hl.is_defined(mt.__gt))
)
mt = mt.select_rows(
__mean_gt=mt.__AC / mt.__n_called, __centered_length=hl.sqrt(mt.__ACsq - (mt.__AC**2) / mt.__n_called)
)
fmt = mt.filter_rows(mt.__centered_length > 0.1) # truly non-zero values are at least sqrt(0.5)
normalized_gt = hl.or_else((fmt.__gt - fmt.__mean_gt) / fmt.__centered_length, 0.0)
try:
bm = BlockMatrix.from_entry_expr(normalized_gt)
return (bm.T @ bm) / (bm.n_rows / bm.n_cols)
except FatalError as fe:
raise FatalError(
"Could not convert MatrixTable to BlockMatrix. It's possible all variants were dropped by variance filter.\n"
"Check that the input MatrixTable has at least two samples in it: mt.count_cols()."
) from fe
[docs]@typecheck(entry_expr=expr_float64, block_size=nullable(int))
def row_correlation(entry_expr, block_size=None) -> BlockMatrix:
"""Computes the correlation matrix between row vectors.
Examples
--------
Consider the following dataset with three variants and four samples:
>>> data = [{'v': '1:1:A:C', 's': 'a', 'GT': hl.Call([0, 0])},
... {'v': '1:1:A:C', 's': 'b', 'GT': hl.Call([0, 0])},
... {'v': '1:1:A:C', 's': 'c', 'GT': hl.Call([0, 1])},
... {'v': '1:1:A:C', 's': 'd', 'GT': hl.Call([1, 1])},
... {'v': '1:2:G:T', 's': 'a', 'GT': hl.Call([0, 1])},
... {'v': '1:2:G:T', 's': 'b', 'GT': hl.Call([1, 1])},
... {'v': '1:2:G:T', 's': 'c', 'GT': hl.Call([0, 1])},
... {'v': '1:2:G:T', 's': 'd', 'GT': hl.Call([0, 0])},
... {'v': '1:3:C:G', 's': 'a', 'GT': hl.Call([0, 1])},
... {'v': '1:3:C:G', 's': 'b', 'GT': hl.Call([0, 0])},
... {'v': '1:3:C:G', 's': 'c', 'GT': hl.Call([1, 1])},
... {'v': '1:3:C:G', 's': 'd', 'GT': hl.missing(hl.tcall)}]
>>> ht = hl.Table.parallelize(data, hl.dtype('struct{v: str, s: str, GT: call}'))
>>> mt = ht.to_matrix_table(row_key=['v'], col_key=['s'])
Compute genotype correlation between all pairs of variants:
>>> ld = hl.row_correlation(mt.GT.n_alt_alleles())
>>> ld.to_numpy()
array([[ 1. , -0.85280287, 0.42640143],
[-0.85280287, 1. , -0.5 ],
[ 0.42640143, -0.5 , 1. ]])
Compute genotype correlation between consecutively-indexed variants:
>>> ld.sparsify_band(lower=0, upper=1).to_numpy()
array([[ 1. , -0.85280287, 0. ],
[ 0. , 1. , -0.5 ],
[ 0. , 0. , 1. ]])
Warning
-------
Rows with a constant value (i.e., zero variance) will result `nan`
correlation values. To avoid this, first check that all rows vary or filter
out constant rows (for example, with the help of :func:`.aggregators.stats`).
Notes
-----
In this method, each row of entries is regarded as a vector with elements
defined by `entry_expr` and missing values mean-imputed per row.
The ``(i, j)`` element of the resulting block matrix is the correlation
between rows ``i`` and ``j`` (as 0-indexed by order in the matrix table;
see :meth:`~hail.MatrixTable.add_row_index`).
The correlation of two vectors is defined as the
`Pearson correlation coeffecient <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`__
between the corresponding empirical distributions of elements,
or equivalently as the cosine of the angle between the vectors.
This method has two stages:
- writing the row-normalized block matrix to a temporary file on persistent
disk with :meth:`.BlockMatrix.from_entry_expr`. The parallelism is
``n_rows / block_size``.
- reading and multiplying this block matrix by its transpose. The
parallelism is ``(n_rows / block_size)^2`` if all blocks are computed.
Warning
-------
See all warnings on :meth:`.BlockMatrix.from_entry_expr`. In particular,
for large matrices, it may be preferable to run the two stages separately,
saving the row-normalized block matrix to a file on external storage with
:meth:`.BlockMatrix.write_from_entry_expr`.
The resulting number of matrix elements is the square of the number of rows
in the matrix table, so computing the full matrix may be infeasible. For
example, ten million rows would produce 800TB of float64 values. The
block-sparse representation on BlockMatrix may be used to work efficiently
with regions of such matrices, as in the second example above and
:meth:`ld_matrix`.
To prevent excessive re-computation, be sure to write and read the (possibly
block-sparsified) result before multiplication by another matrix.
Parameters
----------
entry_expr : :class:`.Float64Expression`
Entry-indexed numeric expression on matrix table.
block_size : :obj:`int`, optional
Block size. Default given by :meth:`.BlockMatrix.default_block_size`.
Returns
-------
:class:`.BlockMatrix`
Correlation matrix between row vectors. Row and column indices
correspond to matrix table row index.
"""
bm = BlockMatrix.from_entry_expr(entry_expr, mean_impute=True, center=True, normalize=True, block_size=block_size)
return bm @ bm.T
[docs]@typecheck(
entry_expr=expr_float64,
locus_expr=expr_locus(),
radius=oneof(int, float),
coord_expr=nullable(expr_float64),
block_size=nullable(int),
)
def ld_matrix(entry_expr, locus_expr, radius, coord_expr=None, block_size=None) -> BlockMatrix:
"""Computes the windowed correlation (linkage disequilibrium) matrix between
variants.
Examples
--------
Consider the following dataset consisting of three variants with centimorgan
coordinates and four samples:
>>> data = [{'v': '1:1:A:C', 'cm': 0.1, 's': 'a', 'GT': hl.Call([0, 0])},
... {'v': '1:1:A:C', 'cm': 0.1, 's': 'b', 'GT': hl.Call([0, 0])},
... {'v': '1:1:A:C', 'cm': 0.1, 's': 'c', 'GT': hl.Call([0, 1])},
... {'v': '1:1:A:C', 'cm': 0.1, 's': 'd', 'GT': hl.Call([1, 1])},
... {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'a', 'GT': hl.Call([0, 1])},
... {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'b', 'GT': hl.Call([1, 1])},
... {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'c', 'GT': hl.Call([0, 1])},
... {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'd', 'GT': hl.Call([0, 0])},
... {'v': '2:1:C:G', 'cm': 0.2, 's': 'a', 'GT': hl.Call([0, 1])},
... {'v': '2:1:C:G', 'cm': 0.2, 's': 'b', 'GT': hl.Call([0, 0])},
... {'v': '2:1:C:G', 'cm': 0.2, 's': 'c', 'GT': hl.Call([1, 1])},
... {'v': '2:1:C:G', 'cm': 0.2, 's': 'd', 'GT': hl.missing(hl.tcall)}]
>>> ht = hl.Table.parallelize(data, hl.dtype('struct{v: str, s: str, cm: float64, GT: call}'))
>>> ht = ht.transmute(**hl.parse_variant(ht.v))
>>> mt = ht.to_matrix_table(row_key=['locus', 'alleles'], col_key=['s'], row_fields=['cm'])
Compute linkage disequilibrium between all pairs of variants on the same
contig and within two megabases:
>>> ld = hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=2e6)
>>> ld.to_numpy()
array([[ 1. , -0.85280287, 0. ],
[-0.85280287, 1. , 0. ],
[ 0. , 0. , 1. ]])
Within one megabases:
>>> ld = hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=1e6)
>>> ld.to_numpy()
array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
Within one centimorgan:
>>> ld = hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=1.0, coord_expr=mt.cm)
>>> ld.to_numpy()
array([[ 1. , -0.85280287, 0. ],
[-0.85280287, 1. , 0. ],
[ 0. , 0. , 1. ]])
Within one centimorgan, and only calculate the upper triangle:
>>> ld = hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=1.0, coord_expr=mt.cm)
>>> ld = ld.sparsify_triangle()
>>> ld.to_numpy()
array([[ 1. , -0.85280287, 0. ],
[ 0. , 1. , 0. ],
[ 0. , 0. , 1. ]])
Notes
-----
This method sparsifies the result of :meth:`row_correlation` using
:func:`.linalg.utils.locus_windows` and
:meth:`.BlockMatrix.sparsify_row_intervals`
in order to only compute linkage disequilibrium between nearby
variants. Use :meth:`row_correlation` directly to calculate correlation
without windowing.
More precisely, variants are 0-indexed by their order in the matrix table
(see :meth:`~hail.MatrixTable.add_row_index`). Each variant is regarded as a vector of
elements defined by `entry_expr`, typically the number of alternate alleles
or genotype dosage. Missing values are mean-imputed within variant.
The method produces a symmetric block-sparse matrix supported in a
neighborhood of the diagonal. If variants :math:`i` and :math:`j` are on the
same contig and within `radius` base pairs (inclusive) then the
:math:`(i, j)` element is their
`Pearson correlation coefficient <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`__.
Otherwise, the :math:`(i, j)` element is ``0.0``.
Rows with a constant value (i.e., zero variance) will result in ``nan``
correlation values. To avoid this, first check that all variants vary or
filter out constant variants (for example, with the help of
:func:`.aggregators.stats`).
If the :meth:`.global_position` on `locus_expr` is not in ascending order,
this method will fail. Ascending order should hold for a matrix table keyed
by locus or variant (and the associated row table), or for a table that's
been ordered by `locus_expr`.
Set `coord_expr` to use a value other than position to define the windows.
This row-indexed numeric expression must be non-missing, non-``nan``, on the
same source as `locus_expr`, and ascending with respect to locus
position for each contig; otherwise the method will raise an error.
Warning
-------
See the warnings in :meth:`row_correlation`. In particular, for large
matrices it may be preferable to run its stages separately.
`entry_expr` and `locus_expr` are implicitly aligned by row-index, though
they need not be on the same source. If their sources differ in the number
of rows, an error will be raised; otherwise, unintended misalignment may
silently produce unexpected results.
Parameters
----------
entry_expr : :class:`.Float64Expression`
Entry-indexed numeric expression on matrix table.
locus_expr : :class:`.LocusExpression`
Row-indexed locus expression on a table or matrix table that is
row-aligned with the matrix table of `entry_expr`.
radius: :obj:`int` or :obj:`float`
Radius of window for row values.
coord_expr: :class:`.Float64Expression`, optional
Row-indexed numeric expression for the row value on the same table or
matrix table as `locus_expr`.
By default, the row value is given by the locus position.
block_size : :obj:`int`, optional
Block size. Default given by :meth:`.BlockMatrix.default_block_size`.
Returns
-------
:class:`.BlockMatrix`
Windowed correlation matrix between variants.
Row and column indices correspond to matrix table variant index.
"""
starts_and_stops = hl.linalg.utils.locus_windows(locus_expr, radius, coord_expr, _localize=False)
starts_and_stops = hl.tuple([
starts_and_stops[0].map(lambda i: hl.int64(i)),
starts_and_stops[1].map(lambda i: hl.int64(i)),
])
ld = hl.row_correlation(entry_expr, block_size)
return ld._sparsify_row_intervals_expr(starts_and_stops, blocks_only=False)
[docs]@typecheck(
n_populations=int,
n_samples=int,
n_variants=int,
n_partitions=nullable(int),
pop_dist=nullable(sequenceof(numeric)),
fst=nullable(sequenceof(numeric)),
af_dist=nullable(expr_any),
reference_genome=reference_genome_type,
mixture=bool,
phased=bool,
)
def balding_nichols_model(
n_populations: int,
n_samples: int,
n_variants: int,
n_partitions: Optional[int] = None,
pop_dist: Optional[List[int]] = None,
fst: Optional[List[Union[float, int]]] = None,
af_dist: Optional[Expression] = None,
reference_genome: str = 'default',
mixture: bool = False,
*,
phased: bool = False,
) -> MatrixTable:
r"""Generate a matrix table of variants, samples, and genotypes using the
Balding-Nichols or Pritchard-Stephens-Donnelly model.
Examples
--------
Generate a matrix table of genotypes with 1000 variants and 100 samples
across 3 populations:
>>> hl.reset_global_randomness()
>>> bn_ds = hl.balding_nichols_model(3, 100, 1000)
>>> bn_ds.show(n_rows=5, n_cols=5)
+---------------+------------+------+------+------+------+------+
| locus | alleles | 0.GT | 1.GT | 2.GT | 3.GT | 4.GT |
+---------------+------------+------+------+------+------+------+
| locus<GRCh37> | array<str> | call | call | call | call | call |
+---------------+------------+------+------+------+------+------+
| 1:1 | ["A","C"] | 0/1 | 0/0 | 0/1 | 0/0 | 0/0 |
| 1:2 | ["A","C"] | 1/1 | 1/1 | 1/1 | 1/1 | 0/1 |
| 1:3 | ["A","C"] | 0/1 | 0/1 | 1/1 | 0/1 | 1/1 |
| 1:4 | ["A","C"] | 0/1 | 0/0 | 0/1 | 0/0 | 0/1 |
| 1:5 | ["A","C"] | 0/1 | 0/1 | 0/1 | 0/0 | 0/0 |
+---------------+------------+------+------+------+------+------+
showing top 5 rows
showing the first 5 of 100 columns
Generate a dataset as above but with phased genotypes:
>>> hl.reset_global_randomness()
>>> bn_ds = hl.balding_nichols_model(3, 100, 1000, phased=True)
>>> bn_ds.show(n_rows=5, n_cols=5)
+---------------+------------+------+------+------+------+------+
| locus | alleles | 0.GT | 1.GT | 2.GT | 3.GT | 4.GT |
+---------------+------------+------+------+------+------+------+
| locus<GRCh37> | array<str> | call | call | call | call | call |
+---------------+------------+------+------+------+------+------+
| 1:1 | ["A","C"] | 0|0 | 0|0 | 0|0 | 0|0 | 1|0 |
| 1:2 | ["A","C"] | 1|1 | 1|1 | 1|1 | 1|1 | 1|1 |
| 1:3 | ["A","C"] | 1|1 | 1|1 | 0|1 | 1|1 | 1|1 |
| 1:4 | ["A","C"] | 0|0 | 1|0 | 0|0 | 1|0 | 0|0 |
| 1:5 | ["A","C"] | 0|0 | 0|1 | 0|0 | 0|0 | 0|0 |
+---------------+------------+------+------+------+------+------+
showing top 5 rows
showing the first 5 of 100 columns
Generate a matrix table using 4 populations, 40 samples, 150 variants, 3
partitions, population distribution ``[0.1, 0.2, 0.3, 0.4]``,
:math:`F_{ST}` values ``[.02, .06, .04, .12]``, ancestral allele
frequencies drawn from a truncated beta distribution with ``a = 0.01`` and
``b = 0.05`` over the interval ``[0.05, 1]``, and random seed 1:
>>> hl.reset_global_randomness()
>>> bn_ds = hl.balding_nichols_model(4, 40, 150, 3,
... pop_dist=[0.1, 0.2, 0.3, 0.4],
... fst=[.02, .06, .04, .12],
... af_dist=hl.rand_beta(a=0.01, b=2.0, lower=0.05, upper=1.0))
To guarantee reproducibility, we set the Hail global seed with
:func:`.set_global_seed` immediately prior to generating the dataset.
Notes
-----
This method simulates a matrix table of variants, samples, and genotypes
using the Balding-Nichols model, which we now define.
- :math:`K` populations are labeled by integers :math:`0, 1, \dots, K - 1`.
- :math:`N` samples are labeled by strings :math:`0, 1, \dots, N - 1`.
- :math:`M` variants are defined as ``1:1:A:C``, ``1:2:A:C``, ...,
``1:M:A:C``.
- The default distribution for population assignment :math:`\pi` is uniform.
- The default ancestral frequency distribution :math:`P_0` is uniform on
:math:`[0.1, 0.9]`.
All three classes are located in ``hail.stats``.
- The default :math:`F_{ST}` values are all :math:`0.1`.
The Balding-Nichols model models genotypes of individuals from a structured
population comprising :math:`K` homogeneous modern populations that have
each diverged from a single ancestral population (a `star phylogeny`). Each
sample is assigned a population by sampling from the categorical
distribution :math:`\pi`. Note that the actual size of each population is
random.
Variants are modeled as biallelic and unlinked. Ancestral allele
frequencies are drawn independently for each variant from a frequency
spectrum :math:`P_0`. The extent of genetic drift of each modern population
from the ancestral population is defined by the corresponding :math:`F_{ST}`
parameter :math:`F_k` (here and below, lowercase indices run over a range
bounded by the corresponding uppercase parameter, e.g. :math:`k = 1, \ldots,
K`). For each variant and population, allele frequencies are drawn from a
`beta distribution <https://en.wikipedia.org/wiki/Beta_distribution>`__
whose parameters are determined by the ancestral allele frequency and
:math:`F_{ST}` parameter. The beta distribution gives a continuous
approximation of the effect of genetic drift. We denote sample population
assignments by :math:`k_n`, ancestral allele frequencies by :math:`p_m`,
population allele frequencies by :math:`p_{k, m}`, and diploid, unphased
genotype calls by :math:`g_{n, m}` (0, 1, and 2 correspond to homozygous
reference, heterozygous, and homozygous variant, respectively).
The generative model is then given by:
.. math::
\begin{aligned}
k_n \,&\sim\, \pi \\
p_m \,&\sim\, P_0 \\
p_{k,m} \mid p_m\,&\sim\, \mathrm{Beta}(\mu = p_m,\, \sigma^2 = F_k p_m (1 - p_m)) \\
g_{n,m} \mid k_n, p_{k, m} \,&\sim\, \mathrm{Binomial}(2, p_{k_n, m})
\end{aligned}
The beta distribution by its mean and variance above; the usual parameters
are :math:`a = (1 - p) \frac{1 - F}{F}` and :math:`b = p \frac{1 - F}{F}` with
:math:`F = F_k` and :math:`p = p_m`.
The resulting dataset has the following fields.
Global fields:
- `bn.n_populations` (:py:data:`.tint32`) -- Number of populations.
- `bn.n_samples` (:py:data:`.tint32`) -- Number of samples.
- `bn.n_variants` (:py:data:`.tint32`) -- Number of variants.
- `bn.n_partitions` (:py:data:`.tint32`) -- Number of partitions.
- `bn.pop_dist` (:class:`.tarray` of :py:data:`.tfloat64`) -- Population distribution indexed by
population.
- `bn.fst` (:class:`.tarray` of :py:data:`.tfloat64`) -- :math:`F_{ST}` values indexed by
population.
- `bn.seed` (:py:data:`.tint32`) -- Random seed.
- `bn.mixture` (:py:data:`.tbool`) -- Value of `mixture` parameter.
Row fields:
- `locus` (:class:`.tlocus`) -- Variant locus (key field).
- `alleles` (:class:`.tarray` of :py:data:`.tstr`) -- Variant alleles (key field).
- `ancestral_af` (:py:data:`.tfloat64`) -- Ancestral allele frequency.
- `af` (:class:`.tarray` of :py:data:`.tfloat64`) -- Modern allele frequencies indexed by
population.
Column fields:
- `sample_idx` (:py:data:`.tint32`) - Sample index (key field).
- `pop` (:py:data:`.tint32`) -- Population of sample.
Entry fields:
- `GT` (:py:data:`.tcall`) -- Genotype call (diploid, unphased).
For the `Pritchard-Stephens-Donnelly model <http://www.genetics.org/content/155/2/945.long>`__,
set the `mixture` to true to treat `pop_dist` as the parameters of the
Dirichlet distribution describing admixture between the modern populations.
In this case, the type of `pop` is :class:`.tarray` of
:py:data:`.tfloat64` and the value is the mixture proportions.
Parameters
----------
n_populations : :obj:`int`
Number of modern populations.
n_samples : :obj:`int`
Total number of samples.
n_variants : :obj:`int`
Number of variants.
n_partitions : :obj:`int`, optional
Number of partitions.
Default is 1 partition per million entries or 8, whichever is larger.
pop_dist : :obj:`list` of :obj:`float`, optional
Unnormalized population distribution, a list of length
`n_populations` with non-negative values.
Default is ``[1, ..., 1]``.
fst : :obj:`list` of :obj:`float`, optional
:math:`F_{ST}` values, a list of length `n_populations` with values
in (0, 1). Default is ``[0.1, ..., 0.1]``.
af_dist : :class:`.Float64Expression`, optional
Representing a random function. Ancestral allele frequency
distribution. Default is :func:`.rand_unif` over the range
`[0.1, 0.9]` with seed 0.
reference_genome : :class:`str` or :class:`.ReferenceGenome`
Reference genome to use.
mixture : :obj:`bool`
Treat `pop_dist` as the parameters of a Dirichlet distribution,
as in the Prichard-Stevens-Donnelly model.
phased : :obj:`bool`
Generate phased genotypes.
Returns
-------
:class:`.MatrixTable`
Simulated matrix table of variants, samples, and genotypes.
"""
if pop_dist is None:
pop_dist = [1 for _ in range(n_populations)]
if fst is None:
fst = [0.1 for _ in range(n_populations)]
if af_dist is None:
af_dist = hl.rand_unif(0.1, 0.9, seed=0)
if n_partitions is None:
n_partitions = max(8, int(n_samples * n_variants / (128 * 1024 * 1024)))
# verify args
for name, var in {
"populations": n_populations,
"samples": n_samples,
"variants": n_variants,
"partitions": n_partitions,
}.items():
if var < 1:
raise ValueError("n_{} must be positive, got {}".format(name, var))
for name, var in {"pop_dist": pop_dist, "fst": fst}.items():
if len(var) != n_populations:
raise ValueError(
"{} must be of length n_populations={}, got length {}".format(name, n_populations, len(var))
)
if any(x < 0 for x in pop_dist):
raise ValueError("pop_dist must be non-negative, got {}".format(pop_dist))
if any(x <= 0 or x >= 1 for x in fst):
raise ValueError("elements of fst must satisfy 0 < x < 1, got {}".format(fst))
# verify af_dist
if not af_dist._is_scalar:
raise ExpressionException(
'balding_nichols_model expects af_dist to '
+ 'have scalar arguments: found expression '
+ 'from source {}'.format(af_dist._indices.source)
)
if af_dist.dtype != tfloat64:
raise ValueError("af_dist must be a hail function with return type tfloat64.")
info(
"balding_nichols_model: generating genotypes for {} populations, {} samples, and {} variants...".format(
n_populations, n_samples, n_variants
)
)
# generate matrix table
from numpy import linspace
n_partitions = min(n_partitions, n_variants)
start_idxs = [int(x) for x in linspace(0, n_variants, n_partitions + 1)]
idx_bounds = list(zip(start_idxs, start_idxs[1:]))
pop_f = hl.rand_dirichlet if mixture else hl.rand_cat
bn = hl.Table._generate(
contexts=idx_bounds,
globals=hl.struct(
bn=hl.struct(
n_populations=n_populations,
n_samples=n_samples,
n_variants=n_variants,
n_partitions=n_partitions,
pop_dist=pop_dist,
fst=fst,
mixture=mixture,
),
cols=hl.range(n_samples).map(lambda idx: hl.struct(sample_idx=idx, pop=pop_f(pop_dist))),
),
partitions=[
hl.Interval(**{
endpoint: hl.Struct(locus=reference_genome.locus_from_global_position(idx), alleles=['A', 'C'])
for endpoint, idx in [('start', lo), ('end', hi)]
})
for (lo, hi) in idx_bounds
],
rowfn=lambda idx_range, _: hl.range(idx_range[0], idx_range[1]).map(
lambda idx: hl.bind(
lambda ancestral: hl.struct(
locus=hl.locus_from_global_position(idx, reference_genome),
alleles=['A', 'C'],
ancestral_af=ancestral,
af=hl.array([(1 - x) / x for x in fst]).map(
lambda x: hl.rand_beta(ancestral * x, (1 - ancestral) * x)
),
entries=hl.repeat(hl.struct(), n_samples),
),
af_dist,
)
),
)
bn = bn._unlocalize_entries('entries', 'cols', ['sample_idx'])
# entry info
p = hl.sum(bn.pop * bn.af) if mixture else bn.af[bn.pop]
q = 1 - p
if phased:
mom = hl.rand_bool(p)
dad = hl.rand_bool(p)
return bn.select_entries(GT=hl.call(mom, dad, phased=True))
idx = hl.rand_cat([q**2, 2 * p * q, p**2])
return bn.select_entries(GT=hl.unphased_diploid_gt_index_call(idx))
[docs]@typecheck(mt=MatrixTable, f=anytype)
def filter_alleles(mt: MatrixTable, f: Callable) -> MatrixTable:
"""Filter alternate alleles.
.. include:: ../_templates/req_tvariant.rst
Examples
--------
Keep SNPs:
>>> ds_result = hl.filter_alleles(ds, lambda allele, i: hl.is_snp(ds.alleles[0], allele))
Keep alleles with AC > 0:
>>> ds_result = hl.filter_alleles(ds, lambda a, allele_index: ds.info.AC[allele_index - 1] > 0)
Update the AC field of the resulting dataset:
>>> updated_info = ds_result.info.annotate(AC = ds_result.new_to_old.map(lambda i: ds_result.info.AC[i-1]))
>>> ds_result = ds_result.annotate_rows(info = updated_info)
Notes
-----
The following new fields are generated:
- `old_locus` (``locus``) -- The old locus, before filtering and computing
the minimal representation.
- `old_alleles` (``array<str>``) -- The old alleles, before filtering and
computing the minimal representation.
- `old_to_new` (``array<int32>``) -- An array that maps old allele index to
new allele index. Its length is the same as `old_alleles`. Alleles that
are filtered are missing.
- `new_to_old` (``array<int32>``) -- An array that maps new allele index to
the old allele index. Its length is the same as the modified `alleles`
field.
If all alternate alleles of a variant are filtered out, the variant itself
is filtered out.
**Using** `f`
The `f` argument is a function or lambda evaluated per alternate allele to
determine whether that allele is kept. If `f` evaluates to ``True``, the
allele is kept. If `f` evaluates to ``False`` or missing, the allele is
removed.
`f` is a function that takes two arguments: the allele string (of type
:class:`.StringExpression`) and the allele index (of type
:class:`.Int32Expression`), and returns a boolean expression. This can
be either a defined function or a lambda. For example, these two usages
are equivalent:
(with a lambda)
>>> ds_result = hl.filter_alleles(ds, lambda allele, i: hl.is_snp(ds.alleles[0], allele))
(with a defined function)
>>> def filter_f(allele, allele_index):
... return hl.is_snp(ds.alleles[0], allele)
>>> ds_result = hl.filter_alleles(ds, filter_f)
Warning
-------
:func:`.filter_alleles` does not update any fields other than `locus` and
`alleles`. This means that row fields like allele count (AC) and entry
fields like allele depth (AD) can become meaningless unless they are also
updated. You can update them with :meth:`.annotate_rows` and
:meth:`.annotate_entries`.
See Also
--------
:func:`.filter_alleles_hts`
Parameters
----------
mt : :class:`.MatrixTable`
Dataset.
f : callable
Function from (allele: :class:`.StringExpression`, allele_index:
:class:`.Int32Expression`) to :class:`.BooleanExpression`
Returns
-------
:class:`.MatrixTable`
"""
require_row_key_variant(mt, 'filter_alleles')
inclusion = hl.range(0, hl.len(mt.alleles)).map(lambda i: (i == 0) | hl.bind(lambda ii: f(mt.alleles[ii], ii), i))
# old locus, old alleles, new to old, old to new
mt = mt.annotate_rows(__allele_inclusion=inclusion, old_locus=mt.locus, old_alleles=mt.alleles)
new_to_old = hl.enumerate(mt.__allele_inclusion).filter(lambda elt: elt[1]).map(lambda elt: elt[0])
old_to_new_dict = hl.dict(
hl.enumerate(hl.enumerate(mt.alleles).filter(lambda elt: mt.__allele_inclusion[elt[0]])).map(
lambda elt: (elt[1][1], elt[0])
)
)
old_to_new = hl.bind(lambda d: mt.alleles.map(lambda a: d.get(a)), old_to_new_dict)
mt = mt.annotate_rows(old_to_new=old_to_new, new_to_old=new_to_old)
new_locus_alleles = hl.min_rep(mt.locus, mt.new_to_old.map(lambda i: mt.alleles[i]))
mt = mt.annotate_rows(__new_locus=new_locus_alleles.locus, __new_alleles=new_locus_alleles.alleles)
mt = mt.filter_rows(hl.len(mt.__new_alleles) > 1)
left = mt.filter_rows((mt.locus == mt.__new_locus) & (mt.alleles == mt.__new_alleles))
right = mt.filter_rows((mt.locus != mt.__new_locus) | (mt.alleles != mt.__new_alleles))
right = right.key_rows_by(locus=right.__new_locus, alleles=right.__new_alleles)
return left.union_rows(right, _check_cols=False).drop('__allele_inclusion', '__new_locus', '__new_alleles')
[docs]@typecheck(mt=MatrixTable, f=anytype, subset=bool)
def filter_alleles_hts(mt: MatrixTable, f: Callable, subset: bool = False) -> MatrixTable:
"""Filter alternate alleles and update standard GATK entry fields.
Examples
--------
Filter to SNP alleles using the subset strategy:
>>> ds_result = hl.filter_alleles_hts(
... ds,
... lambda allele, _: hl.is_snp(ds.alleles[0], allele),
... subset=True)
Update the AC field of the resulting dataset:
>>> updated_info = ds_result.info.annotate(AC = ds_result.new_to_old.map(lambda i: ds_result.info.AC[i-1]))
>>> ds_result = ds_result.annotate_rows(info = updated_info)
Notes
-----
For usage of the `f` argument, see the :func:`.filter_alleles`
documentation.
:func:`.filter_alleles_hts` requires the dataset have the GATK VCF schema,
namely the following entry fields in this order:
.. code-block:: text
GT: call
AD: array<int32>
DP: int32
GQ: int32
PL: array<int32>
Use :meth:`.MatrixTable.select_entries` to rearrange these fields if
necessary.
The following new fields are generated:
- `old_locus` (``locus``) -- The old locus, before filtering and computing
the minimal representation.
- `old_alleles` (``array<str>``) -- The old alleles, before filtering and
computing the minimal representation.
- `old_to_new` (``array<int32>``) -- An array that maps old allele index to
new allele index. Its length is the same as `old_alleles`. Alleles that
are filtered are missing.
- `new_to_old` (``array<int32>``) -- An array that maps new allele index to
the old allele index. Its length is the same as the modified `alleles`
field.
**Downcode algorithm**
We will illustrate the behavior on the example genotype below
when filtering the first alternate allele (allele 1) at a site
with 1 reference allele and 2 alternate alleles.
.. code-block:: text
GT: 1/2
GQ: 10
AD: 0,50,35
0 | 1000
1 | 1000 10
2 | 1000 0 20
+-----------------
0 1 2
The downcode algorithm recodes occurances of filtered alleles
to occurances of the reference allele (e.g. 1 -> 0 in our
example). So the depths of filtered alleles in the AD field
are added to the depth of the reference allele. Where
downcoding filtered alleles merges distinct genotypes, the
minimum PL is used (since PL is on a log scale, this roughly
corresponds to adding probabilities). The PLs are then
re-normalized (shifted) so that the most likely genotype has a
PL of 0, and GT is set to this genotype. If an allele is
filtered, this algorithm acts similarly to
:func:`.split_multi_hts`.
The downcode algorithm would produce the following:
.. code-block:: text
GT: 0/1
GQ: 10
AD: 35,50
0 | 20
1 | 0 10
+-----------
0 1
In summary:
- GT: Downcode filtered alleles to reference.
- AD: Columns of filtered alleles are eliminated and their
values are added to the reference column, e.g., filtering
alleles 1 and 2 transforms ``25,5,10,20`` to ``40,20``.
- DP: No change.
- PL: Downcode filtered alleles to reference, combine PLs
using minimum for each overloaded genotype, and shift so
the overall minimum PL is 0.
- GQ: The second-lowest PL (after shifting).
**Subset algorithm**
We will illustrate the behavior on the example genotype below
when filtering the first alternate allele (allele 1) at a site
with 1 reference allele and 2 alternate alleles.
.. code-block:: text
GT: 1/2
GQ: 10
AD: 0,50,35
0 | 1000
1 | 1000 10
2 | 1000 0 20
+-----------------
0 1 2
The subset algorithm subsets the AD and PL arrays
(i.e. removes entries corresponding to filtered alleles) and
then sets GT to the genotype with the minimum PL. Note that
if the genotype changes (as in the example), the PLs are
re-normalized (shifted) so that the most likely genotype has a
PL of 0. Qualitatively, subsetting corresponds to the belief
that the filtered alleles are not real so we should discard
any probability mass associated with them.
The subset algorithm would produce the following:
.. code-block:: text
GT: 1/1
GQ: 980
AD: 0,50
0 | 980
1 | 980 0
+-----------
0 1
In summary:
- GT: Set to most likely genotype based on the PLs ignoring
the filtered allele(s).
- AD: The filtered alleles' columns are eliminated, e.g.,
filtering alleles 1 and 2 transforms ``25,5,10,20`` to
``25,20``.
- DP: Unchanged.
- PL: Columns involving filtered alleles are eliminated and
the remaining columns' values are shifted so the minimum
value is 0.
- GQ: The second-lowest PL (after shifting).
Warning
-------
:func:`.filter_alleles_hts` does not update any row fields other than
`locus` and `alleles`. This means that row fields like allele count (AC) can
become meaningless unless they are also updated. You can update them with
:meth:`.annotate_rows`.
See Also
--------
:func:`.filter_alleles`
Parameters
----------
mt : :class:`.MatrixTable`
f : callable
Function from (allele: :class:`.StringExpression`, allele_index:
:class:`.Int32Expression`) to :class:`.BooleanExpression`
subset : :obj:`.bool`
Subset PL field if ``True``, otherwise downcode PL field. The
calculation of GT and GQ also depend on whether one subsets or
downcodes the PL.
Returns
-------
:class:`.MatrixTable`
"""
if mt.entry.dtype != hl.hts_entry_schema:
raise FatalError(
"'filter_alleles_hts': entry schema must be the HTS entry schema:\n"
" found: {}\n"
" expected: {}\n"
" Use 'hl.filter_alleles' to split entries with non-HTS entry fields.".format(
mt.entry.dtype, hl.hts_entry_schema
)
)
mt = filter_alleles(mt, f)
if subset:
newPL = hl.if_else(
hl.is_defined(mt.PL),
hl.bind(
lambda unnorm: unnorm - hl.min(unnorm),
hl.range(0, hl.triangle(mt.alleles.length())).map(
lambda newi: hl.bind(
lambda newc: mt.PL[
hl.call(mt.new_to_old[newc[0]], mt.new_to_old[newc[1]]).unphased_diploid_gt_index()
],
hl.unphased_diploid_gt_index_call(newi),
)
),
),
hl.missing(tarray(tint32)),
)
return mt.annotate_entries(
GT=hl.unphased_diploid_gt_index_call(hl.argmin(newPL, unique=True)),
AD=hl.if_else(
hl.is_defined(mt.AD),
hl.range(0, mt.alleles.length()).map(lambda newi: mt.AD[mt.new_to_old[newi]]),
hl.missing(tarray(tint32)),
),
# DP unchanged
GQ=hl.gq_from_pl(newPL),
PL=newPL,
)
# otherwise downcode
else:
mt = mt.annotate_rows(__old_to_new_no_na=mt.old_to_new.map(lambda x: hl.or_else(x, 0)))
newPL = hl.if_else(
hl.is_defined(mt.PL),
(
hl.range(0, hl.triangle(hl.len(mt.alleles))).map(
lambda newi: hl.min(
hl.range(0, hl.triangle(hl.len(mt.old_alleles)))
.filter(
lambda oldi: hl.bind(
lambda oldc: hl.call(mt.__old_to_new_no_na[oldc[0]], mt.__old_to_new_no_na[oldc[1]])
== hl.unphased_diploid_gt_index_call(newi),
hl.unphased_diploid_gt_index_call(oldi),
)
)
.map(lambda oldi: mt.PL[oldi])
)
)
),
hl.missing(tarray(tint32)),
)
return mt.annotate_entries(
GT=hl.call(mt.__old_to_new_no_na[mt.GT[0]], mt.__old_to_new_no_na[mt.GT[1]]),
AD=hl.if_else(
hl.is_defined(mt.AD),
(
hl.range(0, hl.len(mt.alleles)).map(
lambda newi: hl.sum(
hl.range(0, hl.len(mt.old_alleles))
.filter(lambda oldi: mt.__old_to_new_no_na[oldi] == newi)
.map(lambda oldi: mt.AD[oldi])
)
)
),
hl.missing(tarray(tint32)),
),
# DP unchanged
GQ=hl.gq_from_pl(newPL),
PL=newPL,
).drop('__old_to_new_no_na')
@typecheck(mt=MatrixTable, call_field=str, r2=numeric, bp_window_size=int, memory_per_core=int)
def _local_ld_prune(mt, call_field, r2=0.2, bp_window_size=1000000, memory_per_core=256):
bytes_per_core = memory_per_core * 1024 * 1024
fraction_memory_to_use = 0.25
variant_byte_overhead = 50
genotypes_per_pack = 32
n_samples = mt.count_cols()
min_bytes_per_core = math.ceil((1 / fraction_memory_to_use) * 8 * n_samples + variant_byte_overhead)
if bytes_per_core < min_bytes_per_core:
raise ValueError("memory_per_core must be greater than {} MB".format(min_bytes_per_core // (1024 * 1024)))
bytes_per_variant = math.ceil(8 * n_samples / genotypes_per_pack) + variant_byte_overhead
bytes_available_per_core = bytes_per_core * fraction_memory_to_use
max_queue_size = int(max(1.0, math.ceil(bytes_available_per_core / bytes_per_variant)))
info(f'ld_prune: running local pruning stage with max queue size of {max_queue_size} variants')
return Table(
ir.MatrixToTableApply(
mt._mir,
{
'name': 'LocalLDPrune',
'callField': call_field,
'r2Threshold': float(r2),
'windowSize': bp_window_size,
'maxQueueSize': max_queue_size,
},
)
).persist()
[docs]@typecheck(
call_expr=expr_call,
r2=numeric,
bp_window_size=int,
memory_per_core=int,
keep_higher_maf=bool,
block_size=nullable(int),
)
def ld_prune(call_expr, r2=0.2, bp_window_size=1000000, memory_per_core=256, keep_higher_maf=True, block_size=None):
"""Returns a maximal subset of variants that are nearly uncorrelated within each window.
.. include:: ../_templates/req_diploid_gt.rst
.. include:: ../_templates/req_biallelic.rst
.. include:: ../_templates/req_tvariant.rst
Examples
--------
Prune variants in linkage disequilibrium by filtering a dataset to those variants returned
by :func:`.ld_prune`. If the dataset contains multiallelic variants, the multiallelic variants
must be filtered out or split before being passed to :func:`.ld_prune`.
>>> biallelic_dataset = dataset.filter_rows(hl.len(dataset.alleles) == 2)
>>> pruned_variant_table = hl.ld_prune(biallelic_dataset.GT, r2=0.2, bp_window_size=500000)
>>> filtered_ds = dataset.filter_rows(hl.is_defined(pruned_variant_table[dataset.row_key]))
Notes
-----
This method finds a maximal subset of variants such that the squared Pearson
correlation coefficient :math:`r^2` of any pair at most `bp_window_size`
base pairs apart is strictly less than `r2`. Each variant is represented as
a vector over samples with elements given by the (mean-imputed) number of
alternate alleles. In particular, even if present, **phase information is
ignored**. Variants that do not vary across samples are dropped.
The method prunes variants in linkage disequilibrium in three stages.
- The first, "local pruning" stage prunes correlated variants within each
partition, using a local variant queue whose size is determined by
`memory_per_core`. A larger queue may facilitate more local pruning in
this stage. Minor allele frequency is not taken into account. The
parallelism is the number of matrix table partitions.
- The second, "global correlation" stage uses block-sparse matrix
multiplication to compute correlation between each pair of remaining
variants within `bp_window_size` base pairs, and then forms a graph of
correlated variants. The parallelism of writing the locally-pruned matrix
table as a block matrix is ``n_locally_pruned_variants / block_size``.
- The third, "global pruning" stage applies :func:`.maximal_independent_set`
to prune variants from this graph until no edges remain. This algorithm
iteratively removes the variant with the highest vertex degree. If
`keep_higher_maf` is true, then in the case of a tie for highest degree,
the variant with lowest minor allele frequency is removed.
Warning
-------
The locally-pruned matrix table and block matrix are stored as temporary files
on persistent disk. See the warnings on `BlockMatrix.from_entry_expr` with
regard to memory and Hadoop replication errors.
Parameters
----------
call_expr : :class:`.CallExpression`
Entry-indexed call expression on a matrix table with row-indexed
variants and column-indexed samples.
r2 : :obj:`float`
Squared correlation threshold (exclusive upper bound).
Must be in the range [0.0, 1.0].
bp_window_size: :obj:`int`
Window size in base pairs (inclusive upper bound).
memory_per_core : :obj:`int`
Memory in MB per core for local pruning queue.
keep_higher_maf: :obj:`int`
If ``True``, break ties at each step of the global pruning stage by
preferring to keep variants with higher minor allele frequency.
block_size: :obj:`int`, optional
Block size for block matrices in the second stage.
Default given by :meth:`.BlockMatrix.default_block_size`.
Returns
-------
:class:`.Table`
Table of a maximal independent set of variants.
"""
if block_size is None:
block_size = BlockMatrix.default_block_size()
if not 0.0 <= r2 <= 1:
raise ValueError(f'r2 must be in the range [0.0, 1.0], found {r2}')
if bp_window_size < 0:
raise ValueError(f'bp_window_size must be non-negative, found {bp_window_size}')
raise_unless_entry_indexed('ld_prune/call_expr', call_expr)
mt = matrix_table_source('ld_prune/call_expr', call_expr)
require_row_key_variant(mt, 'ld_prune')
# FIXME: remove once select_entries on a field is free
if call_expr in mt._fields_inverse:
field = mt._fields_inverse[call_expr]
else:
field = Env.get_uid()
mt = mt.select_entries(**{field: call_expr})
mt = mt.select_rows().select_cols()
mt = mt.distinct_by_row()
locally_pruned_table_path = new_temp_file()
(
_local_ld_prune(require_biallelic(mt, 'ld_prune'), field, r2, bp_window_size, memory_per_core).write(
locally_pruned_table_path, overwrite=True
)
)
locally_pruned_table = hl.read_table(locally_pruned_table_path).add_index()
mt = mt.annotate_rows(info=locally_pruned_table[mt.row_key])
mt = mt.filter_rows(hl.is_defined(mt.info)).unfilter_entries()
std_gt_bm = BlockMatrix.from_entry_expr(
hl.or_else((mt[field].n_alt_alleles() - mt.info.mean) * mt.info.centered_length_rec, 0.0), block_size=block_size
)
r2_bm = (std_gt_bm @ std_gt_bm.T) ** 2
_, stops = hl.linalg.utils.locus_windows(locally_pruned_table.locus, bp_window_size)
entries = r2_bm.sparsify_row_intervals(range(stops.size), stops, blocks_only=True).entries(keyed=False)
entries = entries.filter((entries.entry >= r2) & (entries.i < entries.j))
entries = entries.select(i=hl.int32(entries.i), j=hl.int32(entries.j))
if keep_higher_maf:
fields = ['mean', 'locus']
else:
fields = ['locus']
info = locally_pruned_table.aggregate(
hl.agg.collect(locally_pruned_table.row.select('idx', *fields)), _localize=False
)
info = hl.sorted(info, key=lambda x: x.idx)
entries = entries.annotate_globals(info=info)
entries = entries.filter(
(entries.info[entries.i].locus.contig == entries.info[entries.j].locus.contig)
& (entries.info[entries.j].locus.position - entries.info[entries.i].locus.position <= bp_window_size)
)
if keep_higher_maf:
entries = entries.annotate(
i=hl.struct(
idx=entries.i, twice_maf=hl.min(entries.info[entries.i].mean, 2.0 - entries.info[entries.i].mean)
),
j=hl.struct(
idx=entries.j, twice_maf=hl.min(entries.info[entries.j].mean, 2.0 - entries.info[entries.j].mean)
),
)
def tie_breaker(left, right):
return hl.sign(right.twice_maf - left.twice_maf)
else:
tie_breaker = None
variants_to_remove = hl.maximal_independent_set(
entries.i, entries.j, keep=False, tie_breaker=tie_breaker, keyed=False
)
locally_pruned_table = locally_pruned_table.annotate_globals(
variants_to_remove=variants_to_remove.aggregate(
hl.agg.collect_as_set(variants_to_remove.node.idx), _localize=False
)
)
return (
locally_pruned_table.filter(
locally_pruned_table.variants_to_remove.contains(hl.int32(locally_pruned_table.idx)), keep=False
)
.select()
.persist()
)
def _warn_if_no_intercept(caller, covariates):
if all([e._indices.axes for e in covariates]):
warning(
f'{caller}: model appears to have no intercept covariate.'
'\n To include an intercept, add 1.0 to the list of covariates.'
)
return True
return False