import abc
from collections.abc import Mapping
import plotly
import plotly.express as px
from hail.context import get_reference
from hail.expr.types import tstr
from .geoms import FigureAttribute
from .utils import continuous_nums_to_colors, is_continuous_type, is_discrete_type
class Scale(FigureAttribute):
def __init__(self, aesthetic_name):
self.aesthetic_name = aesthetic_name
@abc.abstractmethod
def transform_data(self, field_expr):
pass
def create_local_transformer(self, groups_of_dfs):
return lambda x: x
@abc.abstractmethod
def is_discrete(self):
pass
@abc.abstractmethod
def is_continuous(self):
pass
def valid_dtype(self, dtype):
pass
class PositionScale(Scale):
def __init__(self, aesthetic_name, name, breaks, labels):
super().__init__(aesthetic_name)
self.name = name
self.breaks = breaks
self.labels = labels
def update_axis(self, fig):
if self.aesthetic_name == "x":
return fig.update_xaxes
elif self.aesthetic_name == "y":
return fig.update_yaxes
# What else do discrete and continuous scales have in common?
def apply_to_fig(self, parent, fig_so_far):
if self.name is not None:
self.update_axis(fig_so_far)(title=self.name)
if self.breaks is not None:
self.update_axis(fig_so_far)(tickvals=self.breaks)
if self.labels is not None:
self.update_axis(fig_so_far)(ticktext=self.labels)
def valid_dtype(self, dtype):
return True
class PositionScaleGenomic(PositionScale):
def __init__(self, aesthetic_name, reference_genome, name=None):
super().__init__(aesthetic_name, name, None, None)
if isinstance(reference_genome, str):
reference_genome = get_reference(reference_genome)
self.reference_genome = reference_genome
def apply_to_fig(self, parent, fig_so_far):
contig_offsets = dict(list(self.reference_genome.global_positions_dict.items())[:24])
breaks = list(contig_offsets.values())
labels = list(contig_offsets.keys())
self.update_axis(fig_so_far)(tickvals=breaks, ticktext=labels)
def transform_data(self, field_expr):
return field_expr.global_position()
def is_discrete(self):
return False
def is_continuous(self):
return False
class PositionScaleContinuous(PositionScale):
def __init__(self, axis=None, name=None, breaks=None, labels=None, transformation="identity"):
super().__init__(axis, name, breaks, labels)
self.transformation = transformation
def apply_to_fig(self, parent, fig_so_far):
super().apply_to_fig(parent, fig_so_far)
if self.transformation == "identity":
pass
elif self.transformation == "log10":
self.update_axis(fig_so_far)(type="log")
elif self.transformation == "reverse":
self.update_axis(fig_so_far)(autorange="reversed")
else:
raise ValueError(f"Unrecognized transformation {self.transformation}")
def transform_data(self, field_expr):
return field_expr
def is_discrete(self):
return False
def is_continuous(self):
return True
class PositionScaleDiscrete(PositionScale):
def __init__(self, axis=None, name=None, breaks=None, labels=None):
super().__init__(axis, name, breaks, labels)
def apply_to_fig(self, parent, fig_so_far):
super().apply_to_fig(parent, fig_so_far)
def transform_data(self, field_expr):
return field_expr
def is_discrete(self):
return True
def is_continuous(self):
return False
class ScaleContinuous(Scale):
def __init__(self, aesthetic_name):
super().__init__(aesthetic_name)
def transform_data(self, field_expr):
return field_expr
def is_discrete(self):
return False
def is_continuous(self):
return True
def valid_dtype(self, dtype):
return is_continuous_type(dtype)
class ScaleDiscrete(Scale):
def __init__(self, aesthetic_name):
super().__init__(aesthetic_name)
def get_values(self, categories):
return None
def transform_data(self, field_expr):
return field_expr
def is_discrete(self):
return True
def is_continuous(self):
return False
def valid_dtype(self, dtype):
return is_discrete_type(dtype)
def create_local_transformer(self, groups_of_dfs):
categories = set()
for group_of_dfs in groups_of_dfs:
for df in group_of_dfs:
if self.aesthetic_name in df.attrs:
categories.add(df.attrs[self.aesthetic_name])
values = self.get_values(categories)
if values is None:
return super().create_local_transformer(groups_of_dfs)
elif isinstance(values, Mapping):
mapping = values
elif isinstance(values, list):
if len(categories) > len(values):
raise ValueError(
f"Not enough scale values specified. Found {len(categories)} "
f"distinct categories in {categories} and only {len(values)} "
f"scale values were provided in {values}."
)
mapping = dict(zip(categories, values))
else:
raise TypeError(
"Expected scale values to be a Mapping or list, but received a(n) " f"{type(values)}: {values}."
)
def transform(df):
df.attrs[f"{self.aesthetic_name}_legend"] = df.attrs[self.aesthetic_name]
df.attrs[self.aesthetic_name] = mapping[df.attrs[self.aesthetic_name]]
return df
return transform
class ScaleDiscreteManual(ScaleDiscrete):
def __init__(self, aesthetic_name, values):
super().__init__(aesthetic_name)
self.values = values
def get_values(self, categories):
return self.values
class ScaleColorContinuous(ScaleContinuous):
def create_local_transformer(self, groups_of_dfs):
overall_min = None
overall_max = None
for group_of_dfs in groups_of_dfs:
for df in group_of_dfs:
if self.aesthetic_name in df.columns:
series = df[self.aesthetic_name]
series_min = series.min()
series_max = series.max()
if overall_min is None:
overall_min = series_min
else:
overall_min = min(series_min, overall_min)
if overall_max is None:
overall_max = series_max
else:
overall_max = max(series_max, overall_max)
color_mapping = continuous_nums_to_colors(overall_min, overall_max, plotly.colors.sequential.Viridis)
def transform(df):
df[self.aesthetic_name] = df[self.aesthetic_name].map(lambda i: color_mapping(i))
return df
return transform
class ScaleColorHue(ScaleDiscrete):
def get_values(self, categories):
num_categories = len(categories)
step = 1.0 / num_categories
interpolation_values = [step * i for i in range(num_categories)]
hsv_scale = px.colors.get_colorscale("HSV")
return px.colors.sample_colorscale(hsv_scale, interpolation_values)
class ScaleShapeAuto(ScaleDiscrete):
def get_values(self, categories):
return [
"circle",
"square",
"diamond",
"cross",
"x",
"triangle-up",
"triangle-down",
"triangle-left",
"triangle-right",
"triangle-ne",
"triangle-se",
"triangle-sw",
"triangle-nw",
"pentagon",
"hexagon",
"hexagon2",
"octagon",
"star",
"hexagram",
"star-triangle-up",
"star-triangle-down",
"star-square",
"star-diamond",
"diamond-tall",
"diamond-wide",
"hourglass",
"bowtie",
"circle-cross",
"circle-x",
"square-cross",
"square-x",
"diamond-cross",
"diamond-x",
"cross-thin",
"x-thin",
"asterisk",
"hash",
"y-up",
"y-down",
"y-left",
"y-right",
"line-ew",
"line-ns",
"line-ne",
"line-nw",
"arrow-up",
"arrow-down",
"arrow-left",
"arrow-right",
"arrow-bar-up",
"arrow-bar-down",
"arrow-bar-left",
"arrow-bar-right",
]
class ScaleColorContinuousIdentity(ScaleContinuous):
def valid_dtype(self, dtype):
return dtype == tstr
[docs]def scale_x_log10(name=None):
"""Transforms x axis to be log base 10 scaled.
Parameters
----------
name: :class:`str`
The label to show on x-axis
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleContinuous("x", name=name, transformation="log10")
[docs]def scale_y_log10(name=None):
"""Transforms y-axis to be log base 10 scaled.
Parameters
----------
name: :class:`str`
The label to show on y-axis
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleContinuous("y", name=name, transformation="log10")
[docs]def scale_x_reverse(name=None):
"""Transforms x-axis to be vertically reversed.
Parameters
----------
name: :class:`str`
The label to show on x-axis
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleContinuous("x", name=name, transformation="reverse")
[docs]def scale_y_reverse(name=None):
"""Transforms y-axis to be vertically reversed.
Parameters
----------
name: :class:`str`
The label to show on y-axis
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleContinuous("y", name=name, transformation="reverse")
[docs]def scale_x_continuous(name=None, breaks=None, labels=None, trans="identity"):
"""The default continuous x scale.
Parameters
----------
name: :class:`str`
The label to show on x-axis
breaks: :class:`list` of :class:`float`
The locations to draw ticks on the x-axis.
labels: :class:`list` of :class:`str`
The labels of the ticks on the axis.
trans: :class:`str`
The transformation to apply to the x-axis. Supports "identity", "reverse", "log10".
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleContinuous("x", name=name, breaks=breaks, labels=labels, transformation=trans)
[docs]def scale_y_continuous(name=None, breaks=None, labels=None, trans="identity"):
"""The default continuous y scale.
Parameters
----------
name: :class:`str`
The label to show on y-axis
breaks: :class:`list` of :class:`float`
The locations to draw ticks on the y-axis.
labels: :class:`list` of :class:`str`
The labels of the ticks on the axis.
trans: :class:`str`
The transformation to apply to the y-axis. Supports "identity", "reverse", "log10".
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleContinuous("y", name=name, breaks=breaks, labels=labels, transformation=trans)
[docs]def scale_x_discrete(name=None, breaks=None, labels=None):
"""The default discrete x scale.
Parameters
----------
name: :class:`str`
The label to show on x-axis
breaks: :class:`list` of :class:`str`
The locations to draw ticks on the x-axis.
labels: :class:`list` of :class:`str`
The labels of the ticks on the axis.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleDiscrete("x", name=name, breaks=breaks, labels=labels)
[docs]def scale_y_discrete(name=None, breaks=None, labels=None):
"""The default discrete y scale.
Parameters
----------
name: :class:`str`
The label to show on y-axis
breaks: :class:`list` of :class:`str`
The locations to draw ticks on the y-axis.
labels: :class:`list` of :class:`str`
The labels of the ticks on the axis.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleDiscrete("y", name=name, breaks=breaks, labels=labels)
[docs]def scale_x_genomic(reference_genome, name=None):
"""The default genomic x scale. This is used when the ``x`` aesthetic corresponds to a :class:`.LocusExpression`.
Parameters
----------
reference_genome:
The reference genome being used.
name: :class:`str`
The label to show on y-axis
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return PositionScaleGenomic("x", reference_genome, name=name)
[docs]def scale_color_discrete():
"""The default discrete color scale. This maps each discrete value to a color. Equivalent to scale_color_hue.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return scale_color_hue()
[docs]def scale_color_hue():
"""Map discrete colors to evenly placed positions around the color wheel.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleColorHue("color")
[docs]def scale_color_continuous():
"""The default continuous color scale. This linearly interpolates colors between the min and max observed values.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleColorContinuous("color")
[docs]def scale_color_identity():
"""A color scale that assumes the expression specified in the ``color`` aesthetic can be used as a color.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleColorContinuousIdentity("color")
[docs]def scale_color_manual(*, values):
"""A color scale that assigns strings to colors using the pool of colors specified as `values`.
Parameters
----------
values: :class:`list` of :class:`str`
The colors to choose when assigning values to colors.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleDiscreteManual("color", values=values)
[docs]def scale_fill_discrete():
"""The default discrete fill scale. This maps each discrete value to a fill color.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return scale_fill_hue()
[docs]def scale_fill_continuous():
"""The default continuous fill scale. This linearly interpolates colors between the min and max observed values.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleColorContinuous("fill")
[docs]def scale_fill_identity():
"""A color scale that assumes the expression specified in the ``fill`` aesthetic can be used as a fill color.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleColorContinuousIdentity("fill")
[docs]def scale_fill_hue():
"""Map discrete fill colors to evenly placed positions around the color wheel.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleColorHue("fill")
[docs]def scale_fill_manual(*, values):
"""A color scale that assigns strings to fill colors using the pool of colors specified as `values`.
Parameters
----------
values: :class:`list` of :class:`str`
The colors to choose when assigning values to colors.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleDiscreteManual("fill", values=values)
def scale_shape_manual(*, values):
"""A scale that assigns shapes to discrete aesthetics. See `the plotly documentation <https://plotly.com/python-api-reference/generated/plotly.graph_objects.scatter.html#plotly.graph_objects.scatter.Marker.symbol>`__ for a list of supported shapes.
Parameters
----------
values: :class:`list` of :class:`str`
The shapes from which to choose.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleDiscreteManual("shape", values=values)
def scale_shape_auto():
"""A scale that automatically assigns shapes to discrete aesthetics.
Returns
-------
:class:`.FigureAttribute`
The scale to be applied.
"""
return ScaleShapeAuto("shape")