from collections.abc import Sequence
from typing import Optional
import hail as hl
from hail.expr.expressions import Expression
from hail.expr.expressions.typed_expressions import (
ArrayExpression,
CallExpression,
LocusExpression,
NumericExpression,
StructExpression,
)
from hail.genetics.allele_type import AlleleType
from hail.methods.misc import require_first_key_field_locus
from hail.methods.qc import _qc_allele_type
from hail.table import Table
from hail.typecheck import nullable, sequenceof, typecheck
from hail.utils.java import Env
from hail.utils.misc import divide_null
from hail.vds.variant_dataset import VariantDataset
@typecheck(global_gt=Expression, alleles=ArrayExpression)
def vmt_sample_qc_variant_annotations(
*,
global_gt: 'Expression',
alleles: 'ArrayExpression',
) -> tuple['Expression', 'Expression']:
"""Compute the necessary variant annotations for :func:`.vmt_sample_qc`, that is,
allele count (AC) and an integer representation of allele type.
Parameters
----------
global_gt : :class:`.Expression`
Call expression of the global GT of a variants matrix table usually generated
by :func:`..lgt_to_gt`
alleles : :class:`.ArrayExpression`
Array expression of the alleles of a variants matrix table
(generally ``vds.variant_data.alleles``)
Returns
-------
:class:`tuple`
Tuple of expressions representing the AC (first element) and allele type
(second element).
"""
return (hl.agg.call_stats(global_gt, alleles).AC, alleles[1:].map(lambda alt: _qc_allele_type(alleles[0], alt)))
@typecheck(
global_gt=Expression,
gq=Expression,
variant_ac=ArrayExpression,
variant_atypes=ArrayExpression,
dp=nullable(Expression),
gq_bins=sequenceof(int),
dp_bins=sequenceof(int),
)
def vmt_sample_qc(
*,
global_gt: 'CallExpression',
gq: 'Expression',
variant_ac: 'ArrayExpression',
variant_atypes: 'ArrayExpression',
dp: Optional['Expression'] = None,
gq_bins: 'Sequence[int]' = (0, 20, 60),
dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30),
) -> 'Expression':
"""Computes sample quality metrics from variant data of a VDS
Parameters
----------
global_gt : :class:`.CallExpression`
Global GT of a variants matrix table or subset thereof (ex. ``hl.agg.group_by``).
gq : :class:`.Expression`
GQ of a variants matrix table.
variant_ac : :class:`.ArrayExpression`
Allele counts of a the genotypes of a variants matrix table. This can
be generated by ``hl.agg.call_stats`` or alternatively
:func:`.vmt_sample_qc_variant_annotations` (which calls ``call_stats``
internally)
variant_atypes : :class:`.ArrayExpression`
Allele types of the alternate alleles a variants matrix table. This
must be generated with :func:`.vmt_sample_qc_variant_annotations` in
order to return correct results.
dp : :class:`.Expression` or :obj:`NoneType`
DP of a variants matrix table (or ``None``)
gq_bins : :class:`tuple` of :obj:`int`
Tuple containing cutoffs for genotype quality (GQ) scores.
dp_bins : :class:`tuple` of :obj:`int`
Tuple containing cutoffs for depth (DP) scores.
Returns
-------
:class:`.StructExpression`
A struct expression of type::
struct{
bases_over_gq_threshold: tuple(int64 * len(gq_bins)),
bases_over_dp_threshold: tuple(int64 * len(gq_bins)), # present if dp is not None
n_het: int64,
n_hom_var: int64,
n_non_ref: int64,
n_singleton: int64,
n_singleton_ti: int64,
n_singleton_tv: int64,
n_snp: int64,
n_insertion: int64,
n_deletion: int64,
n_transition: int64,
n_transversion: int64,
n_star: int64,
r_ti_tv: float64,
r_ti_tv_singleton: float64,
r_het_hom_var: float64,
r_insertion_deletion: float64,
}
"""
bound_exprs = {}
bound_exprs['n_het'] = hl.agg.count_where(global_gt.is_het())
bound_exprs['n_hom_var'] = hl.agg.count_where(global_gt.is_hom_var())
bound_exprs['n_singleton'] = hl.agg.sum(
hl.rbind(
global_gt,
lambda global_gt: hl.sum(
hl.range(0, global_gt.ploidy).map(
lambda i: hl.rbind(global_gt[i], lambda gti: (gti != 0) & (variant_ac[gti] == 1))
)
),
)
)
bound_exprs['n_singleton_ti'] = hl.agg.sum(
hl.rbind(
global_gt,
lambda global_gt: hl.sum(
hl.range(0, global_gt.ploidy).map(
lambda i: hl.rbind(
global_gt[i],
lambda gti: (gti != 0)
& (variant_ac[gti] == 1)
& (variant_atypes[gti - 1] == AlleleType.TRANSITION),
)
)
),
)
)
bound_exprs['n_singleton_tv'] = hl.agg.sum(
hl.rbind(
global_gt,
lambda global_gt: hl.sum(
hl.range(0, global_gt.ploidy).map(
lambda i: hl.rbind(
global_gt[i],
lambda gti: (gti != 0)
& (variant_ac[gti] == 1)
& (variant_atypes[gti - 1] == AlleleType.TRANSVERSION),
)
)
),
)
)
bound_exprs['allele_type_counts'] = hl.agg.explode(
lambda allele_type: hl.tuple(hl.agg.count_where(allele_type == i) for i in range(len(AlleleType))),
(
hl.range(0, global_gt.ploidy)
.map(lambda i: global_gt[i])
.filter(lambda allele_idx: allele_idx > 0)
.map(lambda allele_idx: variant_atypes[allele_idx - 1])
),
)
dp_exprs = {}
if dp is not None:
dp_exprs['bases_over_dp_threshold'] = hl.tuple(hl.agg.count_where(dp >= x) for x in dp_bins)
gq_dp_exprs = {'bases_over_gq_threshold': hl.tuple(hl.agg.count_where(gq >= x) for x in gq_bins), **dp_exprs}
return hl.rbind(
hl.struct(**bound_exprs),
lambda x: hl.rbind(
hl.struct(**{
**gq_dp_exprs,
'n_het': x.n_het,
'n_hom_var': x.n_hom_var,
'n_non_ref': x.n_het + x.n_hom_var,
'n_singleton': x.n_singleton,
'n_singleton_ti': x.n_singleton_ti,
'n_singleton_tv': x.n_singleton_tv,
'n_snp': x.allele_type_counts[AlleleType.TRANSITION] + x.allele_type_counts[AlleleType.TRANSVERSION],
'n_insertion': x.allele_type_counts[AlleleType.INSERTION],
'n_deletion': x.allele_type_counts[AlleleType.DELETION],
'n_transition': x.allele_type_counts[AlleleType.TRANSITION],
'n_transversion': x.allele_type_counts[AlleleType.TRANSVERSION],
'n_star': x.allele_type_counts[AlleleType.STAR],
}),
lambda s: s.annotate(
r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion),
r_ti_tv_singleton=divide_null(hl.float64(s.n_singleton_ti), s.n_singleton_tv),
r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var),
r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion),
),
),
)
@typecheck(
locus=LocusExpression,
gq=NumericExpression,
end=NumericExpression,
dp=nullable(Expression),
gq_bins=sequenceof(int),
dp_bins=sequenceof(int),
)
def rmt_sample_qc(
*,
locus: 'LocusExpression',
end: 'NumericExpression',
gq: 'NumericExpression',
dp: Optional['Expression'] = None,
gq_bins: 'Sequence[int]' = (0, 20, 60),
dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30),
) -> 'StructExpression':
"""Computes sample quality metrics from reference data of a VDS
Parameters
----------
locus : :class:`.LocusExpression`
Locus of a refrence matrix table
end : :class:`.NumericExpression`
END of a reference matrix table
gq : :class:`.Expression`
GQ of a variants matrix table.
dp : :class:`.Expression` or :obj:`NoneType`
DP of a variants matrix table (or ``None``)
gq_bins : :class:`tuple` of :obj:`int`
Tuple containing cutoffs for genotype quality (GQ) scores.
dp_bins : :class:`tuple` of :obj:`int`
Tuple containing cutoffs for depth (DP) scores.
Returns
-------
:class:`.StructExpression`
A struct expression of type::
struct{
bases_over_gq_threshold: tuple(int64 * len(gq_bins)),
bases_over_dp_threshold: tuple(int64 * len(dp_bins)), # present if dp is not None
}
"""
ref_dp_expr = {}
if dp is not None:
ref_dp_expr['bases_over_dp_threshold'] = hl.tuple(
hl.agg.filter(dp >= x, hl.agg.sum(1 + end - locus.position)) for x in dp_bins
)
return hl.struct(
bases_over_gq_threshold=hl.tuple(hl.agg.filter(gq >= x, hl.agg.sum(1 + end - locus.position)) for x in gq_bins),
**ref_dp_expr,
)
def combine_sample_qc(
rmt_sample_qc: Expression,
vmt_sample_qc: Expression,
) -> Expression:
"""Combine reference and variants sample quality results
Parameters
----------
rmt_sample_qc : :class:`.Expression`
A struct expression produced by :func:`.rmt_sample_qc`
vmt_sample_qc : :class:`.Expression`
A struct expression produced by :func:`.vmt_sample_qc`
Returns
-------
:class:`.StructExpression`
A struct expression of type::
struct{
bases_over_gq_threshold:
tuple(int64 * len(rmt_sample_qc.bases_over_gq_threshold)),
bases_over_dp_threshold: # present if dp was present for qc stats generation
tuple(int64 * len(rmt_sample_qc.bases_over_dp_threshold)),
}
Note
----
It is the responsibility of the caller of this function to make sure that
the ``gq_bins`` and ``dp_bins`` that are used for the generation of both of
the arguments to this function are the same. Incorrect results will occur
if the bins are not the same. This function checks the length of the bins
used, but cannot check the bin values themselves.
"""
if 'bases_over_gq_threshold' not in rmt_sample_qc:
raise ValueError("Expect 'bases_over_gq_threshold' field in 'rmt_sample_qc' expression")
if 'bases_over_gq_threshold' not in vmt_sample_qc:
raise ValueError("Expect 'bases_over_gq_threshold' field in 'vmt_sample_qc' expression")
if sum('bases_over_dp_threshold' in expr for expr in (rmt_sample_qc, vmt_sample_qc)) % 2 == 1:
raise ValueError(
"Expect 'bases_over_dp_threshold' field in both or neither of " "'rmt_sample_qc' and 'vmt_sample_qc'"
)
if len(rmt_sample_qc.bases_over_gq_threshold) != len(vmt_sample_qc.bases_over_gq_threshold):
raise ValueError("Expect same number of GQ bins for both variant and reference qc results")
if 'bases_over_dp_threshold' in rmt_sample_qc and len(rmt_sample_qc.bases_over_dp_threshold) != len(
vmt_sample_qc.bases_over_dp_threshold
):
raise ValueError("Expect same number of DP bins for both variant and reference qc results")
joined_dp_expr = {}
if 'bases_over_dp_threshold' in vmt_sample_qc:
joined_dp_expr['bases_over_dp_threshold'] = hl.tuple(
x + y for x, y in zip(vmt_sample_qc.bases_over_dp_threshold, rmt_sample_qc.bases_over_dp_threshold)
)
return hl.struct(
bases_over_gq_threshold=hl.tuple(
x + y for x, y in zip(vmt_sample_qc.bases_over_gq_threshold, rmt_sample_qc.bases_over_gq_threshold)
),
**joined_dp_expr,
)
[docs]@typecheck(vds=VariantDataset, gq_bins=sequenceof(int), dp_bins=sequenceof(int), dp_field=nullable(str))
def sample_qc(
vds: 'VariantDataset',
*,
gq_bins: 'Sequence[int]' = (0, 20, 60),
dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30),
dp_field=None,
) -> 'Table':
"""Compute sample quality metrics about a :class:`.VariantDataset`.
If the `dp_field` parameter is not specified, the ``DP`` is used for depth
if present. If no ``DP`` field is present, the ``MIN_DP`` field is used. If no ``DP``
or ``MIN_DP`` field is present, no depth statistics will be calculated.
Parameters
----------
vds : :class:`.VariantDataset`
Dataset in VariantDataset representation.
gq_bins : :class:`tuple` of :obj:`int`
Tuple containing cutoffs for genotype quality (GQ) scores.
dp_bins : :class:`tuple` of :obj:`int`
Tuple containing cutoffs for depth (DP) scores.
dp_field : :obj:`str`
Name of depth field. If not supplied, DP or MIN_DP will be used, in that order.
Returns
-------
:class:`.Table`
Hail Table of results, keyed by sample.
"""
require_first_key_field_locus(vds.reference_data, 'sample_qc')
require_first_key_field_locus(vds.variant_data, 'sample_qc')
if dp_field is not None:
ref_dp_field_to_use = dp_field
elif 'DP' in vds.reference_data.entry:
ref_dp_field_to_use = 'DP'
elif 'MIN_DP' in vds.reference_data.entry:
ref_dp_field_to_use = 'MIN_DP'
else:
ref_dp_field_to_use = None
vmt = vds.variant_data
if 'GT' not in vmt.entry:
vmt = vmt.annotate_entries(GT=hl.vds.lgt_to_gt(vmt.LGT, vmt.LA))
allele_count, atypes = vmt_sample_qc_variant_annotations(global_gt=vmt.GT, alleles=vmt.alleles)
variant_ac = Env.get_uid()
variant_atypes = Env.get_uid()
vmt = vmt.annotate_rows(**{variant_ac: allele_count, variant_atypes: atypes})
vmt_dp = vmt['DP'] if ref_dp_field_to_use is not None and 'DP' in vmt.entry else None
variant_results = vmt.select_cols(
**vmt_sample_qc(
global_gt=vmt.GT,
gq=vmt.GQ,
variant_ac=vmt[variant_ac],
variant_atypes=vmt[variant_atypes],
dp=vmt_dp,
gq_bins=gq_bins,
dp_bins=dp_bins,
)
).cols()
rmt = vds.reference_data
rmt_dp = rmt[ref_dp_field_to_use] if ref_dp_field_to_use is not None else None
reference_results = rmt.select_cols(
**rmt_sample_qc(
locus=rmt.locus,
gq=rmt.GQ,
end=rmt.END,
dp=rmt_dp,
gq_bins=gq_bins,
dp_bins=dp_bins,
)
).cols()
joined = reference_results[variant_results.key]
dp_bins_field = {}
if ref_dp_field_to_use is not None:
dp_bins_field['dp_bins'] = hl.tuple(dp_bins)
joined_results = variant_results.transmute(**combine_sample_qc(joined, variant_results.row))
joined_results = joined_results.annotate_globals(gq_bins=hl.tuple(gq_bins), **dp_bins_field)
return joined_results