import json
import os
import hail as hl
from hail.genetics import ReferenceGenome
from hail.matrixtable import MatrixTable
from hail.typecheck import typecheck_method
from hail.utils.java import info, warning
extra_ref_globals_file = 'extra_reference_globals.json'
[docs]def read_vds(
path,
*,
intervals=None,
n_partitions=None,
_assert_reference_type=None,
_assert_variant_type=None,
_warn_no_ref_block_max_length=True,
_drop_end=False,
) -> 'VariantDataset':
"""Read in a :class:`.VariantDataset` written with :meth:`.VariantDataset.write`.
Parameters
----------
path: :obj:`str`
Returns
-------
:class:`.VariantDataset`
"""
if intervals or not n_partitions:
reference_data = hl.read_matrix_table(VariantDataset._reference_path(path), _intervals=intervals)
variant_data = hl.read_matrix_table(VariantDataset._variants_path(path), _intervals=intervals)
else:
assert n_partitions is not None
reference_data = hl.read_matrix_table(VariantDataset._reference_path(path))
intervals = reference_data._calculate_new_partitions(n_partitions)
assert len(intervals) > 0
reference_data = hl.read_matrix_table(VariantDataset._reference_path(path), _intervals=intervals)
variant_data = hl.read_matrix_table(VariantDataset._variants_path(path), _intervals=intervals)
# if LEN is missing, add it, _add_len is a no-op if LEN is already present
reference_data = VariantDataset._add_len(reference_data)
if _drop_end:
if 'END' in reference_data.entry:
reference_data = reference_data.drop('END')
else: # if END is missing, add it, _add_end is a no-op if END is already present
reference_data = VariantDataset._add_end(reference_data)
vds = VariantDataset(reference_data, variant_data)
if VariantDataset.ref_block_max_length_field not in vds.reference_data.globals:
fs = hl.current_backend().fs
metadata_file = os.path.join(path, extra_ref_globals_file)
if fs.exists(metadata_file):
with fs.open(metadata_file, 'r') as f:
metadata = json.load(f)
vds.reference_data = vds.reference_data.annotate_globals(**metadata)
elif _warn_no_ref_block_max_length:
warning(
"You are reading a VDS written with an older version of Hail."
"\n Hail now supports much faster interval filters on VDS, but you'll need to run either"
"\n `hl.vds.truncate_reference_blocks(vds, ...)` and write a copy (see docs) or patch the"
"\n existing VDS in place with `hl.vds.store_ref_block_max_length(vds_path)`."
)
return vds
[docs]def store_ref_block_max_length(vds_path):
"""Patches an existing VDS file to store the max reference block length for faster interval filters.
This method permits :func:`.vds.filter_intervals` to remove reference data not overlapping a target interval.
This method is able to patch an existing VDS file in-place, without copying all the data. However,
if significant downstream interval filtering is anticipated, it may be advantageous to run
:func:`.vds.truncate_reference_blocks` to truncate long reference blocks and make interval filters
even faster. However, truncation requires rewriting the entire VDS.
Examples
--------
>>> hl.vds.store_ref_block_max_length('gs://path/to/my.vds') # doctest: +SKIP
See Also
--------
:func:`.vds.filter_intervals`, :func:`.vds.truncate_reference_blocks`.
Parameters
----------
vds_path : :obj:`str`
"""
vds = read_vds(vds_path, _warn_no_ref_block_max_length=False)
if VariantDataset.ref_block_max_length_field in vds.reference_data.globals:
warning(f"VDS at {vds_path} already contains a global annotation with the max reference block length")
return
rd = vds.reference_data
fs = hl.current_backend().fs
ref_block_max_len = rd.aggregate_entries(hl.agg.max(rd.LEN))
with fs.open(os.path.join(vds_path, extra_ref_globals_file), 'w') as f:
json.dump({VariantDataset.ref_block_max_length_field: ref_block_max_len}, f)
[docs]class VariantDataset:
"""Class for representing cohort-level genomic data.
This class facilitates a sparse, split representation of genomic data in
which reference block data and variant data are contained in separate
:class:`.MatrixTable` objects.
Parameters
----------
reference_data : :class:`.MatrixTable`
MatrixTable containing only reference block data.
variant_data : :class:`.MatrixTable`
MatrixTable containing only variant data.
"""
#: Name of global field that indicates max reference block length.
ref_block_max_length_field = 'ref_block_max_length'
@staticmethod
def _reference_path(base: str) -> str:
return os.path.join(base, 'reference_data')
@staticmethod
def _variants_path(base: str) -> str:
return os.path.join(base, 'variant_data')
[docs] @staticmethod
def from_merged_representation(
mt, *, ref_block_indicator_field='END', ref_block_fields=(), infer_ref_block_fields: bool = True, is_split=False
):
"""Create a VariantDataset from a sparse MatrixTable containing variant and reference data."""
if ref_block_indicator_field not in ('END', 'LEN'):
raise ValueError(
f'Invalid `ref_block_indicator_field` `{ref_block_indicator_field}` one of `LEN` or `END` expected'
)
if ref_block_indicator_field not in mt.entry:
raise ValueError(
f'VariantDataset.from_merged_representation: expect field `{ref_block_indicator_field}` in matrix table entry'
)
if 'LA' not in mt.entry and not is_split:
raise ValueError(
'VariantDataset.from_merged_representation: expect field `LA` in matrix table entry.'
'\n If this dataset is already split into biallelics, use `is_split=True` to permit a conversion'
' with no `LA` field.'
)
if 'GT' not in mt.entry and 'LGT' not in mt.entry:
raise ValueError(
'VariantDataset.from_merged_representation: expect field `LGT` or `GT` in matrix table entry'
)
n_rows_to_use = 100
info(f"inferring reference block fields from missingness patterns in first {n_rows_to_use} rows")
used_ref_block_fields = set(ref_block_fields)
used_ref_block_fields.add(ref_block_indicator_field)
if infer_ref_block_fields:
mt_head = mt.head(n_rows=n_rows_to_use)
for k, any_present in zip(
list(mt_head.entry),
mt_head.aggregate_entries(
hl.agg.filter(
hl.is_defined(mt_head[ref_block_indicator_field]),
tuple(hl.agg.any(hl.is_defined(mt_head[x])) for x in mt_head.entry),
)
),
):
if any_present:
used_ref_block_fields.add(k)
gt_field = 'LGT' if 'LGT' in mt.entry else 'GT'
# remove the LA field, which is trivial for reference blocks and does not need to be represented
if 'LA' in used_ref_block_fields:
used_ref_block_fields.remove('LA')
info(
"Including the following fields in reference block table:"
+ "".join(f"\n {k!r}" for k in mt.entry if k in used_ref_block_fields)
)
rmt = mt.filter_entries(
hl.case()
.when(hl.is_missing(mt[ref_block_indicator_field]), False)
.when(hl.is_defined(mt[ref_block_indicator_field]) & mt[gt_field].is_hom_ref(), True)
.or_error(
hl.str(
f'cannot create VDS from merged representation - found {ref_block_indicator_field} field with non-reference genotype at '
)
+ hl.str(mt.locus)
+ hl.str(' / ')
+ hl.str(mt.col_key[0])
)
)
rmt = rmt.select_entries(*(x for x in rmt.entry if x in used_ref_block_fields))
rmt = rmt.filter_rows(hl.agg.count() > 0)
rmt = rmt.key_rows_by('locus').select_rows().select_cols()
if ref_block_indicator_field == 'END':
rmt = VariantDataset._add_len(rmt)
else: # ref_block_indicator_field is 'LEN'
rmt = VariantDataset._add_end(rmt)
if is_split:
rmt = rmt.distinct_by_row()
vmt = (
mt.filter_entries(hl.is_missing(mt[ref_block_indicator_field]))
.drop(ref_block_indicator_field)
._key_rows_by_assert_sorted('locus', 'alleles')
)
vmt = vmt.filter_rows(hl.agg.count() > 0)
return VariantDataset(rmt, vmt)
def __init__(self, reference_data: MatrixTable, variant_data: MatrixTable):
self.reference_data: MatrixTable = reference_data
self.variant_data: MatrixTable = variant_data
self.validate(check_data=False)
[docs] def write(self, path, **kwargs):
"""Write to `path`.
Any optional parameter from :meth:`.MatrixTable.write` can be passed as
a keyword paramter to this method.
"""
# NOTE: Populate LEN and drop END from reference data to align with VCF 4.5.
# Furthermore, since LEN values are smaller and more likely to be close
# or the same as neighboring values, we expect that after small integer
# compression and general purpose data compression that reference data should
# be smaller using LEN over END
rd = self.reference_data
if 'LEN' not in rd.entry:
rd = VariantDataset._add_len(rd)
if 'END' in rd.entry:
rd = rd.drop('END')
rd.write(VariantDataset._reference_path(path), **kwargs)
self.variant_data.write(VariantDataset._variants_path(path), **kwargs)
[docs] def checkpoint(self, path, **kwargs) -> 'VariantDataset':
"""Write to `path` and then read from `path`."""
self.write(path, **kwargs)
return read_vds(path)
[docs] def n_samples(self) -> int:
"""The number of samples present."""
return self.reference_data.count_cols()
@property
def reference_genome(self) -> ReferenceGenome:
"""Dataset reference genome.
Returns
-------
:class:`.ReferenceGenome`
"""
return self.reference_data.locus.dtype.reference_genome
[docs] @typecheck_method(check_data=bool)
def validate(self, *, check_data: bool = True):
"""Eagerly checks necessary representational properties of the VDS."""
rd = self.reference_data
vd = self.variant_data
def error(msg):
raise ValueError(f'VDS.validate: {msg}')
rd_row_key = rd.row_key.dtype
if (
not isinstance(rd_row_key, hl.tstruct)
or len(rd_row_key) != 1
or not rd_row_key.fields[0] == 'locus'
or not isinstance(rd_row_key.types[0], hl.tlocus)
):
error(f"expect reference data to have a single row key 'locus' of type locus, found {rd_row_key}")
vd_row_key = vd.row_key.dtype
if (
not isinstance(vd_row_key, hl.tstruct)
or len(vd_row_key) != 2
or not vd_row_key.fields == ('locus', 'alleles')
or not isinstance(vd_row_key.types[0], hl.tlocus)
or vd_row_key.types[1] != hl.tarray(hl.tstr)
):
error(
f"expect variant data to have a row key {{'locus': locus<rg>, alleles: array<str>}}, found {vd_row_key}"
)
rd_col_key = rd.col_key.dtype
if not isinstance(rd_col_key, hl.tstruct) or len(rd_row_key) != 1 or rd_col_key.types[0] != hl.tstr:
error(f"expect reference data to have a single col key of type string, found {rd_col_key}")
vd_col_key = vd.col_key.dtype
if not isinstance(vd_col_key, hl.tstruct) or len(vd_col_key) != 1 or vd_col_key.types[0] != hl.tstr:
error(f"expect variant data to have a single col key of type string, found {vd_col_key}")
end_exists = 'END' in rd.entry
len_exists = 'LEN' in rd.entry
if not (end_exists or len_exists):
error("expect at least one of 'END' or 'LEN' in entry of reference data")
if end_exists and rd.END.dtype != hl.tint32:
error("'END' field in entry of reference data must have type tint32")
if len_exists and rd.LEN.dtype != hl.tint32:
error("'LEN' field in entry of reference data must have type tint32")
if check_data:
# check cols
ref_cols = rd.col_key.collect()
var_cols = vd.col_key.collect()
if len(ref_cols) != len(var_cols):
error(
f"mismatch in number of columns: reference data has {ref_cols} columns, variant data has {var_cols} columns"
)
if ref_cols != var_cols:
first_mismatch = 0
while ref_cols[first_mismatch] == var_cols[first_mismatch]:
first_mismatch += 1
error(
f"mismatch in columns keys: ref={ref_cols[first_mismatch]}, var={var_cols[first_mismatch]} at position {first_mismatch}"
)
# check locus distinctness
n_rd_rows = rd.count_rows()
n_distinct = rd.distinct_by_row().count_rows()
if n_distinct != n_rd_rows:
error(f'reference data loci are not distinct: found {n_rd_rows} rows, but {n_distinct} distinct loci')
# check END field
end_exprs = dict(
missing_end=hl.agg.filter(hl.is_missing(rd.END), hl.agg.take((rd.row_key, rd.col_key), 5)),
end_before_position=hl.agg.filter(rd.END < rd.locus.position, hl.agg.take((rd.row_key, rd.col_key), 5)),
)
if VariantDataset.ref_block_max_length_field in rd.globals:
rbml = rd[VariantDataset.ref_block_max_length_field]
end_exprs['blocks_too_long'] = hl.agg.filter(
rd.END - rd.locus.position + 1 > rbml, hl.agg.take((rd.row_key, rd.col_key), 5)
)
res = rd.aggregate_entries(hl.struct(**end_exprs))
if res.missing_end:
error(
'found records in reference data with missing END field\n '
+ '\n '.join(str(x) for x in res.missing_end)
)
if res.end_before_position:
error(
'found records in reference data with END before locus position\n '
+ '\n '.join(str(x) for x in res.end_before_position)
)
blocks_too_long = res.get('blocks_too_long', [])
if blocks_too_long:
error(
'found records in reference data with blocks larger than `ref_block_max_length`\n '
+ '\n '.join(str(x) for x in blocks_too_long)
)
def _same(self, other: 'VariantDataset'):
return self.reference_data._same(other.reference_data) and self.variant_data._same(other.variant_data)
@staticmethod
def _add_len(rd):
if 'LEN' in rd.entry:
return rd
if 'END' in rd.entry:
return rd.annotate_entries(LEN=rd.END - rd.locus.position + 1)
raise ValueError('Need `END` to compute `LEN` in reference data')
@staticmethod
def _add_end(rd):
if 'END' in rd.entry:
return rd
if 'LEN' in rd.entry:
return rd.annotate_entries(END=rd.LEN + rd.locus.position - 1)
raise ValueError('Need `LEN` to compute `END` in reference data')
[docs] def union_rows(*vdses):
"""Combine many VDSes with the same samples but disjoint variants.
**Examples**
If a dataset is imported as VDS in chromosome-chunks, the following will combine them into
one VDS:
>>> vds_paths = ['chr1.vds', 'chr2.vds'] # doctest: +SKIP
... vds_per_chrom = [hl.vds.read_vds(path) for path in vds_paths) # doctest: +SKIP
... hl.vds.VariantDataset.union_rows(*vds_per_chrom) # doctest: +SKIP
"""
fd = hl.vds.VariantDataset.ref_block_max_length_field
mts = [vds.reference_data for vds in vdses]
n_with_ref_max_len = len([mt for mt in mts if fd in mt.globals])
any_ref_max = n_with_ref_max_len > 0
all_ref_max = n_with_ref_max_len == len(mts)
# if some mts have max ref len but not all, drop it
if all_ref_max:
new_ref_mt = hl.MatrixTable.union_rows(*mts).annotate_globals(**{
fd: hl.max([mt.index_globals()[fd] for mt in mts])
})
else:
if any_ref_max:
mts = [mt.drop(fd) if fd in mt.globals else mt for mt in mts]
new_ref_mt = hl.MatrixTable.union_rows(*mts)
new_var_mt = hl.MatrixTable.union_rows(*(vds.variant_data for vds in vdses))
return hl.vds.VariantDataset(new_ref_mt, new_var_mt)