import hail as hl
from hail.expr.expressions import analyze, expr_float64, expr_numeric
from hail.table import Table
from hail.typecheck import nullable, oneof, sequenceof, typecheck
from hail.utils import new_temp_file, wrap_to_list
[docs]@typecheck(
weight_expr=expr_float64,
ld_score_expr=expr_numeric,
chi_sq_exprs=oneof(expr_float64, sequenceof(expr_float64)),
n_samples_exprs=oneof(expr_numeric, sequenceof(expr_numeric)),
n_blocks=int,
two_step_threshold=int,
n_reference_panel_variants=nullable(int),
)
def ld_score_regression(
weight_expr,
ld_score_expr,
chi_sq_exprs,
n_samples_exprs,
n_blocks=200,
two_step_threshold=30,
n_reference_panel_variants=None,
) -> Table:
r"""Estimate SNP-heritability and level of confounding biases from genome-wide association study
(GWAS) summary statistics.
Given a set or multiple sets of GWAS summary statistics, :func:`.ld_score_regression` estimates the heritability
of a trait or set of traits and the level of confounding biases present in
the underlying studies by regressing chi-squared statistics on LD scores,
leveraging the model:
.. math::
\mathrm{E}[\chi_j^2] = 1 + Na + \frac{Nh_g^2}{M}l_j
* :math:`\mathrm{E}[\chi_j^2]` is the expected chi-squared statistic
for variant :math:`j` resulting from a test of association between
variant :math:`j` and a trait.
* :math:`l_j = \sum_{k} r_{jk}^2` is the LD score of variant
:math:`j`, calculated as the sum of squared correlation coefficients
between variant :math:`j` and nearby variants. See :func:`ld_score`
for further details.
* :math:`a` captures the contribution of confounding biases, such as
cryptic relatedness and uncontrolled population structure, to the
association test statistic.
* :math:`h_g^2` is the SNP-heritability, or the proportion of variation
in the trait explained by the effects of variants included in the
regression model above.
* :math:`M` is the number of variants used to estimate :math:`h_g^2`.
* :math:`N` is the number of samples in the underlying association study.
For more details on the method implemented in this function, see:
* `LD Score regression distinguishes confounding from polygenicity in genome-wide association studies (Bulik-Sullivan et al, 2015) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4495769/>`__
Examples
--------
Run the method on a matrix table of summary statistics, where the rows
are variants and the columns are different phenotypes:
>>> mt_gwas = ld_score_all_phenos_sumstats
>>> ht_results = hl.experimental.ld_score_regression(
... weight_expr=mt_gwas['ld_score'],
... ld_score_expr=mt_gwas['ld_score'],
... chi_sq_exprs=mt_gwas['chi_squared'],
... n_samples_exprs=mt_gwas['n'])
Run the method on a table with summary statistics for a single
phenotype:
>>> ht_gwas = ld_score_one_pheno_sumstats
>>> ht_results = hl.experimental.ld_score_regression(
... weight_expr=ht_gwas['ld_score'],
... ld_score_expr=ht_gwas['ld_score'],
... chi_sq_exprs=ht_gwas['chi_squared_50_irnt'],
... n_samples_exprs=ht_gwas['n_50_irnt'])
Run the method on a table with summary statistics for multiple
phenotypes:
>>> ht_gwas = ld_score_one_pheno_sumstats
>>> ht_results = hl.experimental.ld_score_regression(
... weight_expr=ht_gwas['ld_score'],
... ld_score_expr=ht_gwas['ld_score'],
... chi_sq_exprs=[ht_gwas['chi_squared_50_irnt'],
... ht_gwas['chi_squared_20160']],
... n_samples_exprs=[ht_gwas['n_50_irnt'],
... ht_gwas['n_20160']])
Notes
-----
The ``exprs`` provided as arguments to :func:`.ld_score_regression`
must all be from the same object, either a :class:`~.Table` or a
:class:`~.MatrixTable`.
**If the arguments originate from a table:**
* The table must be keyed by fields ``locus`` of type
:class:`.tlocus` and ``alleles``, a :class:`.tarray` of
:py:data:`.tstr` elements.
* ``weight_expr``, ``ld_score_expr``, ``chi_sq_exprs``, and
``n_samples_exprs`` are must be row-indexed fields.
* The number of expressions passed to ``n_samples_exprs`` must be
equal to one or the number of expressions passed to
``chi_sq_exprs``. If just one expression is passed to
``n_samples_exprs``, that sample size expression is assumed to
apply to all sets of statistics passed to ``chi_sq_exprs``.
Otherwise, the expressions passed to ``chi_sq_exprs`` and
``n_samples_exprs`` are matched by index.
* The ``phenotype`` field that keys the table returned by
:func:`.ld_score_regression` will have generic :obj:`int` values
``0``, ``1``, etc. corresponding to the ``0th``, ``1st``, etc.
expressions passed to the ``chi_sq_exprs`` argument.
**If the arguments originate from a matrix table:**
* The dimensions of the matrix table must be variants
(rows) by phenotypes (columns).
* The rows of the matrix table must be keyed by fields
``locus`` of type :class:`.tlocus` and ``alleles``,
a :class:`.tarray` of :py:data:`.tstr` elements.
* The columns of the matrix table must be keyed by a field
of type :py:data:`.tstr` that uniquely identifies phenotypes
represented in the matrix table. The column key must be a single
expression; compound keys are not accepted.
* ``weight_expr`` and ``ld_score_expr`` must be row-indexed
fields.
* ``chi_sq_exprs`` must be a single entry-indexed field
(not a list of fields).
* ``n_samples_exprs`` must be a single entry-indexed field
(not a list of fields).
* The ``phenotype`` field that keys the table returned by
:func:`.ld_score_regression` will have values corresponding to the
column keys of the input matrix table.
This function returns a :class:`~.Table` with one row per set of summary
statistics passed to the ``chi_sq_exprs`` argument. The following
row-indexed fields are included in the table:
* **phenotype** (:py:data:`.tstr`) -- The name of the phenotype. The
returned table is keyed by this field. See the notes below for
details on the possible values of this field.
* **mean_chi_sq** (:py:data:`.tfloat64`) -- The mean chi-squared
test statistic for the given phenotype.
* **intercept** (`Struct`) -- Contains fields:
- **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
intercept :math:`1 + Na`.
- **standard_error** (:py:data:`.tfloat64`) -- An estimate of
the standard error of this point estimate.
* **snp_heritability** (`Struct`) -- Contains fields:
- **estimate** (:py:data:`.tfloat64`) -- A point estimate of the
SNP-heritability :math:`h_g^2`.
- **standard_error** (:py:data:`.tfloat64`) -- An estimate of
the standard error of this point estimate.
Warning
-------
:func:`.ld_score_regression` considers only the rows for which both row
fields ``weight_expr`` and ``ld_score_expr`` are defined. Rows with missing
values in either field are removed prior to fitting the LD score
regression model.
Parameters
----------
weight_expr : :class:`.Float64Expression`
Row-indexed expression for the LD scores used to derive
variant weights in the model.
ld_score_expr : :class:`.Float64Expression`
Row-indexed expression for the LD scores used as covariates
in the model.
chi_sq_exprs : :class:`.Float64Expression` or :obj:`list` of
:class:`.Float64Expression`
One or more row-indexed (if table) or entry-indexed
(if matrix table) expressions for chi-squared
statistics resulting from genome-wide association
studies (GWAS).
n_samples_exprs: :class:`.NumericExpression` or :obj:`list` of
:class:`.NumericExpression`
One or more row-indexed (if table) or entry-indexed
(if matrix table) expressions indicating the number of
samples used in the studies that generated the test
statistics supplied to ``chi_sq_exprs``.
n_blocks : :obj:`int`
The number of blocks used in the jackknife approach to
estimating standard errors.
two_step_threshold : :obj:`int`
Variants with chi-squared statistics greater than this
value are excluded in the first step of the two-step
procedure used to fit the model.
n_reference_panel_variants : :obj:`int`, optional
Number of variants used to estimate the
SNP-heritability :math:`h_g^2`.
Returns
-------
:class:`~.Table`
Table keyed by ``phenotype`` with intercept and heritability estimates
for each phenotype passed to the function."""
chi_sq_exprs = wrap_to_list(chi_sq_exprs)
n_samples_exprs = wrap_to_list(n_samples_exprs)
assert (len(chi_sq_exprs) == len(n_samples_exprs)) or (len(n_samples_exprs) == 1)
__k = 2 # number of covariates, including intercept
ds = chi_sq_exprs[0]._indices.source
analyze('ld_score_regression/weight_expr', weight_expr, ds._row_indices)
analyze('ld_score_regression/ld_score_expr', ld_score_expr, ds._row_indices)
# format input dataset
if isinstance(ds, hl.MatrixTable):
if len(chi_sq_exprs) != 1:
raise ValueError("""Only one chi_sq_expr allowed if originating
from a matrix table.""")
if len(n_samples_exprs) != 1:
raise ValueError("""Only one n_samples_expr allowed if
originating from a matrix table.""")
col_key = list(ds.col_key)
if len(col_key) != 1:
raise ValueError("""Matrix table must be keyed by a single
phenotype field.""")
analyze('ld_score_regression/chi_squared_expr', chi_sq_exprs[0], ds._entry_indices)
analyze('ld_score_regression/n_samples_expr', n_samples_exprs[0], ds._entry_indices)
ds = ds._select_all(
row_exprs={
'__locus': ds.locus,
'__alleles': ds.alleles,
'__w_initial': weight_expr,
'__w_initial_floor': hl.max(weight_expr, 1.0),
'__x': ld_score_expr,
'__x_floor': hl.max(ld_score_expr, 1.0),
},
row_key=['__locus', '__alleles'],
col_exprs={'__y_name': ds[col_key[0]]},
col_key=['__y_name'],
entry_exprs={'__y': chi_sq_exprs[0], '__n': n_samples_exprs[0]},
)
ds = ds.annotate_entries(**{'__w': ds.__w_initial})
ds = ds.filter_rows(
hl.is_defined(ds.__locus)
& hl.is_defined(ds.__alleles)
& hl.is_defined(ds.__w_initial)
& hl.is_defined(ds.__x)
)
else:
assert isinstance(ds, Table)
for y in chi_sq_exprs:
analyze('ld_score_regression/chi_squared_expr', y, ds._row_indices)
for n in n_samples_exprs:
analyze('ld_score_regression/n_samples_expr', n, ds._row_indices)
ys = ['__y{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
ws = ['__w{:}'.format(i) for i, _ in enumerate(chi_sq_exprs)]
ns = ['__n{:}'.format(i) for i, _ in enumerate(n_samples_exprs)]
ds = ds.select(
**dict(
**{'__locus': ds.locus, '__alleles': ds.alleles, '__w_initial': weight_expr, '__x': ld_score_expr},
**{y: chi_sq_exprs[i] for i, y in enumerate(ys)},
**{w: weight_expr for w in ws},
**{n: n_samples_exprs[i] for i, n in enumerate(ns)},
)
)
ds = ds.key_by(ds.__locus, ds.__alleles)
table_tmp_file = new_temp_file()
ds.write(table_tmp_file)
ds = hl.read_table(table_tmp_file)
hts = [
ds.select(**{
'__w_initial': ds.__w_initial,
'__w_initial_floor': hl.max(ds.__w_initial, 1.0),
'__x': ds.__x,
'__x_floor': hl.max(ds.__x, 1.0),
'__y_name': i,
'__y': ds[ys[i]],
'__w': ds[ws[i]],
'__n': hl.int(ds[ns[i]]),
})
for i, y in enumerate(ys)
]
mts = [
ht.to_matrix_table(
row_key=['__locus', '__alleles'],
col_key=['__y_name'],
row_fields=['__w_initial', '__w_initial_floor', '__x', '__x_floor'],
)
for ht in hts
]
ds = mts[0]
for i in range(1, len(ys)):
ds = ds.union_cols(mts[i])
ds = ds.filter_rows(
hl.is_defined(ds.__locus)
& hl.is_defined(ds.__alleles)
& hl.is_defined(ds.__w_initial)
& hl.is_defined(ds.__x)
)
mt_tmp_file1 = new_temp_file()
ds.write(mt_tmp_file1)
mt = hl.read_matrix_table(mt_tmp_file1)
if not n_reference_panel_variants:
M = mt.count_rows()
else:
M = n_reference_panel_variants
mt = mt.annotate_entries(
__in_step1=(hl.is_defined(mt.__y) & (mt.__y < two_step_threshold)), __in_step2=hl.is_defined(mt.__y)
)
mt = mt.annotate_cols(
__col_idx=hl.int(hl.scan.count()),
__m_step1=hl.agg.count_where(mt.__in_step1),
__m_step2=hl.agg.count_where(mt.__in_step2),
)
col_keys = list(mt.col_key)
ht = mt.localize_entries(entries_array_field_name='__entries', columns_array_field_name='__cols')
ht = ht.annotate(
__entries=hl.rbind(
hl.scan.array_agg(lambda entry: hl.scan.count_where(entry.__in_step1), ht.__entries),
lambda step1_indices: hl.map(
lambda i: hl.rbind(
hl.int(hl.or_else(step1_indices[i], 0)),
ht.__cols[i].__m_step1,
ht.__entries[i],
lambda step1_idx, m_step1, entry: hl.rbind(
hl.map(lambda j: hl.int(hl.floor(j * (m_step1 / n_blocks))), hl.range(0, n_blocks + 1)),
lambda step1_separators: hl.rbind(
hl.set(step1_separators).contains(step1_idx),
hl.sum(hl.map(lambda s1: step1_idx >= s1, step1_separators)) - 1,
lambda is_separator, step1_block: entry.annotate(
__step1_block=step1_block,
__step2_block=hl.if_else(
~entry.__in_step1 & is_separator, step1_block - 1, step1_block
),
),
),
),
),
hl.range(0, hl.len(ht.__entries)),
),
)
)
mt = ht._unlocalize_entries('__entries', '__cols', col_keys)
mt_tmp_file2 = new_temp_file()
mt.write(mt_tmp_file2)
mt = hl.read_matrix_table(mt_tmp_file2)
# initial coefficient estimates
mt = mt.annotate_cols(__initial_betas=[1.0, (hl.agg.mean(mt.__y) - 1.0) / hl.agg.mean(mt.__x)])
mt = mt.annotate_cols(__step1_betas=mt.__initial_betas, __step2_betas=mt.__initial_betas)
# step 1 iteratively reweighted least squares
for i in range(3):
mt = mt.annotate_entries(
__w=hl.if_else(
mt.__in_step1,
1.0 / (mt.__w_initial_floor * 2.0 * (mt.__step1_betas[0] + mt.__step1_betas[1] * mt.__x_floor) ** 2),
0.0,
)
)
mt = mt.annotate_cols(
__step1_betas=hl.agg.filter(mt.__in_step1, hl.agg.linreg(y=mt.__y, x=[1.0, mt.__x], weight=mt.__w).beta)
)
mt = mt.annotate_cols(__step1_h2=hl.max(hl.min(mt.__step1_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
mt = mt.annotate_cols(__step1_betas=[mt.__step1_betas[0], mt.__step1_h2 * hl.agg.mean(mt.__n) / M])
# step 1 block jackknife
mt = mt.annotate_cols(
__step1_block_betas=hl.agg.array_agg(
lambda i: hl.agg.filter(
(mt.__step1_block != i) & mt.__in_step1, hl.agg.linreg(y=mt.__y, x=[1.0, mt.__x], weight=mt.__w).beta
),
hl.range(n_blocks),
)
)
mt = mt.annotate_cols(
__step1_block_betas_bias_corrected=hl.map(
lambda x: n_blocks * mt.__step1_betas - (n_blocks - 1) * x, mt.__step1_block_betas
)
)
mt = mt.annotate_cols(
__step1_jackknife_mean=hl.map(
lambda i: hl.mean(hl.map(lambda x: x[i], mt.__step1_block_betas_bias_corrected)), hl.range(0, __k)
),
__step1_jackknife_variance=hl.map(
lambda i: (
hl.sum(hl.map(lambda x: x[i] ** 2, mt.__step1_block_betas_bias_corrected))
- hl.sum(hl.map(lambda x: x[i], mt.__step1_block_betas_bias_corrected)) ** 2 / n_blocks
)
/ (n_blocks - 1)
/ n_blocks,
hl.range(0, __k),
),
)
# step 2 iteratively reweighted least squares
for i in range(3):
mt = mt.annotate_entries(
__w=hl.if_else(
mt.__in_step2,
1.0 / (mt.__w_initial_floor * 2.0 * (mt.__step2_betas[0] + +mt.__step2_betas[1] * mt.__x_floor) ** 2),
0.0,
)
)
mt = mt.annotate_cols(
__step2_betas=[
mt.__step1_betas[0],
hl.agg.filter(
mt.__in_step2, hl.agg.linreg(y=mt.__y - mt.__step1_betas[0], x=[mt.__x], weight=mt.__w).beta[0]
),
]
)
mt = mt.annotate_cols(__step2_h2=hl.max(hl.min(mt.__step2_betas[1] * M / hl.agg.mean(mt.__n), 1.0), 0.0))
mt = mt.annotate_cols(__step2_betas=[mt.__step1_betas[0], mt.__step2_h2 * hl.agg.mean(mt.__n) / M])
# step 2 block jackknife
mt = mt.annotate_cols(
__step2_block_betas=hl.agg.array_agg(
lambda i: hl.agg.filter(
(mt.__step2_block != i) & mt.__in_step2,
hl.agg.linreg(y=mt.__y - mt.__step1_betas[0], x=[mt.__x], weight=mt.__w).beta[0],
),
hl.range(n_blocks),
)
)
mt = mt.annotate_cols(
__step2_block_betas_bias_corrected=hl.map(
lambda x: n_blocks * mt.__step2_betas[1] - (n_blocks - 1) * x, mt.__step2_block_betas
)
)
mt = mt.annotate_cols(
__step2_jackknife_mean=hl.mean(mt.__step2_block_betas_bias_corrected),
__step2_jackknife_variance=(
hl.sum(mt.__step2_block_betas_bias_corrected**2)
- hl.sum(mt.__step2_block_betas_bias_corrected) ** 2 / n_blocks
)
/ (n_blocks - 1)
/ n_blocks,
)
# combine step 1 and step 2 block jackknifes
mt = mt.annotate_entries(
__step2_initial_w=1.0
/ (mt.__w_initial_floor * 2.0 * (mt.__initial_betas[0] + +mt.__initial_betas[1] * mt.__x_floor) ** 2)
)
mt = mt.annotate_cols(
__final_betas=[mt.__step1_betas[0], mt.__step2_betas[1]],
__c=(hl.agg.sum(mt.__step2_initial_w * mt.__x) / hl.agg.sum(mt.__step2_initial_w * mt.__x**2)),
)
mt = mt.annotate_cols(
__final_block_betas=hl.map(
lambda i: (mt.__step2_block_betas[i] - mt.__c * (mt.__step1_block_betas[i][0] - mt.__final_betas[0])),
hl.range(0, n_blocks),
)
)
mt = mt.annotate_cols(
__final_block_betas_bias_corrected=(n_blocks * mt.__final_betas[1] - (n_blocks - 1) * mt.__final_block_betas)
)
mt = mt.annotate_cols(
__final_jackknife_mean=[mt.__step1_jackknife_mean[0], hl.mean(mt.__final_block_betas_bias_corrected)],
__final_jackknife_variance=[
mt.__step1_jackknife_variance[0],
(
hl.sum(mt.__final_block_betas_bias_corrected**2)
- hl.sum(mt.__final_block_betas_bias_corrected) ** 2 / n_blocks
)
/ (n_blocks - 1)
/ n_blocks,
],
)
# convert coefficient to heritability estimate
mt = mt.annotate_cols(
phenotype=mt.__y_name,
mean_chi_sq=hl.agg.mean(mt.__y),
intercept=hl.struct(estimate=mt.__final_betas[0], standard_error=hl.sqrt(mt.__final_jackknife_variance[0])),
snp_heritability=hl.struct(
estimate=(M / hl.agg.mean(mt.__n)) * mt.__final_betas[1],
standard_error=hl.sqrt((M / hl.agg.mean(mt.__n)) ** 2 * mt.__final_jackknife_variance[1]),
),
)
# format and return results
ht = mt.cols()
ht = ht.key_by(ht.phenotype)
ht = ht.select(ht.mean_chi_sq, ht.intercept, ht.snp_heritability)
ht_tmp_file = new_temp_file()
ht.write(ht_tmp_file)
ht = hl.read_table(ht_tmp_file)
return ht