Source code for hail.experimental.full_outer_join_mt

import hail as hl
from hail.matrixtable import MatrixTable


[docs]def full_outer_join_mt(left: MatrixTable, right: MatrixTable) -> MatrixTable: """Performs a full outer join on `left` and `right`. Replaces row, column, and entry fields with the following: - `left_row` / `right_row`: structs of row fields from left and right. - `left_col` / `right_col`: structs of column fields from left and right. - `left_entry` / `right_entry`: structs of entry fields from left and right. Examples -------- The following creates and joins two random datasets with disjoint sample ids but non-disjoint variant sets. We use :func:`.or_else` to attempt to find a non-missing genotype. If neither genotype is non-missing, then the genotype is set to missing. In particular, note that Samples `2` and `3` have missing genotypes for loci 1:1 and 1:2 because those loci are not present in `mt2` and these samples are not present in `mt1` >>> hl.reset_global_randomness() >>> mt1 = hl.balding_nichols_model(1, 2, 3) >>> mt2 = hl.balding_nichols_model(1, 2, 3) >>> mt2 = mt2.key_rows_by(locus=hl.locus(mt2.locus.contig, ... mt2.locus.position+2), ... alleles=mt2.alleles) >>> mt2 = mt2.key_cols_by(sample_idx=mt2.sample_idx+2) >>> mt1.show() +---------------+------------+------+------+ | locus | alleles | 0.GT | 1.GT | +---------------+------------+------+------+ | locus<GRCh37> | array<str> | call | call | +---------------+------------+------+------+ | 1:1 | ["A","C"] | 0/0 | 0/0 | | 1:2 | ["A","C"] | 0/1 | 0/1 | | 1:3 | ["A","C"] | 0/0 | 0/1 | +---------------+------------+------+------+ <BLANKLINE> >>> mt2.show() +---------------+------------+------+------+ | locus | alleles | 2.GT | 3.GT | +---------------+------------+------+------+ | locus<GRCh37> | array<str> | call | call | +---------------+------------+------+------+ | 1:3 | ["A","C"] | 0/1 | 1/1 | | 1:4 | ["A","C"] | 1/1 | 0/1 | | 1:5 | ["A","C"] | 0/0 | 0/0 | +---------------+------------+------+------+ <BLANKLINE> >>> mt3 = hl.experimental.full_outer_join_mt(mt1, mt2) >>> mt3 = mt3.select_entries(GT=hl.or_else(mt3.left_entry.GT, mt3.right_entry.GT)) >>> mt3.show() +---------------+------------+------+------+------+------+ | locus | alleles | 0.GT | 1.GT | 2.GT | 3.GT | +---------------+------------+------+------+------+------+ | locus<GRCh37> | array<str> | call | call | call | call | +---------------+------------+------+------+------+------+ | 1:1 | ["A","C"] | 0/0 | 0/0 | NA | NA | | 1:2 | ["A","C"] | 0/1 | 0/1 | NA | NA | | 1:3 | ["A","C"] | 0/0 | 0/1 | 0/1 | 1/1 | | 1:4 | ["A","C"] | NA | NA | 1/1 | 0/1 | | 1:5 | ["A","C"] | NA | NA | 0/0 | 0/0 | +---------------+------------+------+------+------+------+ <BLANKLINE> Parameters ---------- left : :class:`.MatrixTable` right : :class:`.MatrixTable` Returns ------- :class:`.MatrixTable` """ if [x.dtype for x in left.row_key.values()] != [x.dtype for x in right.row_key.values()]: raise ValueError( f"row key types do not match:\n" f" left: {list(left.row_key.values())}\n" f" right: {list(right.row_key.values())}" ) if [x.dtype for x in left.col_key.values()] != [x.dtype for x in right.col_key.values()]: raise ValueError( f"column key types do not match:\n" f" left: {list(left.col_key.values())}\n" f" right: {list(right.col_key.values())}" ) left = left.select_rows(left_row=left.row) left_t = left.localize_entries('left_entries', 'left_cols') right = right.select_rows(right_row=right.row) right_t = right.localize_entries('right_entries', 'right_cols') ht = left_t.join(right_t, how='outer') ht = ht.annotate_globals( left_keys=hl.group_by( lambda t: t[0], hl.enumerate(ht.left_cols.map(lambda x: hl.tuple([x[f] for f in left.col_key])), index_first=False), ).map_values(lambda elts: elts.map(lambda t: t[1])), right_keys=hl.group_by( lambda t: t[0], hl.enumerate(ht.right_cols.map(lambda x: hl.tuple([x[f] for f in right.col_key])), index_first=False), ).map_values(lambda elts: elts.map(lambda t: t[1])), ) ht = ht.annotate_globals( key_indices=hl.array(ht.left_keys.key_set().union(ht.right_keys.key_set())) .map(lambda k: hl.struct(k=k, left_indices=ht.left_keys.get(k), right_indices=ht.right_keys.get(k))) .flatmap( lambda s: hl.case() .when( hl.is_defined(s.left_indices) & hl.is_defined(s.right_indices), hl.range(0, s.left_indices.length()).flatmap( lambda i: hl.range(0, s.right_indices.length()).map( lambda j: hl.struct(k=s.k, left_index=s.left_indices[i], right_index=s.right_indices[j]) ) ), ) .when( hl.is_defined(s.left_indices), s.left_indices.map(lambda elt: hl.struct(k=s.k, left_index=elt, right_index=hl.missing('int32'))), ) .when( hl.is_defined(s.right_indices), s.right_indices.map(lambda elt: hl.struct(k=s.k, left_index=hl.missing('int32'), right_index=elt)), ) .or_error('assertion error') ) ) ht = ht.annotate( __entries=ht.key_indices.map( lambda s: hl.struct(left_entry=ht.left_entries[s.left_index], right_entry=ht.right_entries[s.right_index]) ) ) ht = ht.annotate_globals( __cols=ht.key_indices.map( lambda s: hl.struct( **{f: s.k[i] for i, f in enumerate(left.col_key)}, left_col=ht.left_cols[s.left_index], right_col=ht.right_cols[s.right_index], ) ) ) ht = ht.drop('left_entries', 'left_cols', 'left_keys', 'right_entries', 'right_cols', 'right_keys', 'key_indices') return ht._unlocalize_entries('__entries', '__cols', list(left.col_key))