import itertools
from pprint import pprint
from plotly.subplots import make_subplots
import hail as hl
from .aes import Aesthetic, aes
from .coord_cartesian import CoordCartesian
from .facets import Faceter
from .geoms import FigureAttribute, Geom
from .labels import Labels
from .scale import (
Scale,
ScaleContinuous,
ScaleDiscrete,
scale_color_continuous,
scale_color_discrete,
scale_fill_continuous,
scale_fill_discrete,
scale_shape_auto,
scale_x_continuous,
scale_x_discrete,
scale_x_genomic,
scale_y_continuous,
scale_y_discrete,
)
from .utils import check_scale_continuity, is_continuous_type, is_genomic_type
[docs]class GGPlot:
"""The class representing a figure created using the ``hail.ggplot`` module.
Create one by using :func:`.ggplot`.
.. automethod:: to_plotly
.. automethod:: show
.. automethod:: write_image
"""
def __init__(self, ht, aes, geoms=[], labels=Labels(), coord_cartesian=None, scales=None, facet=None):
if scales is None:
scales = {}
self.ht = ht
self.aes = aes
self.geoms = geoms
self.labels = labels
self.coord_cartesian = coord_cartesian
self.scales = scales
self.facet = facet
self.add_default_scales(aes)
def __add__(self, other):
assert isinstance(other, (FigureAttribute, Aesthetic))
copied = self.copy()
if isinstance(other, Geom):
copied.geoms.append(other)
copied.add_default_scales(other.aes)
elif isinstance(other, Labels):
copied.labels = copied.labels.merge(other)
elif isinstance(other, CoordCartesian):
copied.coord_cartesian = other
elif isinstance(other, Scale):
copied.scales[other.aesthetic_name] = other
elif isinstance(other, Aesthetic):
copied.aes = copied.aes.merge(other)
elif isinstance(other, Faceter):
copied.facet = other
else:
raise ValueError("Not implemented")
return copied
def add_default_scales(self, aesthetic):
for aesthetic_str, mapped_expr in aesthetic.items():
dtype = mapped_expr.dtype
if aesthetic_str not in self.scales:
is_continuous = is_continuous_type(dtype)
# We only know how to come up with a few default scales.
if aesthetic_str == "x":
if is_continuous:
self.scales["x"] = scale_x_continuous()
elif is_genomic_type(dtype):
self.scales["x"] = scale_x_genomic(reference_genome=dtype.reference_genome)
else:
self.scales["x"] = scale_x_discrete()
elif aesthetic_str == "y":
if is_continuous:
self.scales["y"] = scale_y_continuous()
elif is_genomic_type(dtype):
raise ValueError("Don't yet support y axis genomic")
else:
self.scales["y"] = scale_y_discrete()
elif aesthetic_str == "color" and not is_continuous:
self.scales["color"] = scale_color_discrete()
elif aesthetic_str == "color" and is_continuous:
self.scales["color"] = scale_color_continuous()
elif aesthetic_str == "fill" and not is_continuous:
self.scales["fill"] = scale_fill_discrete()
elif aesthetic_str == "fill" and is_continuous:
self.scales["fill"] = scale_fill_continuous()
elif aesthetic_str == "shape" and not is_continuous:
self.scales["shape"] = scale_shape_auto()
elif aesthetic_str == "shape" and is_continuous:
raise ValueError(
"The 'shape' aesthetic does not support continuous "
"types. Specify values of a discrete type instead."
)
elif is_continuous:
self.scales[aesthetic_str] = ScaleContinuous(aesthetic_str)
else:
self.scales[aesthetic_str] = ScaleDiscrete(aesthetic_str)
def copy(self):
return GGPlot(self.ht, self.aes, self.geoms[:], self.labels, self.coord_cartesian, self.scales, self.facet)
def verify_scales(self):
for aes_key in self.aes.keys():
check_scale_continuity(self.scales[aes_key], self.aes[aes_key].dtype, aes_key)
for geom in self.geoms:
aesthetic_dict = geom.aes.properties
for aes_key in aesthetic_dict.keys():
check_scale_continuity(self.scales[aes_key], aesthetic_dict[aes_key].dtype, aes_key)
[docs] def to_plotly(self):
"""Turn the hail plot into a Plotly plot.
Returns
-------
A Plotly figure that can be updated with plotly methods.
"""
def make_geom_label(geom_idx):
return f"geom{geom_idx}"
def select_table():
fields_to_select = {"figure_mapping": hl.struct(**self.aes)}
if self.facet is not None:
fields_to_select["facet"] = self.facet.get_expr_to_group_by()
for geom_idx, geom in enumerate(self.geoms):
geom_label = make_geom_label(geom_idx)
fields_to_select[geom_label] = hl.struct(**geom.aes.properties)
name, ht = hl.struct(**fields_to_select)._to_table('__fallback')
return ht.select(**{field: ht[name][field] for field in fields_to_select})
def collect_mappings_and_precomputed(selected):
mapping_per_geom = []
precomputes = {}
for geom_idx, geom in enumerate(self.geoms):
geom_label = make_geom_label(geom_idx)
combined_mapping = selected["figure_mapping"].annotate(**selected[geom_label])
for key in combined_mapping:
if key in self.scales:
combined_mapping = combined_mapping.annotate(**{
key: self.scales[key].transform_data(combined_mapping[key])
})
mapping_per_geom.append(combined_mapping)
precomputes[geom_label] = geom.get_stat().get_precomputes(combined_mapping)
# Is there anything to precompute?
should_precompute = any([len(precompute) > 0 for precompute in precomputes.values()])
if should_precompute:
precomputed = selected.aggregate(hl.struct(**precomputes))
else:
precomputed = hl.Struct(**{key: hl.Struct() for key in precomputes.keys()})
return mapping_per_geom, precomputed
def get_aggregation_result(selected, mapping_per_geom, precomputed):
aggregators = {}
labels_to_stats = {}
use_faceting = self.facet is not None
for geom_idx, combined_mapping in enumerate(mapping_per_geom):
stat = self.geoms[geom_idx].get_stat()
geom_label = make_geom_label(geom_idx)
if use_faceting:
agg = hl.agg.group_by(
selected.facet, stat.make_agg(combined_mapping, precomputed[geom_label], self.scales)
)
else:
agg = stat.make_agg(combined_mapping, precomputed[geom_label], self.scales)
aggregators[geom_label] = agg
labels_to_stats[geom_label] = stat
all_agg_results = selected.aggregate(hl.struct(**aggregators))
if use_faceting:
facet_list = list(set(itertools.chain(*[list(x.keys()) for x in all_agg_results.values()])))
facet_to_idx = {facet: idx for idx, facet in enumerate(facet_list)}
facet_idx_to_agg_result = {
geom_label: {facet_to_idx[facet]: agg_result for facet, agg_result in facet_to_agg_result.items()}
for geom_label, facet_to_agg_result in all_agg_results.items()
}
num_facets = len(facet_list)
else:
facet_idx_to_agg_result = {
geom_label: {0: agg_result} for geom_label, agg_result in all_agg_results.items()
}
num_facets = 1
facet_list = None
return labels_to_stats, facet_idx_to_agg_result, num_facets, facet_list
self.verify_scales()
selected = select_table()
mapping_per_geom, precomputed = collect_mappings_and_precomputed(selected)
labels_to_stats, aggregated, num_facets, facet_list = get_aggregation_result(
selected, mapping_per_geom, precomputed
)
geoms_and_grouped_dfs_by_facet_idx = []
for geom, (geom_label, agg_result_by_facet) in zip(self.geoms, aggregated.items()):
dfs_by_facet_idx = {
facet_idx: labels_to_stats[geom_label].listify(agg_result)
for facet_idx, agg_result in agg_result_by_facet.items()
}
geoms_and_grouped_dfs_by_facet_idx.append((geom, geom_label, dfs_by_facet_idx))
# Create scaling functions based on all the data:
transformers = {}
for scale in self.scales.values():
all_dfs = list(
itertools.chain(*[
facet_to_dfs_dict.values() for _, _, facet_to_dfs_dict in geoms_and_grouped_dfs_by_facet_idx
])
)
transformers[scale.aesthetic_name] = scale.create_local_transformer(all_dfs)
is_faceted = self.facet is not None
if is_faceted:
n_facet_rows, n_facet_cols = self.facet.get_facet_nrows_and_ncols(num_facets)
subplot_args = {
"rows": n_facet_rows,
"cols": n_facet_cols,
"subplot_titles": [
", ".join([str(fs_value) for fs_value in facet_struct.values()]) for facet_struct in facet_list
],
**self.facet.get_shared_axis_kwargs(),
}
else:
n_facet_rows = 1
n_facet_cols = 1
subplot_args = {
"rows": 1,
"cols": 1,
}
fig = make_subplots(**subplot_args)
# Need to know what I've added to legend already so we don't do it more than once.
legend_cache = {}
for geom, geom_label, facet_to_grouped_dfs in geoms_and_grouped_dfs_by_facet_idx:
for facet_idx, grouped_dfs in facet_to_grouped_dfs.items():
scaled_grouped_dfs = []
for df in grouped_dfs:
scales_to_consider = list(df.columns) + list(df.attrs)
relevant_aesthetics = [scale_name for scale_name in scales_to_consider if scale_name in self.scales]
scaled_df = df
for relevant_aesthetic in relevant_aesthetics:
scaled_df = transformers[relevant_aesthetic](scaled_df)
scaled_grouped_dfs.append(scaled_df)
facet_row = facet_idx // n_facet_cols + 1
facet_col = facet_idx % n_facet_cols + 1
geom.apply_to_fig(
scaled_grouped_dfs, fig, precomputed[geom_label], facet_row, facet_col, legend_cache, is_faceted
)
# Important to update axes after labels, axes names take precedence.
self.labels.apply_to_fig(fig)
if self.scales.get("x") is not None:
self.scales["x"].apply_to_fig(self, fig)
if self.scales.get("y") is not None:
self.scales["y"].apply_to_fig(self, fig)
if self.coord_cartesian is not None:
self.coord_cartesian.apply_to_fig(fig)
fig = fig.update_xaxes(title_font_size=18, ticks="outside")
fig = fig.update_yaxes(title_font_size=18, ticks="outside")
fig.update_layout(
plot_bgcolor="white",
font_family='Arial, "Open Sans", verdana, sans-serif',
title_font_size=26,
xaxis=dict(linecolor="black", showticklabels=True),
yaxis=dict(linecolor="black", showticklabels=True),
# axes for plotly subplots are numbered following the pattern [xaxis, xaxis2, xaxis3, ...]
**{
f"{var}axis{idx}": {"linecolor": "black", "showticklabels": True}
for idx in range(2, n_facet_rows + n_facet_cols + 1)
for var in ["x", "y"]
},
)
return fig
[docs] def show(self):
"""Render and show the plot, either in a browser or notebook."""
self.to_plotly().show()
[docs] def write_image(self, path):
"""Write out this plot as an image.
This requires you to have installed the python package kaleido from pypi.
Parameters
----------
path: :class:`str`
The path to write the file to.
"""
self.to_plotly().write_image(path)
def _repr_html_(self):
return self.to_plotly()._repr_html_()
def _debug_print(self):
print("Ggplot Object:")
print("Aesthetics")
pprint(self.aes)
pprint("Scales:")
pprint(self.scales)
print("Geoms:")
pprint(self.geoms)
[docs]def ggplot(table, mapping=aes()):
"""Create the initial plot object.
This function is the beginning of all plots using the ``hail.ggplot`` interface. Plots are constructed
by calling this function, then adding attributes to the plot to get the desired result.
Examples
--------
Create a y = x^2 scatter plot
>>> ht = hl.utils.range_table(10)
>>> ht = ht.annotate(squared = ht.idx**2)
>>> my_plot = hl.ggplot.ggplot(ht, hl.ggplot.aes(x=ht.idx, y=ht.squared)) + hl.ggplot.geom_point()
Parameters
----------
table
The table containing the data to plot.
mapping
Default list of aesthetic mappings from table data to plot attributes.
Returns
-------
:class:`.GGPlot`
"""
assert isinstance(mapping, Aesthetic)
return GGPlot(table, mapping)