Source code for hail.vds.variant_dataset

import json
import os

import hail as hl
from hail.genetics import ReferenceGenome
from hail.matrixtable import MatrixTable
from hail.typecheck import typecheck_method
from hail.utils.java import info, warning

extra_ref_globals_file = 'extra_reference_globals.json'


[docs]def read_vds( path, *, intervals=None, n_partitions=None, _assert_reference_type=None, _assert_variant_type=None, _warn_no_ref_block_max_length=True, ) -> 'VariantDataset': """Read in a :class:`.VariantDataset` written with :meth:`.VariantDataset.write`. Parameters ---------- path: :obj:`str` Returns ------- :class:`.VariantDataset` """ if intervals or not n_partitions: reference_data = hl.read_matrix_table(VariantDataset._reference_path(path), _intervals=intervals) variant_data = hl.read_matrix_table(VariantDataset._variants_path(path), _intervals=intervals) else: assert n_partitions is not None reference_data = hl.read_matrix_table(VariantDataset._reference_path(path)) intervals = reference_data._calculate_new_partitions(n_partitions) assert len(intervals) > 0 reference_data = hl.read_matrix_table(VariantDataset._reference_path(path), _intervals=intervals) variant_data = hl.read_matrix_table(VariantDataset._variants_path(path), _intervals=intervals) vds = VariantDataset(reference_data, variant_data) if VariantDataset.ref_block_max_length_field not in vds.reference_data.globals: fs = hl.current_backend().fs metadata_file = os.path.join(path, extra_ref_globals_file) if fs.exists(metadata_file): with fs.open(metadata_file, 'r') as f: metadata = json.load(f) vds.reference_data = vds.reference_data.annotate_globals(**metadata) elif _warn_no_ref_block_max_length: warning( "You are reading a VDS written with an older version of Hail." "\n Hail now supports much faster interval filters on VDS, but you'll need to run either" "\n `hl.vds.truncate_reference_blocks(vds, ...)` and write a copy (see docs) or patch the" "\n existing VDS in place with `hl.vds.store_ref_block_max_length(vds_path)`." ) return vds
[docs]def store_ref_block_max_length(vds_path): """Patches an existing VDS file to store the max reference block length for faster interval filters. This method permits :func:`.vds.filter_intervals` to remove reference data not overlapping a target interval. This method is able to patch an existing VDS file in-place, without copying all the data. However, if significant downstream interval filtering is anticipated, it may be advantageous to run :func:`.vds.truncate_reference_blocks` to truncate long reference blocks and make interval filters even faster. However, truncation requires rewriting the entire VDS. Examples -------- >>> hl.vds.store_ref_block_max_length('gs://path/to/my.vds') # doctest: +SKIP See Also -------- :func:`.vds.filter_intervals`, :func:`.vds.truncate_reference_blocks`. Parameters ---------- vds_path : :obj:`str` """ vds = read_vds(vds_path, _warn_no_ref_block_max_length=False) if VariantDataset.ref_block_max_length_field in vds.reference_data.globals: warning(f"VDS at {vds_path} already contains a global annotation with the max reference block length") return rd = vds.reference_data rd = rd.annotate_rows(__start_pos=rd.locus.position) fs = hl.current_backend().fs ref_block_max_len = rd.aggregate_entries(hl.agg.max(rd.END - rd.__start_pos + 1)) with fs.open(os.path.join(vds_path, extra_ref_globals_file), 'w') as f: json.dump({VariantDataset.ref_block_max_length_field: ref_block_max_len}, f)
[docs]class VariantDataset: """Class for representing cohort-level genomic data. This class facilitates a sparse, split representation of genomic data in which reference block data and variant data are contained in separate :class:`.MatrixTable` objects. Parameters ---------- reference_data : :class:`.MatrixTable` MatrixTable containing only reference block data. variant_data : :class:`.MatrixTable` MatrixTable containing only variant data. """ #: Name of global field that indicates max reference block length. ref_block_max_length_field = 'ref_block_max_length' @staticmethod def _reference_path(base: str) -> str: return os.path.join(base, 'reference_data') @staticmethod def _variants_path(base: str) -> str: return os.path.join(base, 'variant_data')
[docs] @staticmethod def from_merged_representation(mt, *, ref_block_fields=(), infer_ref_block_fields: bool = True, is_split=False): """Create a VariantDataset from a sparse MatrixTable containing variant and reference data.""" if 'END' not in mt.entry: raise ValueError("VariantDataset.from_merged_representation: expect field 'END' in matrix table entry") if 'LA' not in mt.entry and not is_split: raise ValueError( "VariantDataset.from_merged_representation: expect field 'LA' in matrix table entry." "\n If this dataset is already split into biallelics, use `is_split=True` to permit a conversion" " with no LA field." ) if 'GT' not in mt.entry and 'LGT' not in mt.entry: raise ValueError( "VariantDataset.from_merged_representation: expect field 'LGT' or 'GT' in matrix table entry" ) n_rows_to_use = 100 info(f"inferring reference block fields from missingness patterns in first {n_rows_to_use} rows") used_ref_block_fields = set(ref_block_fields) used_ref_block_fields.add('END') if infer_ref_block_fields: mt_head = mt.head(n_rows=n_rows_to_use) for k, any_present in zip( list(mt_head.entry), mt_head.aggregate_entries( hl.agg.filter( hl.is_defined(mt_head.END), tuple(hl.agg.any(hl.is_defined(mt_head[x])) for x in mt_head.entry) ) ), ): if any_present: used_ref_block_fields.add(k) gt_field = 'LGT' if 'LGT' in mt.entry else 'GT' # remove LGT/GT and LA fields, which are trivial for reference blocks and do not need to be represented if gt_field in used_ref_block_fields: used_ref_block_fields.remove(gt_field) if 'LA' in used_ref_block_fields: used_ref_block_fields.remove('LA') info( "Including the following fields in reference block table:" + "".join(f"\n {k!r}" for k in mt.entry if k in used_ref_block_fields) ) rmt = mt.filter_entries( hl.case() .when(hl.is_missing(mt.END), False) .when(hl.is_defined(mt.END) & mt[gt_field].is_hom_ref(), True) .or_error( hl.str( 'cannot create VDS from merged representation -' ' found END field with non-reference genotype at ' ) + hl.str(mt.locus) + hl.str(' / ') + hl.str(mt.col_key[0]) ) ) rmt = rmt.select_entries(*(x for x in rmt.entry if x in used_ref_block_fields)) rmt = rmt.filter_rows(hl.agg.count() > 0) rmt = rmt.key_rows_by('locus').select_rows().select_cols() if is_split: rmt = rmt.distinct_by_row() vmt = mt.filter_entries(hl.is_missing(mt.END)).drop('END')._key_rows_by_assert_sorted('locus', 'alleles') vmt = vmt.filter_rows(hl.agg.count() > 0) return VariantDataset(rmt, vmt)
def __init__(self, reference_data: MatrixTable, variant_data: MatrixTable): self.reference_data: MatrixTable = reference_data self.variant_data: MatrixTable = variant_data self.validate(check_data=False)
[docs] def write(self, path, **kwargs): """Write to `path`.""" self.reference_data.write(VariantDataset._reference_path(path), **kwargs) self.variant_data.write(VariantDataset._variants_path(path), **kwargs)
[docs] def checkpoint(self, path, **kwargs) -> 'VariantDataset': """Write to `path` and then read from `path`.""" self.write(path, **kwargs) return read_vds(path)
[docs] def n_samples(self) -> int: """The number of samples present.""" return self.reference_data.count_cols()
@property def reference_genome(self) -> ReferenceGenome: """Dataset reference genome. Returns ------- :class:`.ReferenceGenome` """ return self.reference_data.locus.dtype.reference_genome
[docs] @typecheck_method(check_data=bool) def validate(self, *, check_data: bool = True): """Eagerly checks necessary representational properties of the VDS.""" rd = self.reference_data vd = self.variant_data def error(msg): raise ValueError(f'VDS.validate: {msg}') rd_row_key = rd.row_key.dtype if ( not isinstance(rd_row_key, hl.tstruct) or len(rd_row_key) != 1 or not rd_row_key.fields[0] == 'locus' or not isinstance(rd_row_key.types[0], hl.tlocus) ): error(f"expect reference data to have a single row key 'locus' of type locus, found {rd_row_key}") vd_row_key = vd.row_key.dtype if ( not isinstance(vd_row_key, hl.tstruct) or len(vd_row_key) != 2 or not vd_row_key.fields == ('locus', 'alleles') or not isinstance(vd_row_key.types[0], hl.tlocus) or vd_row_key.types[1] != hl.tarray(hl.tstr) ): error( f"expect variant data to have a row key {{'locus': locus<rg>, alleles: array<str>}}, found {vd_row_key}" ) rd_col_key = rd.col_key.dtype if not isinstance(rd_col_key, hl.tstruct) or len(rd_row_key) != 1 or rd_col_key.types[0] != hl.tstr: error(f"expect reference data to have a single col key of type string, found {rd_col_key}") vd_col_key = vd.col_key.dtype if not isinstance(vd_col_key, hl.tstruct) or len(vd_col_key) != 1 or vd_col_key.types[0] != hl.tstr: error(f"expect variant data to have a single col key of type string, found {vd_col_key}") if 'END' not in rd.entry or rd.END.dtype != hl.tint32: error("expect field 'END' in entry of reference data with type int32") if check_data: # check cols ref_cols = rd.col_key.collect() var_cols = vd.col_key.collect() if len(ref_cols) != len(var_cols): error( f"mismatch in number of columns: reference data has {ref_cols} columns, variant data has {var_cols} columns" ) if ref_cols != var_cols: first_mismatch = 0 while ref_cols[first_mismatch] == var_cols[first_mismatch]: first_mismatch += 1 error( f"mismatch in columns keys: ref={ref_cols[first_mismatch]}, var={var_cols[first_mismatch]} at position {first_mismatch}" ) # check locus distinctness n_rd_rows = rd.count_rows() n_distinct = rd.distinct_by_row().count_rows() if n_distinct != n_rd_rows: error(f'reference data loci are not distinct: found {n_rd_rows} rows, but {n_distinct} distinct loci') # check END field end_exprs = dict( missing_end=hl.agg.filter(hl.is_missing(rd.END), hl.agg.take((rd.row_key, rd.col_key), 5)), end_before_position=hl.agg.filter(rd.END < rd.locus.position, hl.agg.take((rd.row_key, rd.col_key), 5)), ) if VariantDataset.ref_block_max_length_field in rd.globals: rbml = rd[VariantDataset.ref_block_max_length_field] end_exprs['blocks_too_long'] = hl.agg.filter( rd.END - rd.locus.position + 1 > rbml, hl.agg.take((rd.row_key, rd.col_key), 5) ) res = rd.aggregate_entries(hl.struct(**end_exprs)) if res.missing_end: error( 'found records in reference data with missing END field\n ' + '\n '.join(str(x) for x in res.missing_end) ) if res.end_before_position: error( 'found records in reference data with END before locus position\n ' + '\n '.join(str(x) for x in res.end_before_position) ) blocks_too_long = res.get('blocks_too_long', []) if blocks_too_long: error( 'found records in reference data with blocks larger than `ref_block_max_length`\n ' + '\n '.join(str(x) for x in blocks_too_long) )
def _same(self, other: 'VariantDataset'): return self.reference_data._same(other.reference_data) and self.variant_data._same(other.variant_data)
[docs] def union_rows(*vdses): """Combine many VDSes with the same samples but disjoint variants. **Examples** If a dataset is imported as VDS in chromosome-chunks, the following will combine them into one VDS: >>> vds_paths = ['chr1.vds', 'chr2.vds'] # doctest: +SKIP ... vds_per_chrom = [hl.vds.read_vds(path) for path in vds_paths) # doctest: +SKIP ... hl.vds.VariantDataset.union_rows(*vds_per_chrom) # doctest: +SKIP """ fd = hl.vds.VariantDataset.ref_block_max_length_field mts = [vds.reference_data for vds in vdses] n_with_ref_max_len = len([mt for mt in mts if fd in mt.globals]) any_ref_max = n_with_ref_max_len > 0 all_ref_max = n_with_ref_max_len == len(mts) # if some mts have max ref len but not all, drop it if all_ref_max: new_ref_mt = hl.MatrixTable.union_rows(*mts).annotate_globals(**{ fd: hl.max([mt.index_globals()[fd] for mt in mts]) }) else: if any_ref_max: mts = [mt.drop(fd) if fd in mt.globals else mt for mt in mts] new_ref_mt = hl.MatrixTable.union_rows(*mts) new_var_mt = hl.MatrixTable.union_rows(*(vds.variant_data for vds in vdses)) return hl.vds.VariantDataset(new_ref_mt, new_var_mt)