Source code for hail.vds.methods

import hail as hl
from hail import ir
from hail.expr import expr_any, expr_array, expr_bool, expr_interval, expr_locus, expr_str
from hail.matrixtable import MatrixTable
from hail.table import Table
from hail.typecheck import dictof, enumeration, func_spec, nullable, oneof, sequenceof, typecheck
from hail.utils.java import Env, info, warning
from hail.utils.misc import new_temp_file, wrap_to_list
from hail.vds.variant_dataset import VariantDataset


def write_variant_datasets(vdss, paths, *, overwrite=False, stage_locally=False, codec_spec=None):
    """Write many `vdses` to their corresponding path in `paths`."""
    ref_writer = ir.MatrixNativeMultiWriter(
        [f"{p}/reference_data" for p in paths], overwrite, stage_locally, codec_spec
    )
    var_writer = ir.MatrixNativeMultiWriter([f"{p}/variant_data" for p in paths], overwrite, stage_locally, codec_spec)
    Env.backend().execute(ir.MatrixMultiWrite([vds.reference_data._mir for vds in vdss], ref_writer))
    Env.backend().execute(ir.MatrixMultiWrite([vds.variant_data._mir for vds in vdss], var_writer))


[docs]@typecheck(vds=VariantDataset) def to_dense_mt(vds: 'VariantDataset') -> 'MatrixTable': """Creates a single, dense :class:`.MatrixTable` from the split :class:`.VariantDataset` representation. Parameters ---------- vds : :class:`.VariantDataset` Dataset in VariantDataset representation. Returns ------- :class:`.MatrixTable` Dataset in dense MatrixTable representation. """ ref = vds.reference_data # FIXME(chrisvittal) consider changing END semantics on VDS to make this better # see https://github.com/hail-is/hail/issues/13183 for why this is here and more discussion # we assume that END <= contig.length ref = ref.annotate_rows(_locus_global_pos=ref.locus.global_position(), _locus_pos=ref.locus.position) ref = ref.transmute_entries(_END_GLOBAL=ref._locus_global_pos + (ref.END - ref._locus_pos)) to_drop = 'alleles', 'rsid', 'ref_allele', '_locus_global_pos', '_locus_pos' ref = ref.drop(*(x for x in to_drop if x in ref.row)) var = vds.variant_data refl = ref.localize_entries('_ref_entries') varl = var.localize_entries('_var_entries', '_var_cols') varl = varl.annotate(_variant_defined=True) joined = varl.key_by('locus').join(refl, how='outer') dr = joined.annotate( dense_ref=hl.or_missing( joined._variant_defined, hl.scan._densify(hl.len(joined._var_cols), joined._ref_entries) ) ) dr = dr.filter(dr._variant_defined) def coalesce_join(ref, var): call_field = 'GT' if 'GT' in var else 'LGT' assert call_field in var, var.dtype if call_field not in ref: ref_call_field = 'GT' if 'GT' in ref else 'LGT' if ref_call_field not in ref: ref = ref.annotate(**{call_field: hl.call(0, 0)}) else: ref = ref.annotate(**{call_field: ref[ref_call_field]}) # call_field is now in both ref and var ref_set, var_set = set(ref.dtype), set(var.dtype) shared_fields, var_fields = var_set & ref_set, var_set - ref_set return hl.if_else( hl.is_defined(var), var.select(*shared_fields, *var_fields), ref.select(*shared_fields, **{f: hl.missing(var[f].dtype) for f in var_fields}), ) dr = dr.annotate( _dense=hl.rbind( dr._ref_entries, lambda refs_at_this_row: hl.enumerate(hl.zip(dr._var_entries, dr.dense_ref)).map( lambda tup: coalesce_join( hl.coalesce( refs_at_this_row[tup[0]], hl.or_missing(tup[1][1]._END_GLOBAL >= dr.locus.global_position(), tup[1][1]), ), tup[1][0], ) ), ), ) dr = dr._key_by_assert_sorted('locus', 'alleles') fields_to_drop = ['_var_entries', '_ref_entries', 'dense_ref', '_variant_defined'] if hl.vds.VariantDataset.ref_block_max_length_field in dr.globals: fields_to_drop.append(hl.vds.VariantDataset.ref_block_max_length_field) if 'ref_allele' in dr.row: fields_to_drop.append('ref_allele') dr = dr.drop(*fields_to_drop) return dr._unlocalize_entries('_dense', '_var_cols', list(var.col_key))
[docs]@typecheck(vds=VariantDataset, ref_allele_function=nullable(func_spec(1, expr_str))) def to_merged_sparse_mt(vds: 'VariantDataset', *, ref_allele_function=None) -> 'MatrixTable': """Creates a single, merged sparse :class:`.MatrixTable` from the split :class:`.VariantDataset` representation. Parameters ---------- vds : :class:`.VariantDataset` Dataset in VariantDataset representation. Returns ------- :class:`.MatrixTable` Dataset in the merged sparse MatrixTable representation. """ rht = vds.reference_data.localize_entries('_ref_entries', '_ref_cols') vht = vds.variant_data.localize_entries('_var_entries', '_var_cols') # drop 'alleles' key for join vht = vht.key_by('locus') merged_schema = {} for e in vds.reference_data.entry: merged_schema[e] = vds.reference_data[e].dtype for e in vds.variant_data.entry: if e in merged_schema: if not merged_schema[e] == vds.variant_data[e].dtype: raise TypeError(f"cannot unify field {e!r}: {merged_schema[e]}, {vds.variant_data[e].dtype}") else: merged_schema[e] = vds.variant_data[e].dtype ht = vht.join(rht, how='outer').drop('_ref_cols') def merge_arrays(r_array, v_array): def rewrite_ref(r): ref_block_selector = {} for k, t in merged_schema.items(): if k == 'LA': ref_block_selector[k] = hl.literal([0]) elif k in ('LGT', 'GT') and k not in r: ref_block_selector[k] = hl.call(0, 0) else: ref_block_selector[k] = r[k] if k in r else hl.missing(t) return r.select(**ref_block_selector) def rewrite_var(v): return v.select(**{k: v[k] if k in v else hl.missing(t) for k, t in merged_schema.items()}) return ( hl.case() .when(hl.is_missing(r_array), v_array.map(rewrite_var)) .when(hl.is_missing(v_array), r_array.map(rewrite_ref)) .default(hl.zip(r_array, v_array).map(lambda t: hl.coalesce(rewrite_var(t[1]), rewrite_ref(t[0])))) ) if ref_allele_function is None: rg = ht.locus.dtype.reference_genome if 'ref_allele' in ht.row: def ref_allele_function(ht): return ht.ref_allele elif rg.has_sequence(): def ref_allele_function(ht): return ht.locus.sequence_context() info("to_merged_sparse_mt: using locus sequence context to fill in reference alleles at monomorphic loci.") else: raise ValueError( "to_merged_sparse_mt: in order to construct a ref allele for reference-only sites, " "either pass a function to fill in reference alleles (e.g. ref_allele_function=lambda locus: hl.missing('str'))" " or add a sequence file with 'hl.get_reference(RG_NAME).add_sequence(FASTA_PATH)'." ) ht = ht.select( alleles=hl.coalesce(ht['alleles'], hl.array([ref_allele_function(ht)])), # handle cases where vmt is not keyed by alleles **{k: ht[k] for k in vds.variant_data.row_value if k != 'alleles'}, _entries=merge_arrays(ht['_ref_entries'], ht['_var_entries']), ) ht = ht._key_by_assert_sorted('locus', 'alleles') return ht._unlocalize_entries('_entries', '_var_cols', list(vds.variant_data.col_key))
[docs]@typecheck(vds=VariantDataset, samples=oneof(Table, expr_array(expr_str)), keep=bool, remove_dead_alleles=bool) def filter_samples( vds: 'VariantDataset', samples, *, keep: bool = True, remove_dead_alleles: bool = False ) -> 'VariantDataset': """Filter samples in a :class:`.VariantDataset`. Parameters ---------- vds : :class:`.VariantDataset` Dataset in VariantDataset representation. samples : :class:`.Table` or list of str Samples to keep or remove. keep : :obj:`bool` Whether to keep (default), or filter out the samples from `samples_table`. remove_dead_alleles : :obj:`bool` If true, remove alleles observed in no samples. Alleles with AC == 0 will be removed, and LA values recalculated. Returns ------- :class:`.VariantDataset` """ if not isinstance(samples, hl.Table): samples = hl.Table.parallelize(samples.map(lambda s: hl.struct(s=s)), key='s') if not list(samples[x].dtype for x in samples.key) == [hl.tstr]: raise TypeError(f'invalid key: {samples.key.dtype}') samples_to_keep = samples.aggregate(hl.agg.collect_as_set(samples.key[0]), _localize=False)._persist() reference_data = vds.reference_data.filter_cols(samples_to_keep.contains(vds.reference_data.col_key[0]), keep=keep) reference_data = reference_data.filter_rows(hl.agg.count() > 0) variant_data = vds.variant_data.filter_cols(samples_to_keep.contains(vds.variant_data.col_key[0]), keep=keep) if remove_dead_alleles: vd = variant_data vd = vd.annotate_rows(__allele_counts=hl.agg.explode(lambda x: hl.agg.counter(x), vd.LA), __n=hl.agg.count()) vd = vd.filter_rows(vd.__n > 0) vd = vd.drop('__n') vd = vd.annotate_rows( __kept_indices=hl.dict( hl.enumerate( hl.range(hl.len(vd.alleles)).filter(lambda idx: (idx == 0) | (vd.__allele_counts.get(idx, 0) > 0)), index_first=False, ) ) ) vd = vd.annotate_rows( __old_to_new_LA=hl.range(hl.len(vd.alleles)).map(lambda idx: vd.__kept_indices.get(idx, -1)) ) def new_la_index(old_idx): raw_idx = vd.__old_to_new_LA[old_idx] return ( hl.case() .when(raw_idx >= 0, raw_idx) .or_error("'filter_samples': unexpected local allele: old index=" + hl.str(old_idx)) ) vd = vd.annotate_entries(LA=vd.LA.map(lambda la: new_la_index(la))) vd = vd.key_rows_by('locus') vd = vd.annotate_rows(alleles=vd.__kept_indices.keys().map(lambda i: vd.alleles[i])) vd = vd._key_rows_by_assert_sorted('locus', 'alleles') vd = vd.drop('__allele_counts', '__kept_indices', '__old_to_new_LA') return VariantDataset(reference_data, vd) variant_data = variant_data.filter_rows(hl.agg.count() > 0) return VariantDataset(reference_data, variant_data)
[docs]@typecheck(mt=MatrixTable, normalization_contig=str) def impute_sex_chr_ploidy_from_interval_coverage( mt: 'MatrixTable', normalization_contig: str, ) -> 'Table': """Impute sex chromosome ploidy from a precomputed interval coverage MatrixTable. The input MatrixTable must have the following row fields: - ``interval`` (*interval*): Genomic interval of interest. - ``interval_size`` (*int32*): Size of interval, in bases. And the following entry fields: - ``sum_dp`` (*int64*): Sum of depth values by base across the interval. Returns a :class:`.Table` with sample ID keys, with the following fields: - ``autosomal_mean_dp`` (*float64*): Mean depth on calling intervals on normalization contig. - ``x_mean_dp`` (*float64*): Mean depth on calling intervals on X chromosome. - ``x_ploidy`` (*float64*): Estimated ploidy on X chromosome. Equal to ``2 * x_mean_dp / autosomal_mean_dp``. - ``y_mean_dp`` (*float64*): Mean depth on calling intervals on chromosome. - ``y_ploidy`` (*float64*): Estimated ploidy on Y chromosome. Equal to ``2 * y_mean_db / autosomal_mean_dp``. Parameters ---------- mt : :class:`.MatrixTable` Interval-by-sample MatrixTable with sum of depth values across the interval. normalization_contig : str Autosomal contig for depth comparison. Returns ------- :class:`.Table` """ rg = mt.interval.start.dtype.reference_genome if len(rg.x_contigs) != 1: raise NotImplementedError( f"reference genome {rg.name!r} has multiple X contigs, this is not supported in 'impute_sex_chr_ploidy_from_interval_coverage'" ) chr_x = rg.x_contigs[0] if len(rg.y_contigs) != 1: raise NotImplementedError( f"reference genome {rg.name!r} has multiple Y contigs, this is not supported in 'impute_sex_chr_ploidy_from_interval_coverage'" ) chr_y = rg.y_contigs[0] mt = mt.annotate_rows(contig=mt.interval.start.contig) mt = mt.annotate_cols(__mean_dp=hl.agg.group_by(mt.contig, hl.agg.sum(mt.sum_dp) / hl.agg.sum(mt.interval_size))) mean_dp_dict = mt.__mean_dp auto_dp = mean_dp_dict.get(normalization_contig, 0.0) x_dp = mean_dp_dict.get(chr_x, 0.0) y_dp = mean_dp_dict.get(chr_y, 0.0) per_sample = mt.transmute_cols( autosomal_mean_dp=auto_dp, x_mean_dp=x_dp, x_ploidy=2 * x_dp / auto_dp, y_mean_dp=y_dp, y_ploidy=2 * y_dp / auto_dp, ) info("'impute_sex_chromosome_ploidy': computing and checkpointing coverage and karyotype metrics") return per_sample.cols().checkpoint(new_temp_file('impute_sex_karyotype', extension='ht'))
[docs]@typecheck( vds=VariantDataset, calling_intervals=oneof(Table, expr_array(expr_interval(expr_locus()))), normalization_contig=str, use_variant_dataset=bool, ) def impute_sex_chromosome_ploidy( vds: VariantDataset, calling_intervals, normalization_contig: str, use_variant_dataset: bool = False ) -> Table: """Impute sex chromosome ploidy from depth of reference or variant data within calling intervals. Returns a :class:`.Table` with sample ID keys, with the following fields: - ``autosomal_mean_dp`` (*float64*): Mean depth on calling intervals on normalization contig. - ``x_mean_dp`` (*float64*): Mean depth on calling intervals on X chromosome. - ``x_ploidy`` (*float64*): Estimated ploidy on X chromosome. Equal to ``2 * x_mean_dp / autosomal_mean_dp``. - ``y_mean_dp`` (*float64*): Mean depth on calling intervals on chromosome. - ``y_ploidy`` (*float64*): Estimated ploidy on Y chromosome. Equal to ``2 * y_mean_db / autosomal_mean_dp``. Parameters ---------- vds : vds: :class:`.VariantDataset` Dataset. calling_intervals : :class:`.Table` or :class:`.ArrayExpression` Calling intervals with consistent read coverage (for exomes, trim the capture intervals). normalization_contig : str Autosomal contig for depth comparison. use_variant_dataset : bool Whether to use depth of variant data within calling intervals instead of reference data. Default will use reference data. Returns ------- :class:`.Table` """ if not isinstance(calling_intervals, Table): calling_intervals = hl.Table.parallelize( hl.map(lambda i: hl.struct(interval=i), calling_intervals), schema=hl.tstruct(interval=calling_intervals.dtype.element_type), key='interval', ) else: key_dtype = calling_intervals.key.dtype if ( len(key_dtype) != 1 or not isinstance(calling_intervals.key[0].dtype, hl.tinterval) or calling_intervals.key[0].dtype.point_type != vds.reference_data.locus.dtype ): raise ValueError( f"'impute_sex_chromosome_ploidy': expect calling_intervals to be list of intervals or" f" table with single key of type interval<locus>, found table with key: {key_dtype}" ) rg = vds.reference_data.locus.dtype.reference_genome par_boundaries = [] for par_interval in rg.par: par_boundaries.append(par_interval.start) par_boundaries.append(par_interval.end) # segment on PAR interval boundaries calling_intervals = hl.segment_intervals(calling_intervals, par_boundaries) # remove intervals overlapping PAR calling_intervals = calling_intervals.filter( hl.all(lambda x: ~x.overlaps(calling_intervals.interval), hl.literal(rg.par)) ) # checkpoint for efficient multiple downstream usages info("'impute_sex_chromosome_ploidy': checkpointing calling intervals") calling_intervals = calling_intervals.checkpoint(new_temp_file(extension='ht')) interval = calling_intervals.key[0] (any_bad_intervals, chrs_represented) = calling_intervals.aggregate(( hl.agg.any(interval.start.contig != interval.end.contig), hl.agg.collect_as_set(interval.start.contig), )) if any_bad_intervals: raise ValueError( "'impute_sex_chromosome_ploidy' does not support calling intervals that span chromosome boundaries" ) if len(rg.x_contigs) != 1: raise NotImplementedError( f"reference genome {rg.name!r} has multiple X contigs, this is not supported in 'impute_sex_chromosome_ploidy'" ) if len(rg.y_contigs) != 1: raise NotImplementedError( f"reference genome {rg.name!r} has multiple Y contigs, this is not supported in 'impute_sex_chromosome_ploidy'" ) kept_contig_filter = hl.array(chrs_represented).map(lambda x: hl.parse_locus_interval(x, reference_genome=rg)) vds = VariantDataset( hl.filter_intervals(vds.reference_data, kept_contig_filter), hl.filter_intervals(vds.variant_data, kept_contig_filter), ) if use_variant_dataset: mt = vds.variant_data calling_intervals = calling_intervals.annotate(interval_dup=interval) mt = mt.annotate_rows(interval=calling_intervals[mt.locus].interval_dup) mt = mt.filter_rows(hl.is_defined(mt.interval)) coverage = mt.select_entries(sum_dp=mt.DP, interval_size=hl.is_defined(mt.DP)) else: coverage = interval_coverage(vds, calling_intervals, gq_thresholds=()).drop('gq_thresholds') return impute_sex_chr_ploidy_from_interval_coverage(coverage, normalization_contig)
[docs]@typecheck(vds=VariantDataset, variants_table=Table, keep=bool) def filter_variants(vds: 'VariantDataset', variants_table: 'Table', *, keep: bool = True) -> 'VariantDataset': """Filter variants in a :class:`.VariantDataset`, without removing reference data. Parameters ---------- vds : :class:`.VariantDataset` Dataset in VariantDataset representation. variants_table : :class:`.Table` Variants to filter on. keep: :obj:`bool` Whether to keep (default), or filter out the variants from `variants_table`. Returns ------- :class:`.VariantDataset`. """ if keep: variant_data = vds.variant_data.semi_join_rows(variants_table) else: variant_data = vds.variant_data.anti_join_rows(variants_table) return VariantDataset(vds.reference_data, variant_data)
@typecheck( vds=VariantDataset, intervals=oneof(Table, expr_array(expr_interval(expr_any))), keep=bool, mode=enumeration('variants_only', 'split_at_boundaries', 'unchecked_filter_both'), ) def _parameterized_filter_intervals(vds: 'VariantDataset', intervals, keep: bool, mode: str) -> 'VariantDataset': intervals_table = None if isinstance(intervals, Table): expected = hl.tinterval(hl.tlocus(vds.reference_genome)) if len(intervals.key) != 1 or intervals.key[0].dtype != hl.tinterval(hl.tlocus(vds.reference_genome)): raise ValueError( f"'filter_intervals': expect a table with a single key of type {expected}; " f"found {list(intervals.key.dtype.values())}" ) intervals_table = intervals intervals = hl.literal(intervals.aggregate(hl.agg.collect(intervals.key[0]), _localize=False)) if mode == 'unchecked_filter_both': return VariantDataset( hl.filter_intervals(vds.reference_data, intervals, keep), hl.filter_intervals(vds.variant_data, intervals, keep), ) reference_data = vds.reference_data if keep: rbml = hl.vds.VariantDataset.ref_block_max_length_field if rbml in vds.reference_data.globals: max_len = hl.eval(vds.reference_data.index_globals()[rbml]) ref_intervals = intervals.map( lambda interval: hl.interval( interval.start - (max_len - 1), interval.end, interval.includes_start, interval.includes_end ) ) reference_data = hl.filter_intervals(reference_data, ref_intervals, keep) else: warning( "'hl.vds.filter_intervals': filtering intervals without a known max reference block length" "\n (computed by `hl.vds.store_ref_block_max_length` or 'hl.vds.truncate_reference_blocks')" "\n requires a full pass over the reference data (expensive!)" ) if mode == 'variants_only': variant_data = hl.filter_intervals(vds.variant_data, intervals, keep) return VariantDataset(reference_data, variant_data) if mode == 'split_at_boundaries': if not keep: raise ValueError("filter_intervals mode 'split_at_boundaries' not implemented for keep=False") par_intervals = intervals_table or hl.Table.parallelize( intervals.map(lambda x: hl.struct(interval=x)), schema=hl.tstruct(interval=intervals.dtype.element_type), key='interval', ) ref = segment_reference_blocks(reference_data, par_intervals).drop( 'interval_end', next(iter(par_intervals.key)) ) return VariantDataset(ref, hl.filter_intervals(vds.variant_data, intervals, keep))
[docs]@typecheck( vds=VariantDataset, keep=nullable(oneof(str, sequenceof(str))), remove=nullable(oneof(str, sequenceof(str))), keep_autosomes=bool, ) def filter_chromosomes(vds: 'VariantDataset', *, keep=None, remove=None, keep_autosomes=False) -> 'VariantDataset': """Filter chromosomes of a :class:`.VariantDataset` in several possible modes. Notes ----- There are three modes for :func:`filter_chromosomes`, based on which argument is passed to the function. Exactly one of the below arguments must be passed by keyword. - ``keep``: This argument expects a single chromosome identifier or a list of chromosome identifiers, and the function returns a :class:`.VariantDataset` with only those chromosomes. - ``remove``: This argument expects a single chromosome identifier or a list of chromosome identifiers, and the function returns a :class:`.VariantDataset` with those chromosomes removed. - ``keep_autosomes``: This argument expects the value ``True``, and returns a dataset without sex and mitochondrial chromosomes. Parameters ---------- vds : :class:`.VariantDataset` Dataset. keep Keep a specified list of contigs. remove Remove a specified list of contigs keep_autosomes If true, keep only autosomal chromosomes. Returns ------- :class:`.VariantDataset`. """ n_args_passed = (keep is not None) + (remove is not None) + keep_autosomes if n_args_passed == 0: raise ValueError("filter_chromosomes: expect one of 'keep', 'remove', or 'keep_autosomes' arguments") if n_args_passed > 1: raise ValueError( "filter_chromosomes: expect ONLY one of 'keep', 'remove', or 'keep_autosomes' arguments" "\n In order use 'keep_autosomes' with 'keep' or 'remove', call the function twice" ) rg = vds.reference_genome to_keep = [] if keep is not None: keep = wrap_to_list(keep) to_keep.extend(keep) elif remove is not None: remove = set(wrap_to_list(remove)) for c in rg.contigs: if c not in remove: to_keep.append(c) elif keep_autosomes: to_remove = set(rg.x_contigs + rg.y_contigs + rg.mt_contigs) for c in rg.contigs: if c not in to_remove: to_keep.append(c) parsed_intervals = hl.literal(to_keep, hl.tarray(hl.tstr)).map( lambda c: hl.parse_locus_interval(c, reference_genome=rg) ) return _parameterized_filter_intervals(vds, intervals=parsed_intervals, keep=True, mode='unchecked_filter_both')
[docs]@typecheck( vds=VariantDataset, intervals=oneof(Table, expr_array(expr_interval(expr_any))), split_reference_blocks=bool, keep=bool, ) def filter_intervals( vds: 'VariantDataset', intervals, *, split_reference_blocks: bool = False, keep: bool = True ) -> 'VariantDataset': """Filter intervals in a :class:`.VariantDataset`. Parameters ---------- vds : :class:`.VariantDataset` Dataset in VariantDataset representation. intervals : :class:`.Table` or :class:`.ArrayExpression` of type :class:`.tinterval` Intervals to filter on. split_reference_blocks: :obj:`bool` If true, remove reference data outside the given intervals by segmenting reference blocks at interval boundaries. Results in a smaller result, but this filter mode is more computationally expensive to evaluate. keep : :obj:`bool` Whether to keep, or filter out (default) rows that fall within any interval in `intervals`. Returns ------- :class:`.VariantDataset` """ if split_reference_blocks and not keep: raise ValueError("'filter_intervals': cannot use 'split_reference_blocks' with keep=False") return _parameterized_filter_intervals( vds, intervals, keep=keep, mode='split_at_boundaries' if split_reference_blocks else 'variants_only' )
[docs]@typecheck(vds=VariantDataset, filter_changed_loci=bool) def split_multi(vds: 'VariantDataset', *, filter_changed_loci: bool = False) -> 'VariantDataset': """Split the multiallelic variants in a :class:`.VariantDataset`. Parameters ---------- vds : :class:`.VariantDataset` Dataset in VariantDataset representation. filter_changed_loci : :obj:`bool` If any REF/ALT pair changes locus under :func:`.min_rep`, filter that variant instead of throwing an error. Returns ------- :class:`.VariantDataset` """ variant_data = hl.experimental.sparse_split_multi(vds.variant_data, filter_changed_loci=filter_changed_loci) reference_data = vds.reference_data if 'LGT' in reference_data.entry: if 'GT' in reference_data.entry: reference_data = reference_data.drop('LGT') else: reference_data = reference_data.transmute_entries(GT=reference_data.LGT) return VariantDataset(reference_data=reference_data, variant_data=variant_data)
@typecheck(ref=MatrixTable, intervals=Table) def segment_reference_blocks(ref: 'MatrixTable', intervals: 'Table') -> 'MatrixTable': """Returns a matrix table of reference blocks segmented according to intervals. Loci outside the given intervals are discarded. Reference blocks that start before but span an interval will appear at the interval start locus. Note ---- Assumes disjoint intervals which do not span contigs. Requires start-inclusive intervals. Parameters ---------- ref : :class:`.MatrixTable` MatrixTable of reference blocks. intervals : :class:`.Table` Table of intervals at which to segment reference blocks. Returns ------- :class:`.MatrixTable` """ interval_field = next(iter(intervals.key)) if not intervals[interval_field].dtype == hl.tinterval(ref.locus.dtype): raise ValueError( f"expect intervals to be keyed by intervals of loci matching the VariantDataset:" f" found {intervals[interval_field].dtype} / {ref.locus.dtype}" ) intervals = intervals.select(_interval_dup=intervals[interval_field]) if not intervals.aggregate( hl.agg.all( intervals[interval_field].includes_start & (intervals[interval_field].start.contig == intervals[interval_field].end.contig) ) ): raise ValueError("expect intervals to be start-inclusive") starts = intervals.key_by(_start_locus=intervals[interval_field].start) starts = starts.annotate(_include_locus=True) refl = ref.localize_entries('_ref_entries', '_ref_cols') joined = refl.join(starts, how='outer') rg = ref.locus.dtype.reference_genome contigs = rg.contigs contig_idx_map = hl.literal({contigs[i]: i for i in range(len(contigs))}, 'dict<str, int32>') joined = joined.annotate(__contig_idx=contig_idx_map[joined.locus.contig]) joined = joined.annotate( _ref_entries=joined._ref_entries.map(lambda e: e.annotate(__contig_idx=joined.__contig_idx)) ) dense = joined.annotate( dense_ref=hl.or_missing( joined._include_locus, hl.rbind( joined.locus.position, lambda pos: hl.enumerate(hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries)).map( lambda idx_and_e: hl.rbind( idx_and_e[0], idx_and_e[1], lambda idx, e: hl.coalesce( joined._ref_entries[idx], hl.or_missing((e.__contig_idx == joined.__contig_idx) & (e.END >= pos), e), ), ).drop('__contig_idx') ), ), ) ) dense = dense.filter(dense._include_locus).drop('_interval_dup', '_include_locus', '__contig_idx') # at this point, 'dense' is a table with dense rows of reference blocks, keyed by locus refl_filtered = refl.annotate(**{interval_field: intervals[refl.locus]._interval_dup}) # remove rows that are not contained in an interval, and rows that are the start of an # interval (interval starts come from the 'dense' table) refl_filtered = refl_filtered.filter( hl.is_defined(refl_filtered[interval_field]) & (refl_filtered.locus != refl_filtered[interval_field].start) ) # union dense interval starts with filtered table refl_filtered = refl_filtered.union(dense.transmute(_ref_entries=dense.dense_ref)) # rewrite reference blocks to end at the first of (interval end, reference block end) refl_filtered = refl_filtered.annotate( interval_end=refl_filtered[interval_field].end.position - ~refl_filtered[interval_field].includes_end ) refl_filtered = refl_filtered.annotate( _ref_entries=refl_filtered._ref_entries.map( lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered.interval_end)) ) ) return refl_filtered._unlocalize_entries('_ref_entries', '_ref_cols', list(ref.col_key))
[docs]@typecheck( vds=VariantDataset, intervals=Table, gq_thresholds=sequenceof(int), dp_thresholds=sequenceof(int), dp_field=nullable(str), ) def interval_coverage( vds: VariantDataset, intervals: Table, gq_thresholds=( 0, 10, 20, ), dp_thresholds=(0, 1, 10, 20, 30), dp_field=None, ) -> 'MatrixTable': """Compute statistics about base coverage by interval. Returns a :class:`.MatrixTable` with interval row keys and sample column keys. Contains the following row fields: - ``interval`` (*interval*): Genomic interval of interest. - ``interval_size`` (*int32*): Size of interval, in bases. Computes the following entry fields: - ``bases_over_gq_threshold`` (*tuple of int64*): Number of bases in the interval over each GQ threshold. - ``fraction_over_gq_threshold`` (*tuple of float64*): Fraction of interval (in bases) above each GQ threshold. Computed by dividing each member of *bases_over_gq_threshold* by *interval_size*. - ``bases_over_dp_threshold`` (*tuple of int64*): Number of bases in the interval over each DP threshold. - ``fraction_over_dp_threshold`` (*tuple of float64*): Fraction of interval (in bases) above each DP threshold. Computed by dividing each member of *bases_over_dp_threshold* by *interval_size*. - ``sum_dp`` (*int64*): Sum of depth values by base across the interval. - ``mean_dp`` (*float64*): Mean depth of bases across the interval. Computed by dividing *sum_dp* by *interval_size*. If the `dp_field` parameter is not specified, the ``DP`` is used for depth if present. If no ``DP`` field is present, the ``MIN_DP`` field is used. If no ``DP`` or ``MIN_DP`` field is present, no depth statistics will be calculated. Note ---- The metrics computed by this method are computed **only from reference blocks**. Most variant callers produce data where non-reference calls interrupt reference blocks, and so the metrics computed here are slight underestimates of the true values (which would include the quality/depth of non-reference calls). This is likely a negligible difference, but is something to be aware of, especially as it interacts with samples of ancestral backgrounds with more or fewer non-reference calls. Parameters ---------- vds : :class:`.VariantDataset` intervals : :class:`.Table` Table of intervals. Must be start-inclusive, and cannot span contigs. gq_thresholds : tuple of int GQ thresholds. dp_field : str, optional Field for depth calculation. Uses DP or MIN_DP by default (with priority for DP if present). Returns ------- :class:`.MatrixTable` Interval-by-sample matrix """ ref = vds.reference_data split = segment_reference_blocks(ref, intervals) intervals = intervals.annotate(interval_dup=intervals.key[0]) if 'DP' in ref.entry: dp_field_to_use = 'DP' elif 'MIN_DP' in ref.entry: dp_field_to_use = 'MIN_DP' else: dp_field_to_use = dp_field ref_block_length = split.END - split.locus.position + 1 if dp_field_to_use is not None: dp = split[dp_field_to_use] dp_field_dict = { 'sum_dp': hl.agg.sum(ref_block_length * dp), 'bases_over_dp_threshold': tuple( hl.agg.filter(dp >= dp_threshold, hl.agg.sum(ref_block_length)) for dp_threshold in dp_thresholds ), } else: dp_field_dict = dict() per_interval = split.group_rows_by(interval=intervals[split.row_key[0]].interval_dup).aggregate( bases_over_gq_threshold=tuple( hl.agg.filter(split.GQ >= gq_threshold, hl.agg.sum(ref_block_length)) for gq_threshold in gq_thresholds ), **dp_field_dict, ) interval = per_interval.interval interval_size = ( interval.end.position + interval.includes_end - interval.start.position - 1 + interval.includes_start ) per_interval = per_interval.annotate_rows(interval_size=interval_size) dp_mod_dict = {} if dp_field_to_use is not None: dp_mod_dict['fraction_over_dp_threshold'] = tuple( hl.float(x) / per_interval.interval_size for x in per_interval.bases_over_dp_threshold ) dp_mod_dict['mean_dp'] = per_interval.sum_dp / per_interval.interval_size per_interval = per_interval.annotate_entries( fraction_over_gq_threshold=tuple( hl.float(x) / per_interval.interval_size for x in per_interval.bases_over_gq_threshold ), **dp_mod_dict, ) per_interval = per_interval.annotate_globals(gq_thresholds=hl.tuple(gq_thresholds)) return per_interval
[docs]@typecheck( ds=oneof(MatrixTable, VariantDataset), max_ref_block_base_pairs=nullable(int), ref_block_winsorize_fraction=nullable(float), ) def truncate_reference_blocks(ds, *, max_ref_block_base_pairs=None, ref_block_winsorize_fraction=None): """Cap reference blocks at a maximum length in order to permit faster interval filtering. Examples -------- Truncate reference blocks to 5 kilobases: >>> vds2 = hl.vds.truncate_reference_blocks(vds, max_ref_block_base_pairs=5000) # doctest: +SKIP Truncate the longest 1% of reference blocks to the length of the 99th percentile block: >>> vds2 = hl.vds.truncate_reference_blocks(vds, ref_block_winsorize_fraction=0.01) # doctest: +SKIP Notes ----- After this function has been run, the reference blocks have a known maximum length `ref_block_max_length`, stored in the global fields, which permits :func:`.vds.filter_intervals` to filter to intervals of the reference data by reading `ref_block_max_length` bases ahead of each interval. This allows narrow interval queries to run in roughly O(data kept) work rather than O(all reference data) work. It is also possible to patch an existing VDS to store the max reference block length with :func:`.vds.store_ref_block_max_length`. See Also -------- :func:`.vds.store_ref_block_max_length`. Parameters ---------- vds : :class:`.VariantDataset` or :class:`.MatrixTable` max_ref_block_base_pairs Maximum size of reference blocks, in base pairs. ref_block_winsorize_fraction Fraction of reference block length distribution to truncate / winsorize. Returns ------- :class:`.VariantDataset` or :class:`.MatrixTable` """ if isinstance(ds, VariantDataset): rd = ds.reference_data else: rd = ds fd_name = hl.vds.VariantDataset.ref_block_max_length_field if fd_name in rd.globals: rd = rd.drop(fd_name) if int(ref_block_winsorize_fraction is None) + int(max_ref_block_base_pairs is None) != 1: raise ValueError( 'truncate_reference_blocks: require exactly one of "max_ref_block_base_pairs", "ref_block_winsorize_fraction"' ) if ref_block_winsorize_fraction is not None: assert ( ref_block_winsorize_fraction > 0 and ref_block_winsorize_fraction < 1 ), 'truncate_reference_blocks: "ref_block_winsorize_fraction" must be between 0 and 1 (e.g. 0.01 to truncate the top 1% of reference blocks)' if ref_block_winsorize_fraction > 0.1: warning( f"'truncate_reference_blocks': ref_block_winsorize_fraction of {ref_block_winsorize_fraction} will lead to significant data duplication," f" recommended values are <0.05." ) max_ref_block_base_pairs = rd.aggregate_entries( hl.agg.approx_quantiles(rd.END - rd.locus.position + 1, 1 - ref_block_winsorize_fraction, k=200) ) assert ( max_ref_block_base_pairs > 0 ), 'truncate_reference_blocks: "max_ref_block_base_pairs" must be between greater than zero' info(f"splitting VDS reference blocks at {max_ref_block_base_pairs} base pairs") rd_under_limit = rd.filter_entries(rd.END - rd.locus.position < max_ref_block_base_pairs).localize_entries( 'fixed_blocks', 'cols' ) rd_over_limit = rd.filter_entries(rd.END - rd.locus.position >= max_ref_block_base_pairs).key_cols_by( col_idx=hl.scan.count() ) rd_over_limit = rd_over_limit.select_rows().select_cols().key_rows_by().key_cols_by() es = rd_over_limit.entries() es = es.annotate(new_start=hl.range(es.locus.position, es.END + 1, max_ref_block_base_pairs)) es = es.explode('new_start') es = es.transmute( locus=hl.locus(es.locus.contig, es.new_start, reference_genome=es.locus.dtype.reference_genome), END=hl.min(es.new_start + max_ref_block_base_pairs - 1, es.END), ) es = es.key_by(es.locus).collect_by_key("new_blocks") es = es.transmute(moved_blocks_dict=hl.dict(es.new_blocks.map(lambda x: (x.col_idx, x.drop('col_idx'))))) joined = rd_under_limit.join(es, how='outer') joined = joined.transmute( merged_blocks=hl.range(hl.len(joined.cols)).map( lambda idx: hl.coalesce(joined.moved_blocks_dict.get(idx), joined.fixed_blocks[idx]) ) ) new_rd = joined._unlocalize_entries( entries_field_name='merged_blocks', cols_field_name='cols', col_key=list(rd.col_key) ) new_rd = new_rd.annotate_globals(**{fd_name: max_ref_block_base_pairs}) if isinstance(ds, hl.vds.VariantDataset): return VariantDataset(reference_data=new_rd, variant_data=ds.variant_data) return new_rd
[docs]@typecheck( ds=oneof(MatrixTable, VariantDataset), equivalence_function=func_spec(2, expr_bool), merge_functions=nullable(dictof(str, oneof(str, func_spec(1, expr_any)))), ) def merge_reference_blocks(ds, equivalence_function, merge_functions=None): """Merge adjacent reference blocks according to user equivalence criteria. Examples -------- Coarsen GQ granularity into bins of 10 and merges blocks with the same GQ in order to compress reference data. >>> rd = vds.reference_data # doctest: +SKIP >>> vds.reference_data = rd.annotate_entries(GQ = rd.GQ - rd.GQ % 10) # doctest: +SKIP >>> vds2 = hl.vds.merge_reference_blocks(vds, ... equivalence_function=lambda block1, block2: block1.GQ == block2.GQ), ... merge_functions={'MIN_DP': 'min'}) # doctest: +SKIP Notes ----- The `equivalence_function` argument expects a function from two reference blocks to a boolean value indicating whether they should be combined. Adjacency checks are builtin to the method (two reference blocks are 'adjacent' if the END of one block is one base before the beginning of the next). The `merge_functions` Parameters ---------- ds : :class:`.VariantDataset` or :class:`.MatrixTable` Variant dataset or reference block matrix table. Returns ------- :class:`.VariantDataset` or :class:`.MatrixTable` """ if isinstance(ds, VariantDataset): rd = ds.reference_data else: rd = ds rd = rd.annotate_rows(contig_idx_row=rd.locus.contig_idx, start_pos_row=rd.locus.position) rd = rd.annotate_entries(contig_idx=rd.contig_idx_row, start_pos=rd.start_pos_row) ht = rd.localize_entries('entries', 'cols') def merge(block1, block2): new_fields = {'END': block2.END} if merge_functions: for k, f in merge_functions.items(): if isinstance(f, str): _f = f.lower() if _f == 'min': def __f(b1, b2): return hl.min(block1[k], block2[k]) elif _f == 'max': def __f(b1, b2): return hl.max(block1[k], block2[k]) elif _f == 'sum': def __f(b1, b2): return block1[k] + block2[k] else: raise ValueError( f"merge_reference_blocks: unknown merge function {_f!r}," f" support 'min', 'max', and 'sum' in addition to custom lambdas" ) new_value = __f(block1, block2) if new_value.dtype != block1[k].dtype: raise ValueError( f'merge_reference_blocks: merge_function for {k!r}: new type {new_value.dtype!r} ' f'differs from original type {block1[k].dtype!r}' ) new_fields[k] = new_value return block1.annotate(**new_fields) def keep_last(t1, t2): e1 = t1[0] e2 = t2[0] are_adjacent = (e1.contig_idx == e2.contig_idx) & (e1.END + 1 == e2.start_pos) return hl.if_else( hl.is_defined(e1) & hl.is_defined(e2) & are_adjacent & equivalence_function(e1, e2), (merge(e1, e2), True), t2, ) # approximate a scan that merges before result ht = ht.annotate( prev_block=hl.zip( hl.scan.array_agg( lambda elt: hl.scan.fold( (hl.missing(rd.entry.dtype), False), lambda acc: keep_last(acc, (elt, False)), keep_last ), ht.entries, ), ht.entries, ).map(lambda tup: keep_last(tup[0], (tup[1], False))) ) ht_join = ht ht = ht.key_by() ht = ht.select( to_shuffle=hl.enumerate(ht.prev_block).filter( lambda idx_and_elt: hl.is_defined(idx_and_elt[1]) & idx_and_elt[1][1] ) ) ht = ht.explode('to_shuffle') rg = rd.locus.dtype.reference_genome ht = ht.transmute(col_idx=ht.to_shuffle[0], entry=ht.to_shuffle[1][0]) ht_shuf = ht.key_by( locus=hl.locus(hl.literal(rg.contigs)[ht.entry.contig_idx], ht.entry.start_pos, reference_genome=rg) ) ht_shuf = ht_shuf.collect_by_key("new_starts") # new_starts can contain multiple records for a collapsed ref block, one for each folded block. # We want to keep the one with the highest END ht_shuf = ht_shuf.select( moved_blocks_dict=hl.group_by(lambda elt: elt.col_idx, ht_shuf.new_starts).map_values( lambda arr: arr[hl.argmax(arr.map(lambda x: x.entry.END))].entry.drop('contig_idx', 'start_pos') ) ) ht_joined = ht_join.join(ht_shuf.select_globals(), 'left') def merge_f(tup): (idx, original_entry) = tup return ( hl.case() .when( ~(hl.coalesce(ht_joined.prev_block[idx][1], False)), hl.coalesce(ht_joined.moved_blocks_dict.get(idx), original_entry.drop('contig_idx', 'start_pos')), ) .or_missing() ) ht_joined = ht_joined.annotate(new_entries=hl.enumerate(ht_joined.entries).map(lambda tup: merge_f(tup))) ht_joined = ht_joined.drop('moved_blocks_dict', 'entries', 'prev_block', 'contig_idx_row', 'start_pos_row') new_rd = ht_joined._unlocalize_entries( entries_field_name='new_entries', cols_field_name='cols', col_key=list(rd.col_key) ) rbml = hl.vds.VariantDataset.ref_block_max_length_field if rbml in new_rd.globals: new_rd = new_rd.drop(rbml) if isinstance(ds, VariantDataset): return VariantDataset(reference_data=new_rd, variant_data=ds.variant_data) return new_rd