Source code for hail.experimental.vcf_combiner.vcf_combiner

"""An experimental library for combining (g)VCFS into sparse matrix tables"""
# these are necessary for the diver script included at the end of this file
import math
import uuid
from typing import Optional, List, Tuple, Dict

import hail as hl
from hail import MatrixTable, Table
from hail.expr import StructExpression
from hail.expr.expressions import expr_bool, expr_str
from hail.genetics.reference_genome import reference_genome_type
from import Apply, TableMapRows, MatrixKeyRowsBy, TopLevelReference
from hail.typecheck import oneof, sequenceof, typecheck
from import info, warning

_transform_rows_function_map = {}
_merge_function_map = {}

@typecheck(string=expr_str, has_non_ref=expr_bool)
def parse_as_ints(string, has_non_ref):
    ints = string.split(r'\|')
    ints = hl.cond(has_non_ref, ints[:-1], ints)
    return i: hl.cond((hl.len(i) == 0) | (i == '.'), hl.null(hl.tint32), hl.int32(i)))

@typecheck(string=expr_str, has_non_ref=expr_bool)
def parse_as_doubles(string, has_non_ref):
    ints = string.split(r'\|')
    ints = hl.cond(has_non_ref, ints[:-1], ints)
    return i: hl.cond((hl.len(i) == 0) | (i == '.'), hl.null(hl.tfloat64), hl.float64(i)))

@typecheck(string=expr_str, has_non_ref=expr_bool)
def parse_as_sb_table(string, has_non_ref):
    ints = string.split(r'\|')
    ints = hl.cond(has_non_ref, ints[:-1], ints)
    return xs: xs.split(",").map(hl.int32))

@typecheck(string=expr_str, has_non_ref=expr_bool)
def parse_as_ranksum(string, has_non_ref):
    typ = hl.ttuple(hl.tfloat64, hl.tint32)
    items = string.split(r'\|')
    items = hl.cond(has_non_ref, items[:-1], items)
    return s: hl.cond(
        (hl.len(s) == 0) | (s == '.'),
        hl.rbind(s.split(','), lambda ss: hl.cond(
            hl.len(ss) != 2,  # bad field, possibly 'NaN', just set it null
            hl.null(hl.ttuple(hl.tfloat64, hl.tint32)),
            hl.tuple([hl.float64(ss[0]), hl.int32(ss[1])])))))

_as_function_map = {
    'AS_QUALapprox': parse_as_ints,
    'AS_RAW_MQ': parse_as_doubles,
    'AS_RAW_MQRankSum': parse_as_ranksum,
    'AS_RAW_ReadPosRankSum': parse_as_ranksum,
    'AS_SB_TABLE': parse_as_sb_table,
    'AS_VarDP': parse_as_ints,

def parse_as_fields(info, has_non_ref):
    return hl.struct(**{f: info[f] if f not in _as_function_map
                        else _as_function_map[f](info[f], has_non_ref) for f in info})

def localize(mt):
    if isinstance(mt, MatrixTable):
        return mt._localize_entries('__entries', '__cols')
    return mt

def unlocalize(mt):
    if isinstance(mt, Table):
        return mt._unlocalize_entries('__entries', '__cols', ['s'])
    return mt

@typecheck(mt=oneof(Table, MatrixTable), info_to_keep=sequenceof(str))
def transform_gvcf(mt, info_to_keep=[]) -> Table:
    """Transforms a gvcf into a sparse matrix table

    The input to this should be some result of either :func:`.import_vcf` or
    :func:`.import_gvcfs` with ``array_elements_required=False``.

    There is an assumption that this function will be called on a matrix table
    with one column (or a localized table version of the same).

    mt : :obj:`Union[Table, MatrixTable]`
        The gvcf being transformed, if it is a table, then it must be a localized matrix table with
        the entries array named ``__entries``
    info_to_keep : :obj:`List[str]`
        Any ``INFO`` fields in the gvcf that are to be kept and put in the ``gvcf_info`` entry
        field. By default, all ``INFO`` fields except ``END`` and ``DP`` are kept.

        A localized matrix table that can be used as part of the input to :func:`.combine_gvcfs`

    This function will parse the following allele specific annotations from
    pipe delimited strings into proper values. ::


    if not info_to_keep:
        info_to_keep = [name for name in if name not in ['END', 'DP']]
    mt = localize(mt)

    if mt.row.dtype not in _transform_rows_function_map:
        def get_lgt(e, n_alleles, has_non_ref, row):
            index = e.GT.unphased_diploid_gt_index()
            n_no_nonref = n_alleles -
            triangle_without_nonref = hl.triangle(n_no_nonref)
            return (
                    .when(index < triangle_without_nonref, e.GT)
                    .when(index < hl.triangle(n_alleles), hl.null('call'))
                    .or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(

        def make_entry_struct(e, alleles_len, has_non_ref, row):
            handled_fields = dict()
            handled_names = {'LA', 'gvcf_info',
                             'LAD', 'AD',
                             'LGT', 'GT',
                             'LPL', 'PL',
                             'LPGT', 'PGT'}

            if 'END' not in
                raise hl.utils.FatalError("the Hail GVCF combiner expects GVCFs to have an 'END' field in INFO.")
            if 'GT' not in e:
                raise hl.utils.FatalError("the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT.")

            handled_fields['LA'] = hl.range(0, alleles_len - hl.cond(has_non_ref, 1, 0))
            handled_fields['LGT'] = get_lgt(e, alleles_len, has_non_ref, row)
            if 'AD' in e:
                handled_fields['LAD'] = hl.cond(has_non_ref, e.AD[:-1], e.AD)
            if 'PGT' in e:
                handled_fields['LPGT'] = e.PGT
            if 'PL' in e:
                handled_fields['LPL'] = hl.cond(has_non_ref,
                                                hl.cond(alleles_len > 2,
                                                hl.cond(alleles_len > 1,
                handled_fields['RGQ'] = hl.cond(
                    e.PL[, alleles_len - 1).unphased_diploid_gt_index()],

            handled_fields['END'] =
            handled_fields['gvcf_info'] = (

            pass_through_fields = {k: v for k, v in e.items() if k not in handled_names}
            return hl.struct(**handled_fields, **pass_through_fields)

        f = hl.experimental.define_function(
            lambda row: hl.rbind(
                hl.len(row.alleles), '<NON_REF>' == row.alleles[-1],
                lambda alleles_len, has_non_ref: hl.struct(
                    alleles=hl.cond(has_non_ref, row.alleles[:-1], row.alleles),
                        lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)))),
        _transform_rows_function_map[mt.row.dtype] = f
    transform_row = _transform_rows_function_map[mt.row.dtype]
    return Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, TopLevelReference('row'))))

def transform_one(mt, info_to_keep=[]) -> Table:
    return transform_gvcf(mt, info_to_keep)

def combine(ts):
    def merge_alleles(alleles):
        from hail.expr.functions import _num_allele_type, _allele_ints
        return hl.rbind(
   a: hl.or_else(a[0], ''))
            .fold(lambda s, t: hl.cond(hl.len(s) > hl.len(t), s, t), ''),
            lambda ref:
                    lambda al: hl.rbind(
                        lambda r:
                                lambda a:
                                    _num_allele_type(r, a),
                                    lambda at:
                                        (_allele_ints['SNP'] == at)
                                        | (_allele_ints['Insertion'] == at)
                                        | (_allele_ints['Deletion'] == at)
                                        | (_allele_ints['MNP'] == at)
                                        | (_allele_ints['Complex'] == at),
                                        a + ref[hl.len(r):],
                lambda lal:

    def renumber_entry(entry, old_to_new) -> StructExpression:
        # global index of alternate (non-ref) alleles
        return entry.annotate( lak: old_to_new[lak]))

    if (ts.row.dtype, ts.globals.dtype) not in _merge_function_map:
        f = hl.experimental.define_function(
            lambda row, gbl:
                merge_alleles( d: d.alleles)),
                lambda alleles:
                    rsid=hl.find(hl.is_defined, d: d.rsid)),
                        lambda combined_allele_index:
                        hl.range(0, hl.len(
                            lambda i:
                                    hl.range(0, hl.len(gbl.g[i].__cols))
                                    .map(lambda _: hl.null([i].__entries.dtype.element_type)),
                                        lambda old_to_new:[i]
                                            lambda e: renumber_entry(e, old_to_new)),
                                        hl.range(0, hl.len(alleles.local[i])).map(
                                            lambda j: combined_allele_index[alleles.local[i][j]])))),
                        hl.dict(hl.range(0, hl.len(alleles.globl)).map(
                            lambda j: hl.tuple([alleles.globl[j], j])))))),
            ts.row.dtype, ts.globals.dtype)
        _merge_function_map[(ts.row.dtype, ts.globals.dtype)] = f
    merge_function = _merge_function_map[(ts.row.dtype, ts.globals.dtype)]
    ts = Table(TableMapRows(ts._tir, Apply(merge_function._name,
    return ts.transmute_globals(__cols=hl.flatten( g: g.__cols)))

@typecheck(mts=sequenceof(oneof(Table, MatrixTable)))
def combine_gvcfs(mts):
    """Merges gvcfs and/or sparse matrix tables

    mts : :obj:`List[Union[Table, MatrixTable]]`
        The matrix tables (or localized versions) to combine


    All of the input tables/matrix tables must have the same partitioning. This
    module provides no method of repartitioning data.
    ts = hl.Table.multi_way_zip_join([localize(mt) for mt in mts], 'data', 'g')
    combined = combine(ts)
    return unlocalize(combined)

@typecheck(ht=hl.Table, n=int, reference_genome=reference_genome_type)
def calculate_new_intervals(ht, n, reference_genome):
    """takes a table, keyed by ['locus', ...] and produces a list of intervals suitable
    for repartitioning a combiner matrix table

    ht : :class:`.Table`
        Table / Rows Table to compute new intervals for
    n : :obj:`int`
        Number of rows each partition should have, (last partition may be smaller)
    reference_genome: :obj:`str` or :class:`.ReferenceGenome`, optional
        Reference genome to use.

    assert list(ht.key) == ['locus']
    assert == hl.tlocus(reference_genome=reference_genome)
    end = hl.Locus(reference_genome.contigs[-1],

    n_rows = ht.count()

    if n_rows == 0:
        raise ValueError('empty table!')

    ht =
    ht = ht.annotate(x=hl.scan.count())
    ht = ht.annotate(y=ht.x + 1)
    ht = ht.filter((ht.x // n != ht.y // n) | (ht.x == (n_rows - 1)))
    ht =
    ht = ht.annotate(start=hl.or_else(
        hl.scan._prev_nonnull(hl.locus_from_global_position( + 1,
        hl.locus_from_global_position(0, reference_genome=reference_genome)))
    ht = ht.key_by()
    ht =,, includes_end=True))

    intervals = ht.aggregate(hl.agg.collect(ht.interval))

    last_st = hl.eval(
        hl.locus_from_global_position(hl.literal(intervals[-1].end).global_position() + 1,
    interval = hl.Interval(start=last_st, end=end, includes_end=True)
    return intervals

@typecheck(reference_genome=reference_genome_type, interval_size=int)
def calculate_even_genome_partitioning(reference_genome, interval_size) -> List[hl.utils.Interval]:
    """create a list of locus intervals suitable for importing and merging gvcfs.

    reference_genome: :obj:`str` or :class:`.ReferenceGenome`,
        Reference genome to use. NOTE: only GRCh37 and GRCh38 references
        are supported.
    interval_size: :obj:`int` The ceiling and rough target of interval size.
        Intervals will never be larger than this, but may be smaller.

    def calc_parts(contig):
        def locus_interval(start, end):
            return hl.Interval(
                start=hl.Locus(contig=contig, position=start, reference_genome=reference_genome),
                end=hl.Locus(contig=contig, position=end, reference_genome=reference_genome),

        contig_length = reference_genome.lengths[contig]
        n_parts = math.ceil(contig_length / interval_size)
        real_size = math.ceil(contig_length / n_parts)
        n = 1
        intervals = []
        while n < contig_length:
            start = n
            end = min(n + real_size, contig_length)
            intervals.append(locus_interval(start, end))
            n = end + 1

        return intervals

    if == 'GRCh37':
        contigs = [f'{i}' for i in range(1, 23)] + ['X', 'Y', 'MT']
    elif == 'GRCh38':
        contigs = [f'chr{i}' for i in range(1, 23)] + ['chrX', 'chrY', 'chrM']
        raise ValueError(
            f"Unsupported reference genome '{}', "
            "only 'GRCh37' and 'GRCh38' are supported")

    intervals = []
    for ctg in contigs:
    return intervals


class Merge(object):
    def __init__(self,
                 inputs: List[int],
                 input_total_size: int):
        self.inputs: List[int] = inputs
        self.input_total_size: int = input_total_size

class Job(object):
    def __init__(self, merges: List[Merge]):
        self.merges: List[Merge] = merges
        self.input_total_size = sum(m.input_total_size for m in merges)

class Phase(object):
    def __init__(self, jobs: List[Job]): List[Job] = jobs

class CombinerPlan(object):
    def __init__(self,
                 file_size: List[List[int]],
                 phases: List[Phase]):
        self.file_size = file_size
        self.phases = phases
        self.merge_per_phase = len(file_size[0])
        self.total_merge = self.merge_per_phase * len(phases)

class CombinerConfig(object):
    default_branch_factor = 100
    default_batch_size = 100
    default_target_records = 30_000

    # These are used to calculate intervals for reading GVCFs in the combiner
    # The genome interval size results in 2568 partions for GRCh38. The exome
    # interval size assumes that they are around 2% the size of a genome and
    # result in 65 partitions for GRCh38.
    default_genome_interval_size = 1_200_000
    default_exome_interval_size = 60_000_000

    def __init__(self,
                 branch_factor: int = default_branch_factor,
                 batch_size: int = default_batch_size,
                 target_records: int = default_target_records):
        self.branch_factor: int = branch_factor
        self.batch_size: int = batch_size
        self.target_records: int = target_records

    def default(cls) -> 'CombinerConfig':
        return CombinerConfig()

    def plan(self, n_inputs: int) -> CombinerPlan:
        assert n_inputs > 0

        def int_ceil(x):
            return int(math.ceil(x))

        tree_height = int_ceil(math.log(n_inputs, self.branch_factor))
        phases: List[Phase] = []
        file_size: List[List[int]] = []  # List of file size per phase

        file_size.append([1 for _ in range(n_inputs)])
        while len(file_size[-1]) > 1:
            last_stage_files = file_size[-1]
            n = len(last_stage_files)
            i = 0
            jobs = []
            while (i < n):
                job = []
                job_i = 0
                while job_i < self.batch_size and i < n:
                    merge = []
                    merge_i = 0
                    merge_size = 0
                    while merge_i < self.branch_factor and i < n:
                        merge_size += last_stage_files[i]
                        merge_i += 1
                        i += 1
                    job.append(Merge(merge, merge_size))
                    job_i += 1
            file_size.append([merge.input_total_size for job in jobs for merge in job.merges])

        assert len(phases) == tree_height
        for layer in file_size:
            assert sum(layer) == n_inputs

        phase_strs = []
        total_jobs = 0
        for i, phase in enumerate(phases):
            n = len(
            job_str = hl.utils.misc.plural('job', n)
            n_files_produced = len(file_size[i + 1])
            adjective = 'final' if n_files_produced == 1 else 'intermediate'
            file_str = hl.utils.misc.plural('file', n_files_produced)
                f'\n        Phase {i + 1}: {n} {job_str} corresponding to {n_files_produced} {adjective} output {file_str}.')
            total_jobs += n

        info(f"GVCF combiner plan:\n"
             f"    Branch factor: {self.branch_factor}\n"
             f"    Batch size: {self.batch_size}\n"
             f"    Combining {n_inputs} input files in {tree_height} phases with {total_jobs} total jobs.{''.join(phase_strs)}\n")
        return CombinerPlan(file_size, phases)

[docs]def run_combiner(sample_paths: List[str], out_file: str, tmp_path: str, *, intervals: Optional[List[hl.utils.Interval]] = None, import_interval_size: Optional[int] = None, use_genome_default_intervals: bool = False, use_exome_default_intervals: bool = False, header: Optional[str] = None, sample_names: Optional[List[str]] = None, branch_factor: int = CombinerConfig.default_branch_factor, batch_size: int = CombinerConfig.default_batch_size, target_records: int = CombinerConfig.default_target_records, overwrite: bool = False, reference_genome: str = 'default', contig_recoding: Optional[Dict[str, str]] = None, key_by_locus_and_alleles: bool = False): """Run the Hail VCF combiner, performing a hierarchical merge to create a combined sparse matrix table. **Partitioning** The partitioning of input GVCFs, which determines the maximum parallelism per file, is determined the four parameters below. One of these parameters must be passed to this function. - `intervals` -- User-supplied intervals. - `import_interval_size` -- Use intervals of this uniform size across the genome. - `use_genome_default_intervals` -- Use intervals of typical uniform size for whole genome GVCFs. - `use_exome_default_intervals` -- Use intervals of typical uniform size for exome GVCFs. It is recommended that new users include either `use_genome_default_intervals` or `use_exome_default_intervals`. Note also that the partitioning of the final, combined matrix table does not depend the GVCF input partitioning. Parameters ---------- sample_paths : :obj:`list` of :obj:`str` Paths to individual GVCFs. out_file : :obj:`str` Path to final combined matrix table. tmp_path : :obj:`str` Path for intermediate output. intervals : list of :class:`.Interval` or None Import GVCFs with specified partition intervals. import_interval_size : :obj:`int` or None Import GVCFs with uniform partition intervals of specified size. use_genome_default_intervals : :obj:`bool` Import GVCFs with uniform partition intervals of default size for whole-genome data. use_exome_default_intervals : :obj:`bool` Import GVCFs with uniform partition intervals of default size for exome data. header : :obj:`str` or None External header file to use as GVCF header for all inputs. If defined, `sample_names` must be defined as well. sample_names: list of :obj:`str` or None Sample names, to be used with `header`. branch_factor : :obj:`int` Combiner branch factor. batch_size : :obj:`int` Combiner batch size. target_records : :obj:`int` Target records per partition in each combiner phase after the first. overwrite : :obj:`bool` Overwrite output file, if it exists. reference_genome : :obj:`str` Reference genome for GVCF import. contig_recoding: :obj:`dict` of (:obj:`str`, :obj:`str`), optional Mapping from contig name in gVCFs to contig name the reference genome. All contigs must be present in the `reference_genome`, so this is useful for mapping differently-formatted data onto known references. key_by_locus_and_alleles : :obj:`bool` Key by both locus and alleles in the final output. Returns ------- None """ tmp_path += f'/combiner-temporary/{uuid.uuid4()}/' if header is not None: assert sample_names is not None assert len(sample_names) == len(sample_paths) n_partition_args = (int(intervals is not None) + int(import_interval_size is not None) + int(use_genome_default_intervals) + int(use_exome_default_intervals)) if n_partition_args == 0: raise ValueError("'run_combiner': require one argument from 'intervals', 'import_interval_size', " "'use_genome_default_intervals', or 'use_exome_default_intervals' to choose GVCF partitioning") if n_partition_args > 1: warning("'run_combiner': multiple colliding arguments found from 'intervals', 'import_interval_size', " "'use_genome_default_intervals', or 'use_exome_default_intervals'." "\n The argument found first in the list in this warning will be used, and others ignored.") if intervals is not None: info(f"Using {len(intervals)} user-supplied intervals as partitioning for GVCF import") elif import_interval_size is not None: intervals = calculate_even_genome_partitioning(reference_genome, import_interval_size) info(f"Using {len(intervals)} intervals with user-supplied size" f" {import_interval_size} as partitioning for GVCF import") elif use_genome_default_intervals: size = CombinerConfig.default_genome_interval_size intervals = calculate_even_genome_partitioning(reference_genome, size) info(f"Using {len(intervals)} intervals with default whole-genome size" f" {size} as partitioning for GVCF import") elif use_exome_default_intervals: size = CombinerConfig.default_exome_interval_size intervals = calculate_even_genome_partitioning(reference_genome, size) info(f"Using {len(intervals)} intervals with default exome size" f" {size} as partitioning for GVCF import") assert intervals is not None config = CombinerConfig(branch_factor=branch_factor, batch_size=batch_size, target_records=target_records) plan = config.plan(len(sample_paths)) files_to_merge = sample_paths n_phases = len(plan.phases) total_ops = len(files_to_merge) * n_phases total_work_done = 0 for phase_i, phase in enumerate(plan.phases): phase_i += 1 # used for info messages, 1-indexed for readability n_jobs = len( merge_str = 'input GVCFs' if phase_i == 1 else 'intermediate sparse matrix tables' job_str = hl.utils.misc.plural('job', n_jobs) info(f"Starting phase {phase_i}/{n_phases}, merging {len(files_to_merge)} {merge_str} in {n_jobs} {job_str}.") if phase_i > 1: intervals = calculate_new_intervals(hl.read_matrix_table(files_to_merge[0]).rows(), config.target_records, reference_genome=reference_genome) new_files_to_merge = [] for job_i, job in enumerate( job_i += 1 # used for info messages, 1-indexed for readability n_merges = len(job.merges) merge_str = hl.utils.misc.plural('file', n_merges) pct_total = 100 * job.input_total_size / total_ops info( f"Starting phase {phase_i}/{n_phases}, job {job_i}/{len(} to create {n_merges} merged {merge_str}, corresponding to ~{pct_total:.1f}% of total I/O.") merge_mts: List[MatrixTable] = [] for merge in job.merges: inputs = [files_to_merge[i] for i in merge.inputs] if phase_i == 1: mts = [transform_gvcf(vcf) for vcf in hl.import_gvcfs(inputs, intervals, array_elements_required=False, _external_header=header, _external_sample_ids=[sample_names[i] for i in merge.inputs] if header is not None else None, reference_genome=reference_genome, contig_recoding=contig_recoding)] else: mts = [hl.read_matrix_table(path, _intervals=intervals) for path in inputs] merge_mts.append(combine_gvcfs(mts)) if phase_i == n_phases: # final merge! assert n_jobs == 1 assert len(merge_mts) == 1 [final_mt] = merge_mts if key_by_locus_and_alleles: final_mt = MatrixTable(MatrixKeyRowsBy(final_mt._mir, ['locus', 'alleles'], is_sorted=True)) final_mt.write(out_file, overwrite=overwrite) new_files_to_merge = [out_file] info(f"Finished phase {phase_i}/{n_phases}, job {job_i}/{len(}, 100% of total I/O finished.") break tmp = f'{tmp_path}_phase{phase_i}_job{job_i}/' hl.experimental.write_matrix_tables(merge_mts, tmp, overwrite=True) pad = len(str(len(merge_mts))) new_files_to_merge.extend(tmp + str(n).zfill(pad) + '.mt' for n in range(len(merge_mts))) total_work_done += job.input_total_size info( f"Finished {phase_i}/{n_phases}, job {job_i}/{len(}, {100 * total_work_done / total_ops:.1f}% of total I/O finished.") info(f"Finished phase {phase_i}/{n_phases}.") files_to_merge = new_files_to_merge assert files_to_merge == [out_file] info("Finished!")
def parse_sample_mapping(sample_map_path: str) -> Tuple[List[str], List[str]]: sample_names: List[str] = list() sample_paths: List[str] = list() with hl.hadoop_open(sample_map_path) as f: for line in f: [name, path] = line.strip().split('\t') sample_names.append(name) sample_paths.append(path) return sample_names, sample_paths