import collections
import hashlib
import json
import os
import sys
import uuid
from itertools import chain
from math import floor, log
from typing import ClassVar, Collection, Dict, List, NamedTuple, Optional, Union
import hail as hl
from hail.expr import HailType, tmatrix
from hail.genetics.reference_genome import ReferenceGenome
from hail.utils import FatalError, Interval
from hail.utils.java import info, warning
from ..variant_dataset import VariantDataset
from .combine import (
calculate_even_genome_partitioning,
calculate_new_intervals,
combine,
combine_r,
combine_variant_datasets,
defined_entry_fields,
make_reference_stream,
make_variant_stream,
transform_gvcf,
)
class CombinerOutType(NamedTuple):
"""A container for the types of a VDS"""
reference_type: tmatrix
variant_type: tmatrix
FAST_CODEC_SPEC = """{
"name": "LEB128BufferSpec",
"child": {
"name": "BlockingBufferSpec",
"blockSize": 65536,
"child": {
"name": "ZstdBlockBufferSpec",
"blockSize": 65536,
"child": {
"name": "StreamBlockBufferSpec"
}
}
}
}"""
[docs]class VariantDatasetCombiner: # pylint: disable=too-many-instance-attributes
"""A restartable and failure-tolerant method for combining one or more GVCFs and Variant Datasets.
Examples
--------
A Variant Dataset comprises one or more sequences. A new Variant Dataset is constructed from
GVCF files and/or extant Variant Datasets. For example, the following produces a new Variant
Dataset from four GVCF files containing whole genome sequences ::
gvcfs = [
'gs://bucket/sample_10123.g.vcf.bgz',
'gs://bucket/sample_10124.g.vcf.bgz',
'gs://bucket/sample_10125.g.vcf.bgz',
'gs://bucket/sample_10126.g.vcf.bgz',
]
combiner = hl.vds.new_combiner(
output_path='gs://bucket/dataset.vds',
temp_path='gs://1-day-temp-bucket/',
gvcf_paths=gvcfs,
use_genome_default_intervals=True,
)
combiner.run()
vds = hl.read_vds('gs://bucket/dataset.vds')
The following combines four new samples from GVCFs with multiple extant Variant Datasets::
gvcfs = [
'gs://bucket/sample_10123.g.vcf.bgz',
'gs://bucket/sample_10124.g.vcf.bgz',
'gs://bucket/sample_10125.g.vcf.bgz',
'gs://bucket/sample_10126.g.vcf.bgz',
]
vdses = [
'gs://bucket/hgdp.vds',
'gs://bucket/1kg.vds'
]
combiner = hl.vds.new_combiner(
output_path='gs://bucket/dataset.vds',
temp_path='gs://1-day-temp-bucket/',
save_path='gs://1-day-temp-bucket/combiner-plan.json',
gvcf_paths=gvcfs,
vds_paths=vdses,
use_genome_default_intervals=True,
)
combiner.run()
vds = hl.read_vds('gs://bucket/dataset.vds')
The speed of the Variant Dataset Combiner critically depends on data partitioning. Although the
partitioning is fully customizable, two high-quality partitioning strategies are available by
default, one for exomes and one for genomes. These partitioning strategies can be enabled,
respectively, with the parameters: ``use_exome_default_intervals=True`` and
``use_genome_default_intervals=True``.
The combiner serializes itself to `save_path` so that it can be restarted after failure.
Parameters
----------
save_path : :class:`str`
The file path to store this VariantDatasetCombiner plan. A failed or interrupted
execution can be restarted using this plan.
output_path : :class:`str`
The location to store the new VariantDataset.
temp_path : :class:`str`
The location to store temporary intermediates. We recommend using a bucket with an automatic
deletion or lifecycle policy.
reference_genome : :class:`.ReferenceGenome`
The reference genome to which all inputs (GVCFs and Variant Datasets) are aligned.
branch_factor : :class:`int`
The number of Variant Datasets to combine at once.
target_records : :class:`int`
The target number of variants per partition.
gvcf_batch_size : :class:`int`
The number of GVCFs to combine into a Variant Dataset at once.
contig_recoding : :class:`dict` mapping :class:`str` to :class:`str` or :obj:`None`
This mapping is applied to GVCF contigs before importing them into Hail. This is used to
handle GVCFs containing invalid contig names. For example, GRCh38 GVCFs which contain the
contig "1" instead of the correct "chr1".
vdses : :class:`list` of :class:`.VDSMetadata`
A list of Variant Datasets to combine. Each dataset is identified by a
:class:`.VDSMetadata`, which is a pair of a path and the number of samples in said Variant
Dataset.
gvcfs : :class:`list` of :class:`.str`
A list of paths of GVCF files to combine.
gvcf_sample_names : :class:`list` of :class:`str` or :obj:`None`
List of names to use for the samples from the GVCF files. Must be the same length as
`gvcfs`. Must be specified if `gvcf_external_header` is specified.
gvcf_external_header : :class:`str` or :obj:`None`
A path to a file containing a VCF header which is applied to all GVCFs. Must be specified if
`gvcf_sample_names` is specified.
gvcf_import_intervals : :class:`list` of :class:`.Interval`
A list of intervals defining how to partition the GVCF files. The same partitioning is used
for all GVCF files. Finer partitioning yields more parallelism but less work per task.
gvcf_info_to_keep : :class:`list` of :class:`str` or :obj:`None`
GVCF ``INFO`` fields to keep in the ``gvcf_info`` entry field. By default, all fields
except ``END`` and ``DP`` are kept.
gvcf_reference_entry_fields_to_keep : :class:`list` of :class:`str` or :obj:`None`
Genotype fields to keep in the reference table. If empty, the first 10,000 reference block
rows of ``mt`` will be sampled and all fields found to be defined other than ``GT``, ``AD``,
and ``PL`` will be entry fields in the resulting reference matrix in the dataset.
"""
_default_gvcf_batch_size = 50
_default_branch_factor = 100
_default_target_records = 24_000
_gvcf_merge_task_limit = 150_000
# These are used to calculate intervals for reading GVCFs in the combiner
# The genome interval size results in 2568 partitions for GRCh38. The exome
# interval size assumes that they are around 2% the size of a genome and
# result in 65 partitions for GRCh38.
default_genome_interval_size = 1_200_000
"A reasonable partition size in basepairs given the density of genomes."
default_exome_interval_size = 60_000_000
"A reasonable partition size in basepairs given the density of exomes."
__serialized_slots__: ClassVar = [
'_save_path',
'_output_path',
'_temp_path',
'_reference_genome',
'_dataset_type',
'_gvcf_type',
'_branch_factor',
'_target_records',
'_gvcf_batch_size',
'_contig_recoding',
'_vdses',
'_gvcfs',
'_gvcf_external_header',
'_gvcf_sample_names',
'_gvcf_import_intervals',
'_gvcf_info_to_keep',
'_gvcf_reference_entry_fields_to_keep',
'_call_fields',
]
__slots__ = tuple([*__serialized_slots__, "_uuid", "_job_id", "__intervals_cache"])
def __init__(
self,
*,
save_path: str,
output_path: str,
temp_path: str,
reference_genome: ReferenceGenome,
dataset_type: CombinerOutType,
gvcf_type: Optional[tmatrix] = None,
branch_factor: int = _default_branch_factor,
target_records: int = _default_target_records,
gvcf_batch_size: int = _default_gvcf_batch_size,
contig_recoding: Optional[Dict[str, str]] = None,
call_fields: Collection[str],
vdses: List[VDSMetadata],
gvcfs: List[str],
gvcf_sample_names: Optional[List[str]] = None,
gvcf_external_header: Optional[str] = None,
gvcf_import_intervals: List[Interval],
gvcf_info_to_keep: Optional[Collection[str]] = None,
gvcf_reference_entry_fields_to_keep: Optional[Collection[str]] = None,
):
if gvcf_import_intervals:
interval = gvcf_import_intervals[0]
if not isinstance(interval.point_type, hl.tlocus):
raise ValueError(f'intervals point type must be a locus, found {interval.point_type}')
if interval.point_type.reference_genome != reference_genome:
raise ValueError(
f'mismatch in intervals ({interval.point_type.reference_genome}) '
f'and reference genome ({reference_genome}) types'
)
if (gvcf_sample_names is None) != (gvcf_external_header is None):
raise ValueError("both 'gvcf_sample_names' and 'gvcf_external_header' must be set or unset")
if gvcf_sample_names is not None and len(gvcf_sample_names) != len(gvcfs):
raise ValueError(
"'gvcf_sample_names' and 'gvcfs' must have the same length " f'{len(gvcf_sample_names)} != {len(gvcfs)}'
)
if branch_factor < 2:
raise ValueError(f"'branch_factor' must be at least 2, found {branch_factor}")
if gvcf_batch_size < 1:
raise ValueError(f"'gvcf_batch_size' must be at least 1, found {gvcf_batch_size}")
self._save_path = save_path
self._output_path = output_path
self._temp_path = temp_path
self._reference_genome = reference_genome
self._dataset_type = dataset_type
self._gvcf_type = gvcf_type
self._branch_factor = branch_factor
self._target_records = target_records
self._contig_recoding = contig_recoding
self._call_fields = list(call_fields)
self._vdses = collections.defaultdict(list)
for vds in vdses:
self._vdses[max(1, floor(log(vds.n_samples, self._branch_factor)))].append(vds)
self._gvcfs = gvcfs
self._gvcf_sample_names = gvcf_sample_names
self._gvcf_external_header = gvcf_external_header
self._gvcf_import_intervals = gvcf_import_intervals
self._gvcf_info_to_keep = set(gvcf_info_to_keep) if gvcf_info_to_keep is not None else None
self._gvcf_reference_entry_fields_to_keep = (
set(gvcf_reference_entry_fields_to_keep) if gvcf_reference_entry_fields_to_keep is not None else None
)
self._uuid = uuid.uuid4()
self._job_id = 1
self.__intervals_cache = {}
self._gvcf_batch_size = gvcf_batch_size
@property
def gvcf_batch_size(self):
"""The number of GVCFs to combine into a Variant Dataset at once."""
return self._gvcf_batch_size
@gvcf_batch_size.setter
def gvcf_batch_size(self, value: int):
if value * len(self._gvcf_import_intervals) > VariantDatasetCombiner._gvcf_merge_task_limit:
old_value = value
value = VariantDatasetCombiner._gvcf_merge_task_limit // len(self._gvcf_import_intervals)
warning(f'gvcf_batch_size of {old_value} would produce too many tasks ' f'using {value} instead')
self._gvcf_batch_size = value
[docs] def __eq__(self, other):
if other.__class__ != VariantDatasetCombiner:
return False
for slot in self.__serialized_slots__:
if getattr(self, slot) != getattr(other, slot):
return False
return True
@property
def finished(self) -> bool:
"""Have all GVCFs and input Variant Datasets been combined?"""
return not self._gvcfs and not self._vdses
[docs] def save(self):
"""Save a :class:`.VariantDatasetCombiner` to its `save_path`."""
fs = hl.current_backend().fs
try:
backup_path = self._save_path + '.bak'
if fs.exists(self._save_path):
fs.copy(self._save_path, backup_path)
with fs.open(self._save_path, 'w') as out:
json.dump(self, out, indent=2, cls=Encoder)
if fs.exists(backup_path):
fs.remove(backup_path)
except OSError as e:
# these messages get printed, because there is absolutely no guarantee
# that the hail context is in a sane state if any of the above operations
# fail
print(f'Failed saving {self.__class__.__name__} state at {self._save_path}')
print(f'An attempt was made to copy {self._save_path} to {backup_path}')
print('An old version of this state may be there.')
print(
'Dumping current state as json to standard output, you may wish '
'to save this output in order to resume the combiner.'
)
json.dump(self, sys.stdout, indent=2, cls=Encoder)
print()
raise e
[docs] def run(self):
"""Combine the specified GVCFs and Variant Datasets."""
flagname = 'no_ir_logging'
prev_flag_value = hl._get_flags(flagname).get(flagname)
hl._set_flags(**{flagname: '1'})
vds_samples = sum(vds.n_samples for vdses in self._vdses.values() for vds in vdses)
info(
'Running VDS combiner:\n'
f' VDS arguments: {self._num_vdses} datasets with {vds_samples} samples\n'
f' GVCF arguments: {len(self._gvcfs)} inputs/samples\n'
f' Branch factor: {self._branch_factor}\n'
f' GVCF merge batch size: {self._gvcf_batch_size}'
)
while not self.finished:
self.save()
self.step()
self.save()
info('Finished VDS combiner!')
hl._set_flags(**{flagname: prev_flag_value})
[docs] @staticmethod
def load(path) -> 'VariantDatasetCombiner':
"""Load a :class:`.VariantDatasetCombiner` from `path`."""
fs = hl.current_backend().fs
with fs.open(path) as stream:
combiner = json.load(stream, cls=Decoder)
combiner._raise_if_output_exists()
if combiner._save_path != path:
warning(
'path/save_path mismatch in loaded VariantDatasetCombiner, using '
f'{path} as the new save_path for this combiner'
)
combiner._save_path = path
return combiner
def _raise_if_output_exists(self):
if self.finished:
return
fs = hl.current_backend().fs
ref_success_path = os.path.join(VariantDataset._reference_path(self._output_path), '_SUCCESS')
var_success_path = os.path.join(VariantDataset._variants_path(self._output_path), '_SUCCESS')
if fs.exists(ref_success_path) and fs.exists(var_success_path):
raise FatalError(
f'combiner output already exists at {self._output_path}\n' 'move or delete it before continuing'
)
[docs] def to_dict(self) -> dict:
"""A serializable representation of this combiner."""
intervals_typ = hl.tarray(hl.tinterval(hl.tlocus(self._reference_genome)))
return {
'name': self.__class__.__name__,
'save_path': self._save_path,
'output_path': self._output_path,
'temp_path': self._temp_path,
'reference_genome': str(self._reference_genome),
'dataset_type': self._dataset_type,
'gvcf_type': self._gvcf_type,
'branch_factor': self._branch_factor,
'target_records': self._target_records,
'gvcf_batch_size': self._gvcf_batch_size,
'gvcf_external_header': self._gvcf_external_header, # put this here for humans
'contig_recoding': self._contig_recoding,
'gvcf_info_to_keep': None if self._gvcf_info_to_keep is None else list(self._gvcf_info_to_keep),
'gvcf_reference_entry_fields_to_keep': None
if self._gvcf_reference_entry_fields_to_keep is None
else list(self._gvcf_reference_entry_fields_to_keep),
'call_fields': self._call_fields,
'vdses': [md for i in sorted(self._vdses, reverse=True) for md in self._vdses[i]],
'gvcfs': self._gvcfs,
'gvcf_sample_names': self._gvcf_sample_names,
'gvcf_import_intervals': intervals_typ._convert_to_json(self._gvcf_import_intervals),
}
@property
def _num_vdses(self):
return sum(len(v) for v in self._vdses.values())
[docs] def step(self):
"""Run one layer of combinations.
:meth:`.run` effectively runs :meth:`.step` until all GVCFs and Variant Datasets have been
combined.
"""
if self.finished:
return
if self._gvcfs:
self._step_gvcfs()
else:
self._step_vdses()
if not self.finished:
self._job_id += 1
def _write_final(self, vds):
vds.write(self._output_path)
if VariantDataset.ref_block_max_length_field not in vds.reference_data.globals:
info("VDS combiner: computing reference block max length...")
hl.vds.store_ref_block_max_length(self._output_path)
def _step_vdses(self):
current_bin = original_bin = min(self._vdses)
files_to_merge = self._vdses[current_bin][: self._branch_factor]
if len(files_to_merge) == len(self._vdses[current_bin]):
del self._vdses[current_bin]
else:
self._vdses[current_bin] = self._vdses[current_bin][self._branch_factor :]
remaining = self._branch_factor - len(files_to_merge)
while self._num_vdses > 0 and remaining > 0:
current_bin = min(self._vdses)
extra = self._vdses[current_bin][-remaining:]
if len(extra) == len(self._vdses[current_bin]):
del self._vdses[current_bin]
else:
self._vdses[current_bin] = self._vdses[current_bin][:-remaining]
files_to_merge = extra + files_to_merge
remaining = self._branch_factor - len(files_to_merge)
new_n_samples = sum(f.n_samples for f in files_to_merge)
info(f'VDS Combine (job {self._job_id}): merging {len(files_to_merge)} datasets with {new_n_samples} samples')
temp_path = self._temp_out_path(f'vds-combine_job{self._job_id}')
largest_vds = max(files_to_merge, key=lambda vds: vds.n_samples)
vds = hl.vds.read_vds(
largest_vds.path,
_assert_reference_type=self._dataset_type.reference_type,
_assert_variant_type=self._dataset_type.variant_type,
_warn_no_ref_block_max_length=False,
)
interval_bin = floor(log(new_n_samples, self._branch_factor))
intervals = self.__intervals_cache.get(interval_bin)
if intervals is None:
# we use the reference data since it generally has more rows than the variant data
intervals, _ = calculate_new_intervals(
vds.reference_data, self._target_records, os.path.join(temp_path, 'interval_checkpoint.ht')
)
self.__intervals_cache[interval_bin] = intervals
paths = [f.path for f in files_to_merge]
vdss = self._read_variant_datasets(paths, intervals)
combined = combine_variant_datasets(vdss)
if self.finished:
self._write_final(combined)
return
new_path = os.path.join(temp_path, 'dataset.vds')
combined.write(new_path, overwrite=True, _codec_spec=FAST_CODEC_SPEC)
new_bin = floor(log(new_n_samples, self._branch_factor))
# this ensures that we don't somehow stick a vds at the end of
# the same bin, ending up with a weird ordering issue
if new_bin <= original_bin:
new_bin = original_bin + 1
self._vdses[new_bin].append(VDSMetadata(path=new_path, n_samples=new_n_samples))
def _step_gvcfs(self):
step = self._branch_factor
files_to_merge = self._gvcfs[: self._gvcf_batch_size * step]
self._gvcfs = self._gvcfs[self._gvcf_batch_size * step :]
info(
f'GVCF combine (job {self._job_id}): merging {len(files_to_merge)} GVCFs into '
f'{(len(files_to_merge) + step - 1) // step} datasets'
)
if self._gvcf_external_header is not None:
sample_names = self._gvcf_sample_names[: self._gvcf_batch_size * step]
self._gvcf_sample_names = self._gvcf_sample_names[self._gvcf_batch_size * step :]
else:
sample_names = None
header_file = self._gvcf_external_header or files_to_merge[0]
header_info = hl.eval(hl.get_vcf_header_info(header_file))
merge_vds = []
merge_n_samples = []
intervals_literal = hl.literal([
hl.Struct(contig=i.start.contig, start=i.start.position, end=i.end.position)
for i in self._gvcf_import_intervals
])
partition_interval_point_type = hl.tstruct(locus=hl.tlocus(self._reference_genome))
partition_intervals = [
hl.Interval(
start=hl.Struct(locus=i.start),
end=hl.Struct(locus=i.end),
includes_start=i.includes_start,
includes_end=i.includes_end,
point_type=partition_interval_point_type,
)
for i in self._gvcf_import_intervals
]
vcfs = files_to_merge
if sample_names is None:
vcfs_lit = hl.literal(vcfs)
range_ht = hl.utils.range_table(len(vcfs), n_partitions=min(len(vcfs), 32))
range_ht = range_ht.annotate(
sample_id=hl.rbind(hl.get_vcf_header_info(vcfs_lit[range_ht.idx]), lambda header: header.sampleIDs[0])
)
sample_ids = range_ht.aggregate(hl.agg.collect(range_ht.sample_id))
else:
sample_ids = sample_names
for start in range(0, len(vcfs), step):
ids = sample_ids[start : start + step]
merging = vcfs[start : start + step]
reference_ht = hl.Table._generate(
contexts=intervals_literal,
partitions=partition_intervals,
rowfn=lambda interval, globals: hl._zip_join_producers(
hl.enumerate(hl.literal(merging)),
lambda idx_and_path: make_reference_stream(
hl.import_gvcf_interval(
idx_and_path[1],
idx_and_path[0],
interval.contig,
interval.start,
interval.end,
header_info,
call_fields=self._call_fields,
array_elements_required=False,
reference_genome=self._reference_genome,
contig_recoding=self._contig_recoding,
),
self._gvcf_reference_entry_fields_to_keep,
),
['locus'],
lambda k, v: k.annotate(data=v),
),
globals=hl.struct(g=hl.literal(ids).map(lambda s: hl.struct(__cols=[hl.struct(s=s)]))),
)
reference_ht = combine_r(reference_ht, ref_block_max_len_field=None) # compute max length at the end
variant_ht = hl.Table._generate(
contexts=intervals_literal,
partitions=partition_intervals,
rowfn=lambda interval, globals: hl._zip_join_producers(
hl.enumerate(hl.literal(merging)),
lambda idx_and_path: make_variant_stream(
hl.import_gvcf_interval(
idx_and_path[1],
idx_and_path[0],
interval.contig,
interval.start,
interval.end,
header_info,
call_fields=self._call_fields,
array_elements_required=False,
reference_genome=self._reference_genome,
contig_recoding=self._contig_recoding,
),
self._gvcf_info_to_keep,
),
['locus'],
lambda k, v: k.annotate(data=v),
),
globals=hl.struct(g=hl.literal(ids).map(lambda s: hl.struct(__cols=[hl.struct(s=s)]))),
)
variant_ht = combine(variant_ht)
vds = VariantDataset(
reference_ht._unlocalize_entries('__entries', '__cols', ['s']),
variant_ht._unlocalize_entries('__entries', '__cols', ['s'])._key_rows_by_assert_sorted(
'locus', 'alleles'
),
)
merge_vds.append(vds)
merge_n_samples.append(len(merging))
if self.finished and len(merge_vds) == 1:
self._write_final(merge_vds[0])
return
temp_path = self._temp_out_path(f'gvcf-combine_job{self._job_id}/dataset_')
pad = len(str(len(merge_vds) - 1))
merge_metadata = [
VDSMetadata(path=temp_path + str(count).rjust(pad, '0') + '.vds', n_samples=n_samples)
for count, n_samples in enumerate(merge_n_samples)
]
paths = [md.path for md in merge_metadata]
hl.vds.write_variant_datasets(merge_vds, paths, overwrite=True, codec_spec=FAST_CODEC_SPEC)
for md in merge_metadata:
self._vdses[max(1, floor(log(md.n_samples, self._branch_factor)))].append(md)
def _temp_out_path(self, extra):
return os.path.join(self._temp_path, 'combiner-intermediates', f'{self._uuid}_{extra}')
def _read_variant_datasets(self, inputs: List[str], intervals: List[Interval]):
reference_type = self._dataset_type.reference_type
variant_type = self._dataset_type.variant_type
return [
hl.vds.read_vds(
path,
intervals=intervals,
_assert_reference_type=reference_type,
_assert_variant_type=variant_type,
_warn_no_ref_block_max_length=False,
)
for path in inputs
]
[docs]def new_combiner(
*,
output_path: str,
temp_path: str,
save_path: Optional[str] = None,
gvcf_paths: Optional[List[str]] = None,
vds_paths: Optional[List[str]] = None,
vds_sample_counts: Optional[List[int]] = None,
intervals: Optional[List[Interval]] = None,
import_interval_size: Optional[int] = None,
use_genome_default_intervals: bool = False,
use_exome_default_intervals: bool = False,
gvcf_external_header: Optional[str] = None,
gvcf_sample_names: Optional[List[str]] = None,
gvcf_info_to_keep: Optional[Collection[str]] = None,
gvcf_reference_entry_fields_to_keep: Optional[Collection[str]] = None,
call_fields: Collection[str] = ['PGT'],
branch_factor: int = VariantDatasetCombiner._default_branch_factor,
target_records: int = VariantDatasetCombiner._default_target_records,
gvcf_batch_size: Optional[int] = None,
batch_size: Optional[int] = None,
reference_genome: Union[str, ReferenceGenome] = 'default',
contig_recoding: Optional[Dict[str, str]] = None,
force: bool = False,
) -> VariantDatasetCombiner:
"""Create a new :class:`.VariantDatasetCombiner` or load one from `save_path`."""
if not (gvcf_paths or vds_paths):
raise ValueError("at least one of 'gvcf_paths' or 'vds_paths' must be nonempty")
if gvcf_paths is None:
gvcf_paths = []
if len(gvcf_paths) > 0:
if len(set(gvcf_paths)) != len(gvcf_paths):
duplicates = [gvcf for gvcf, count in collections.Counter(gvcf_paths).items() if count > 1]
duplicates = '\n '.join(duplicates)
raise ValueError(f'gvcf paths should be unique, the following paths are repeated:{duplicates}')
if gvcf_sample_names is not None and len(set(gvcf_sample_names)) != len(gvcf_sample_names):
duplicates = [gvcf for gvcf, count in collections.Counter(gvcf_sample_names).items() if count > 1]
duplicates = '\n '.join(duplicates)
raise ValueError(
"provided sample names ('gvcf_sample_names') should be unique, "
f'the following names are repeated:{duplicates}'
)
if vds_paths is None:
vds_paths = []
if vds_sample_counts is not None and len(vds_paths) != len(vds_sample_counts):
raise ValueError(
"'vds_paths' and 'vds_sample_counts' (if present) must have the same length "
f'{len(vds_paths)} != {len(vds_sample_counts)}'
)
if (gvcf_sample_names is None) != (gvcf_external_header is None):
raise ValueError("both 'gvcf_sample_names' and 'gvcf_external_header' must be set or unset")
if gvcf_sample_names is not None and len(gvcf_sample_names) != len(gvcf_paths):
raise ValueError(
"'gvcf_sample_names' and 'gvcf_paths' must have the same length "
f'{len(gvcf_sample_names)} != {len(gvcf_paths)}'
)
if batch_size is None:
if gvcf_batch_size is None:
gvcf_batch_size = VariantDatasetCombiner._default_gvcf_batch_size
else:
pass
elif gvcf_batch_size is None:
warning(
'The batch_size parameter is deprecated. '
'The batch_size parameter will be removed in a future version of Hail. '
'Please use gvcf_batch_size instead.'
)
gvcf_batch_size = batch_size
else:
raise ValueError(
'Specify only one of batch_size and gvcf_batch_size. ' f'Received {batch_size} and {gvcf_batch_size}.'
)
del batch_size
def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombiner]:
if force:
return None
fs = hl.current_backend().fs
if fs.exists(save_path):
try:
combiner = load_combiner(save_path)
warning(f'found existing combiner plan at {save_path}, using it')
# we overwrite these values as they are serialized, but not part of the
# hash for an autogenerated name and we want users to be able to overwrite
# these when resuming a combine (a common reason to need to resume a combine
# is a failure due to branch factor being too large)
combiner._branch_factor = branch_factor
combiner._target_records = target_records
combiner._gvcf_batch_size = gvcf_batch_size
return combiner
except (ValueError, TypeError, OSError, KeyError) as e:
warning(
f'file exists at {save_path}, but it is not a valid combiner plan, overwriting\n'
f' caused by: {e}'
)
return None
# We do the first save_path check now after validating the arguments
if save_path is not None:
saved_combiner = maybe_load_from_saved_path(save_path)
if saved_combiner is not None:
return saved_combiner
if len(gvcf_paths) > 0:
n_partition_args = (
int(intervals is not None)
+ int(import_interval_size is not None)
+ int(use_genome_default_intervals)
+ int(use_exome_default_intervals)
)
if n_partition_args == 0:
raise ValueError(
"'new_combiner': require one argument from 'intervals', 'import_interval_size', "
"'use_genome_default_intervals', or 'use_exome_default_intervals' to choose GVCF partitioning"
)
if n_partition_args > 1:
warning(
"'new_combiner': multiple colliding arguments found from 'intervals', 'import_interval_size', "
"'use_genome_default_intervals', or 'use_exome_default_intervals'."
"\n The argument found first in the list in this warning will be used, and others ignored."
)
if intervals is not None:
pass
elif import_interval_size is not None:
intervals = calculate_even_genome_partitioning(reference_genome, import_interval_size)
elif use_genome_default_intervals:
size = VariantDatasetCombiner.default_genome_interval_size
intervals = calculate_even_genome_partitioning(reference_genome, size)
elif use_exome_default_intervals:
size = VariantDatasetCombiner.default_exome_interval_size
intervals = calculate_even_genome_partitioning(reference_genome, size)
assert intervals is not None
else:
intervals = []
if isinstance(reference_genome, str):
reference_genome = hl.get_reference(reference_genome)
# we need to compute the type that the combiner will have, this will allow us to read matrix
# tables quickly, especially in an asynchronous environment like query on batch where typing
# a read uses a blocking round trip.
vds = None
gvcf_type = None
if vds_paths:
# sync up gvcf_reference_entry_fields_to_keep and they reference entry types from the VDS
vds = hl.vds.read_vds(vds_paths[0], _warn_no_ref_block_max_length=False)
vds_ref_entry = set(
name[1:] if name in ('LGT', 'LPGT') else name for name in vds.reference_data.entry if name != 'END'
)
if gvcf_reference_entry_fields_to_keep is not None and vds_ref_entry != gvcf_reference_entry_fields_to_keep:
warning(
"Mismatch between 'gvcf_reference_entry_fields' to keep and VDS reference data "
"entry types. Overwriting with reference entry fields from supplied VDS.\n"
f" VDS reference entry fields : {sorted(vds_ref_entry)}\n"
f" requested reference entry fields: {sorted(gvcf_reference_entry_fields_to_keep)}"
)
gvcf_reference_entry_fields_to_keep = vds_ref_entry
# sync up call_fields and call fields present in the VDS
all_entry_types = chain(vds.reference_data._type.entry_type.items(), vds.variant_data._type.entry_type.items())
vds_call_fields = {
name[1:] if name == 'LPGT' else name
for name, typ in all_entry_types
if typ == hl.tcall and name not in ('LGT', 'GT')
}
if set(call_fields) != vds_call_fields:
warning(
"Mismatch between 'call_fields' and VDS call fields. "
"Overwriting with call fields from supplied VDS.\n"
f" VDS call fields : {sorted(vds_call_fields)}\n"
f" requested call fields: {sorted(call_fields)}\n"
)
call_fields = vds_call_fields
if gvcf_paths:
mt = hl.import_vcf(
gvcf_paths[0],
header_file=gvcf_external_header,
force_bgz=True,
array_elements_required=False,
reference_genome=reference_genome,
contig_recoding=contig_recoding,
)
gvcf_type = mt._type
if gvcf_reference_entry_fields_to_keep is None:
rmt = mt.filter_rows(hl.is_defined(mt.info.END))
gvcf_reference_entry_fields_to_keep = defined_entry_fields(rmt, 100_000) - {'PGT', 'PL'}
if vds is None:
vds = transform_gvcf(
mt._key_rows_by_assert_sorted('locus'), gvcf_reference_entry_fields_to_keep, gvcf_info_to_keep
)
dataset_type = CombinerOutType(reference_type=vds.reference_data._type, variant_type=vds.variant_data._type)
if save_path is None:
sha = hashlib.sha256()
sha.update(output_path.encode())
sha.update(temp_path.encode())
sha.update(str(reference_genome).encode())
sha.update(str(dataset_type).encode())
if gvcf_type is not None:
sha.update(str(gvcf_type).encode())
for path in vds_paths:
sha.update(path.encode())
for path in gvcf_paths:
sha.update(path.encode())
if gvcf_external_header is not None:
sha.update(gvcf_external_header.encode())
if gvcf_sample_names is not None:
for name in gvcf_sample_names:
sha.update(name.encode())
if gvcf_info_to_keep is not None:
for kept_info in sorted(gvcf_info_to_keep):
sha.update(kept_info.encode())
if gvcf_reference_entry_fields_to_keep is not None:
for field in sorted(gvcf_reference_entry_fields_to_keep):
sha.update(field.encode())
for call_field in sorted(call_fields):
sha.update(call_field.encode())
if contig_recoding is not None:
for key, value in sorted(contig_recoding.items()):
sha.update(key.encode())
sha.update(value.encode())
for interval in intervals:
sha.update(str(interval).encode())
digest = sha.hexdigest()
name = f'vds-combiner-plan_{digest}_{hl.__pip_version__}.json'
save_path = os.path.join(temp_path, 'combiner-plans', name)
saved_combiner = maybe_load_from_saved_path(save_path)
if saved_combiner is not None:
return saved_combiner
warning(f'generated combiner save path of {save_path}')
if vds_sample_counts:
vdses = [VDSMetadata(path, n_samples) for path, n_samples in zip(vds_paths, vds_sample_counts)]
else:
vdses = []
for path in vds_paths:
vds = hl.vds.read_vds(
path,
_assert_reference_type=dataset_type.reference_type,
_assert_variant_type=dataset_type.variant_type,
_warn_no_ref_block_max_length=False,
)
n_samples = vds.n_samples()
vdses.append(VDSMetadata(path, n_samples))
vdses.sort(key=lambda x: x.n_samples, reverse=True)
combiner = VariantDatasetCombiner(
save_path=save_path,
output_path=output_path,
temp_path=temp_path,
reference_genome=reference_genome,
dataset_type=dataset_type,
branch_factor=branch_factor,
target_records=target_records,
gvcf_batch_size=gvcf_batch_size,
contig_recoding=contig_recoding,
call_fields=call_fields,
vdses=vdses,
gvcfs=gvcf_paths,
gvcf_import_intervals=intervals,
gvcf_external_header=gvcf_external_header,
gvcf_sample_names=gvcf_sample_names,
gvcf_info_to_keep=gvcf_info_to_keep,
gvcf_reference_entry_fields_to_keep=gvcf_reference_entry_fields_to_keep,
)
combiner._raise_if_output_exists()
return combiner
[docs]def load_combiner(path: str) -> VariantDatasetCombiner:
"""Load a :class:`.VariantDatasetCombiner` from `path`."""
return VariantDatasetCombiner.load(path)
class Encoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, VariantDatasetCombiner):
return o.to_dict()
if isinstance(o, HailType):
return str(o)
if isinstance(o, tmatrix):
return o.to_dict()
return json.JSONEncoder.default(self, o)
class Decoder(json.JSONDecoder):
def __init__(self, **kwargs):
super().__init__(object_hook=Decoder._object_hook, **kwargs)
@staticmethod
def _object_hook(obj):
if 'name' not in obj:
return obj
name = obj['name']
if name == VariantDatasetCombiner.__name__:
del obj['name']
obj['vdses'] = [VDSMetadata(*x) for x in obj['vdses']]
obj['dataset_type'] = CombinerOutType(*(tmatrix._from_json(ty) for ty in obj['dataset_type']))
if 'gvcf_type' in obj and obj['gvcf_type']:
obj['gvcf_type'] = tmatrix._from_json(obj['gvcf_type'])
rg = hl.get_reference(obj['reference_genome'])
obj['reference_genome'] = rg
intervals_type = hl.tarray(hl.tinterval(hl.tlocus(rg)))
intervals = intervals_type._convert_from_json(obj['gvcf_import_intervals'])
obj['gvcf_import_intervals'] = intervals
return VariantDatasetCombiner(**obj)
return obj