import hail as hl
from hail import ir
from hail.backend.spark_backend import SparkBackend
from hail.expr import analyze
from hail.expr.expressions import expr_float64
from hail.linalg import BlockMatrix
from hail.matrixtable import MatrixTable
from hail.methods.misc import require_biallelic, require_col_key_str
from hail.table import Table
from hail.typecheck import nullable, numeric, typecheck
from hail.utils.java import Env
[docs]@typecheck(dataset=MatrixTable, maf=nullable(expr_float64), bounded=bool, min=nullable(numeric), max=nullable(numeric))
def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None) -> Table:
"""Compute matrix of identity-by-descent estimates.
.. include:: ../_templates/req_tstring.rst
.. include:: ../_templates/req_tvariant.rst
.. include:: ../_templates/req_biallelic.rst
Examples
--------
To calculate a full IBD matrix, using minor allele frequencies computed
from the dataset itself:
>>> hl.identity_by_descent(dataset)
To calculate an IBD matrix containing only pairs of samples with
``PI_HAT`` in :math:`[0.2, 0.9]`, using minor allele frequencies stored in
the row field `panel_maf`:
>>> hl.identity_by_descent(dataset, maf=dataset['panel_maf'], min=0.2, max=0.9)
Notes
-----
The dataset must have a column field named `s` which is a :class:`.StringExpression`
and which uniquely identifies a column.
The implementation is based on the IBD algorithm described in the `PLINK
paper <http://www.ncbi.nlm.nih.gov/pmc/articles/PMC1950838>`__.
:func:`.identity_by_descent` requires the dataset to be biallelic and does
not perform LD pruning. Linkage disequilibrium may bias the result so
consider filtering variants first.
The resulting :class:`.Table` entries have the type: *{ i: String,
j: String, ibd: { Z0: Double, Z1: Double, Z2: Double, PI_HAT: Double },
ibs0: Long, ibs1: Long, ibs2: Long }*. The key list is: `*i: String, j:
String*`.
Conceptually, the output is a symmetric, sample-by-sample matrix. The
output table has the following form
.. code-block:: text
i j ibd.Z0 ibd.Z1 ibd.Z2 ibd.PI_HAT ibs0 ibs1 ibs2
sample1 sample2 1.0000 0.0000 0.0000 0.0000 ...
sample1 sample3 1.0000 0.0000 0.0000 0.0000 ...
sample1 sample4 0.6807 0.0000 0.3193 0.3193 ...
sample1 sample5 0.1966 0.0000 0.8034 0.8034 ...
Parameters
----------
dataset : :class:`.MatrixTable`
Variant-keyed and sample-keyed :class:`.MatrixTable` containing genotype information.
maf : :class:`.Float64Expression`, optional
Row-indexed expression for the minor allele frequency.
bounded : :obj:`bool`
Forces the estimations for ``Z0``, ``Z1``, ``Z2``, and ``PI_HAT`` to take
on biologically meaningful values (in the range :math:`[0,1]`).
min : :obj:`float` or :obj:`None`
Sample pairs with a ``PI_HAT`` below this value will
not be included in the output. Must be in :math:`[0,1]`.
max : :obj:`float` or :obj:`None`
Sample pairs with a ``PI_HAT`` above this value will
not be included in the output. Must be in :math:`[0,1]`.
Returns
-------
:class:`.Table`
"""
require_col_key_str(dataset, 'identity_by_descent')
if not isinstance(dataset.GT, hl.CallExpression):
raise Exception('GT field must be of type Call')
if maf is not None:
analyze('identity_by_descent/maf', maf, dataset._row_indices)
dataset = dataset.select_rows(__maf=maf)
dataset = dataset.filter_rows(hl.is_defined(dataset.__maf))
else:
dataset = dataset.select_rows()
dataset = dataset.select_cols().select_globals().select_entries('GT')
dataset = require_biallelic(dataset, 'ibd')
if isinstance(Env.backend(), SparkBackend):
return Table(
ir.MatrixToTableApply(
dataset._mir,
{
'name': 'IBD',
'mafFieldName': '__maf' if maf is not None else None,
'bounded': bounded,
'min': min,
'max': max,
},
)
).persist()
min = min or 0
max = max or 1
if not 0 <= min <= max <= 1:
raise Exception(f"invalid pi hat filters {min} {max}")
sample_ids = dataset.s.collect()
if len(sample_ids) != len(set(sample_ids)):
raise Exception('duplicate sample ids found')
dataset = dataset.annotate_entries(
n_alt_alleles=hl.or_else(dataset.GT.n_alt_alleles(), 0),
is_hom_ref=hl.or_else(dataset.GT.is_hom_ref(), 0),
is_het=hl.or_else(dataset.GT.is_het(), 0),
is_hom_var=hl.or_else(dataset.GT.is_hom_var(), 0),
is_missing=hl.is_missing(dataset.GT),
is_not_missing=hl.is_defined(dataset.GT),
)
T = 2 * hl.agg.count_where(hl.is_defined(dataset.GT))
X = hl.agg.sum(dataset.GT.n_alt_alleles())
Y = T - X
if maf is not None:
p = dataset.__maf
else:
p = X / T
q = 1 - p
dataset = dataset.annotate_rows(
_e00=(2 * (p**2) * (q**2) * ((X - 1) / X) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))),
_e10=(
4 * (p**3) * q * ((X - 1) / X) * ((X - 2) / X) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))
+ 4 * p * (q**3) * ((Y - 1) / Y) * ((Y - 2) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))
),
_e20=(
(p**4) * ((X - 1) / X) * ((X - 2) / X) * ((X - 3) / X) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))
+ (q**4) * ((Y - 1) / Y) * ((Y - 2) / Y) * ((Y - 3) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))
+ 4 * (p**2) * (q**2) * ((X - 1) / X) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))
),
_e11=(
2 * (p**2) * q * ((X - 1) / X) * (T / (T - 1)) * (T / (T - 2))
+ 2 * p * (q**2) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2))
),
_e21=(
(p**3) * ((X - 1) / X) * ((X - 2) / X) * (T / (T - 1)) * (T / (T - 2))
+ (q**3) * ((Y - 1) / Y) * ((Y - 2) / Y) * (T / (T - 1)) * (T / (T - 2))
+ (p**2) * q * ((X - 1) / X) * (T / (T - 1)) * (T / (T - 2))
+ p * (q**2) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2))
),
_e22=1,
)
dataset = dataset.checkpoint(hl.utils.new_temp_file())
expectations = dataset.aggregate_rows(
hl.struct(
e00=hl.agg.sum(dataset._e00),
e10=hl.agg.sum(dataset._e10),
e20=hl.agg.sum(dataset._e20),
e11=hl.agg.sum(dataset._e11),
e21=hl.agg.sum(dataset._e21),
e22=hl.agg.sum(dataset._e22),
)
)
IS_HOM_REF = BlockMatrix.from_entry_expr(dataset.is_hom_ref).checkpoint(hl.utils.new_temp_file())
IS_HET = BlockMatrix.from_entry_expr(dataset.is_het).checkpoint(hl.utils.new_temp_file())
IS_HOM_VAR = BlockMatrix.from_entry_expr(dataset.is_hom_var).checkpoint(hl.utils.new_temp_file())
NOT_MISSING = (IS_HOM_REF + IS_HET + IS_HOM_VAR).checkpoint(hl.utils.new_temp_file())
total_possible_ibs = NOT_MISSING.T @ NOT_MISSING
ibs0_pre = (IS_HOM_REF.T @ IS_HOM_VAR).checkpoint(hl.utils.new_temp_file())
ibs0 = ibs0_pre + ibs0_pre.T
is_not_het = IS_HOM_REF + IS_HOM_VAR
ibs1_pre = (IS_HET.T @ is_not_het).checkpoint(hl.utils.new_temp_file())
ibs1 = ibs1_pre + ibs1_pre.T
ibs2 = total_possible_ibs - ibs0 - ibs1
Z0 = ibs0 / expectations.e00
Z1 = (ibs1 - Z0 * expectations.e10) / expectations.e11
Z2 = (ibs2 - Z0 * expectations.e20 - Z1 * expectations.e21) / expectations.e22
def convert_to_table(bm, annotation_name):
t = bm.entries()
t = t.rename({'entry': annotation_name})
return t
z0 = convert_to_table(Z0, 'Z0').checkpoint(hl.utils.new_temp_file())
z1 = convert_to_table(Z1, 'Z1').checkpoint(hl.utils.new_temp_file())
z2 = convert_to_table(Z2, 'Z2').checkpoint(hl.utils.new_temp_file())
ibs0 = convert_to_table(ibs0, 'ibs0').checkpoint(hl.utils.new_temp_file())
ibs1 = convert_to_table(ibs1, 'ibs1').checkpoint(hl.utils.new_temp_file())
ibs2 = convert_to_table(ibs2, 'ibs2').checkpoint(hl.utils.new_temp_file())
result = z0.join(z1.join(z2).join(ibs0).join(ibs1).join(ibs2))
def bound_result(_ibd):
return (
hl.case()
.when(_ibd.Z0 > 1, hl.struct(Z0=hl.float(1), Z1=hl.float(0), Z2=hl.float(0)))
.when(_ibd.Z1 > 1, hl.struct(Z0=hl.float(0), Z1=hl.float(1), Z2=hl.float(0)))
.when(_ibd.Z2 > 1, hl.struct(Z0=hl.float(0), Z1=hl.float(0), Z2=hl.float(1)))
.when(
_ibd.Z0 < 0,
hl.struct(Z0=hl.float(0), Z1=_ibd.Z1 / (_ibd.Z1 + _ibd.Z2), Z2=_ibd.Z2 / (_ibd.Z1 + _ibd.Z2)),
)
.when(
_ibd.Z1 < 0,
hl.struct(Z0=_ibd.Z0 / (_ibd.Z0 + _ibd.Z2), Z1=hl.float(0), Z2=_ibd.Z2 / (_ibd.Z0 + _ibd.Z2)),
)
.when(
_ibd.Z2 < 0,
hl.struct(Z0=_ibd.Z0 / (_ibd.Z0 + _ibd.Z1), Z1=_ibd.Z1 / (_ibd.Z0 + _ibd.Z1), Z2=hl.float(0)),
)
.default(_ibd)
)
result = result.annotate(ibd=hl.struct(Z0=result.Z0, Z1=result.Z1, Z2=result.Z2))
result = result.drop('Z0', 'Z1', 'Z2')
if bounded:
result = result.annotate(ibd=bound_result(result.ibd))
result = result.annotate(ibd=result.ibd.annotate(PI_HAT=result.ibd.Z1 / 2 + result.ibd.Z2))
result = result.filter((result.i < result.j) & (min <= result.ibd.PI_HAT) & (result.ibd.PI_HAT <= max))
samples = hl.literal(dataset.s.collect())
result = result.key_by(i=samples[hl.int32(result.i)], j=samples[hl.int32(result.j)])
return result.persist()