"""
========
Plotting
========
:Authors: Orion Cohen and Lauren Lee
:Year: 2023
:Copyright: GNU Public License v3
The plotting functions are a convenient way to visualize data by taking solutions
as their input and generating a Plotly.Figure object.
"""
from typing import Union, Optional, Any, Callable
from copy import deepcopy
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from solvation_analysis.solute import Solute
from solvation_analysis.networking import Networking
from solvation_analysis.speciation import Speciation
# single solution
[docs]
def plot_network_size_histogram(networking: Union[Networking, Solute]) -> go.Figure:
"""
Plot a histogram of network sizes.
Parameters
----------
networking : Networking | Solution
Returns
-------
fig : Plotly.Figure
"""
if isinstance(networking, Solute):
if not hasattr(networking, "networking"):
raise ValueError("Solute networking analysis class must be instantiated.")
networking = networking.networking
network_sizes = networking.network_sizes
sums = network_sizes.sum(axis=0)
total_networks = sums.sum()
fig = go.Figure()
fig.add_trace(go.Bar(x=sums.index, y=sums.values / total_networks))
fig.update_layout(
xaxis_title_text="Network Size",
yaxis_title_text="Fraction of All Networks",
title="Histogram of Network Sizes",
template="plotly_white",
)
fig.update_xaxes(type="category")
return fig
[docs]
def plot_shell_composition_by_size(speciation: Union[Speciation, Solute]) -> go.Figure:
"""
Plot the composition of shells broken down by shell size.
Parameters
----------
speciation : Speciation | Solution
Returns
-------
fig : Plotly.Figure
"""
if isinstance(speciation, Solute):
if not hasattr(speciation, "speciation"):
raise ValueError("Solute speciation analysis class must be instantiated.")
speciation = speciation.speciation
speciation_data = speciation.speciation_data.copy()
speciation_data["total"] = speciation_data.sum(axis=1)
sums = speciation_data.groupby("total").sum()
fig = go.Figure()
totals = sums.T.sum()
for column in sums.columns:
fig.add_trace(
go.Bar(x=sums.index.values, y=sums[column].values / totals, name=column)
)
fig.update_layout(
xaxis_title_text="Shell Size",
yaxis_title_text="Fraction of Total Molecules",
title="Fraction of Solvents in Shells of Different Sizes",
template="plotly_white",
)
fig.update_xaxes(type="category")
return fig
[docs]
def plot_co_occurrence(
speciation: Union[Speciation, Solute], colorscale: Optional[Any] = None
) -> go.Figure:
"""
Plot the co-occurrence matrix of the solute using Plotly.
Co-occurrence represents the extent to which solvents occur with each other
relative to random. Values higher than 1 mean that solvents occur together
more often than random and values lower than 1 mean solvents occur together
less often than random. "Random" is calculated based on the total number of
solvents participating in solvation, it ignores solvents in the diluent.
Args
----
speciation: Speciation | Solution
colorscale : any valid argument to Plotly colorscale.
Returns
-------
fig : plotly.graph_objs.Figure
"""
if isinstance(speciation, Solute):
if not hasattr(speciation, "speciation"):
raise ValueError("Solute speciation analysis class must be instantiated.")
speciation = speciation.speciation
solvent_names = speciation.speciation_data.columns.values
if colorscale:
colorscale = colorscale
else:
min_val = speciation.solvent_co_occurrence.min().min()
max_val = speciation.solvent_co_occurrence.max().max()
range_val = max_val - min_val
colorscale = [
[0, "rgb(67,147,195)"],
[(1 - min_val) / range_val, "white"],
[1, "rgb(214,96,77)"],
]
# Create a heatmap trace with text annotations
trace = go.Heatmap(
x=solvent_names,
y=solvent_names[::-1], # Reverse the order of the y-axis labels
z=speciation.solvent_co_occurrence.values, # Keep the data in the original order
text=speciation.solvent_co_occurrence.round(2).to_numpy(dtype=str),
# Keep the text annotations in the original order
hoverinfo="none",
colorscale=colorscale,
)
# Update layout to display tick labels and text annotations
layout = go.Layout(
title="Solvent Co-Occurrence Matrix",
xaxis=dict(
tickmode="array",
tickvals=list(range(len(solvent_names))),
ticktext=solvent_names,
tickangle=-30,
side="top",
),
yaxis=dict(
tickmode="array",
tickvals=list(range(len(solvent_names))),
ticktext=solvent_names,
autorange="reversed",
),
margin=dict(l=60, r=60, b=60, t=100, pad=4),
annotations=[
dict(
x=i,
y=j,
text=str(round(speciation.solvent_co_occurrence.iloc[j, i], 2)),
font=dict(size=14, color="black"),
showarrow=False,
)
for i in range(len(solvent_names))
for j in range(len(solvent_names))
],
)
# Create and return the Figure object
fig = go.Figure(data=[trace], layout=layout)
return fig
def _make_rectangle(x: float, y: float, color: str) -> dict:
"""
Create a rectangle shape for Plotly.
Parameters
----------
x : float
The x-coordinate of the center of the rectangle.
y : float
The y-coordinate of the center of the rectangle.
color : str
The color of the rectangle.
Returns
-------
go.layout.Shape
The rectangle shape for Plotly.
"""
x0 = x - 0.18
y0 = y - 0.43
x1 = x + 0.18
y1 = y + 0.43
h = 0.09
rounded_bottom_left = f" M {x0 + h}, {y0} Q {x0}, {y0} {x0}, {y0 + h}" #
rounded_top_left = f" L {x0}, {y1 - h} Q {x0}, {y1} {x0 + h}, {y1}"
rounded_top_right = f" L {x1 - h}, {y1} Q {x1}, {y1} {x1}, {y1 - h}"
rounded_bottom_right = f" L {x1}, {y0 + h} Q {x1}, {y0} {x1 - h}, {y0}Z"
path = (
rounded_bottom_left
+ rounded_top_left
+ rounded_top_right
+ rounded_bottom_right
)
return dict(
type="path",
path=path,
line=dict(color=color, width=2),
fillcolor=color,
layer="between",
)
def _get_shell_name(row):
result = []
for column, value in row.items():
result.append(f"{column} {value}")
return "<br>".join(result)
[docs]
def plot_speciation(
speciation: Union[Speciation, Solute], shells: int = 6
) -> go.Figure:
"""
Plot the solvation shell composition and fraction for the top solvation shells.
Parameters
----------
speciation : Speciation or Solute
The Speciation or Solute object containing the speciation data.
shells : int, optional
The number of top solvation shells to plot. Default is 6.
Returns
-------
fig : plotly.graph_objs.Figure
The plot of the solvation shell composition and fraction.
"""
if isinstance(speciation, Solute):
if not hasattr(speciation, "speciation"):
raise ValueError("Solute speciation analysis class must be instantiated.")
speciation = speciation.speciation
# Extract relevant data
df = speciation.speciation_fraction.head(shells)
fraction_data = df["fraction"]
df = df.drop("fraction", axis=1)
# Get unique solvents and assign colors
solvents = df.columns.tolist() # List of solvents
colors = px.colors.qualitative.Plotly # Get a list of Plotly's qualitative colors
# If there are more solvents than colors, cycle through the colors again
if len(solvents) > len(colors):
colors = colors * (
len(solvents) // len(colors) + 1
) # Repeat color list as needed
color_map = dict(zip(solvents, colors)) # Create a color map for solvents
# Prepare data for the plot
x_vals = []
y_vals = []
solvent_names = []
marker_colors = [] # To store color for each marker
shell_names = []
# Process each row to create stacks of points
for index, row in df.iterrows():
shell_names.append(_get_shell_name(row))
total_count = 0
for solvent, count in row.items():
for i in range(count):
x_vals.append(index)
y_vals.append(
0.5 + i + total_count
) # Place each solvent count at different y-levels
solvent_names.append(solvent)
marker_colors.append(
color_map[solvent]
) # Use the dynamically assigned color
total_count += count
# Create scatter plot of solvent squares,
trace1 = go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
marker=dict(size=25, color=marker_colors, opacity=0), # Apply colors to markers
text=solvent_names,
hoverinfo="text",
name="Solvents",
legendgroup="solvents",
showlegend=False,
)
trace2 = go.Scatter(
x=df.index,
y=fraction_data,
mode="lines+markers",
name="Fraction",
yaxis="y2",
line=dict(color="black"),
)
# Create the figure with two traces
fig = go.Figure(data=[trace1, trace2])
# Add traces for each solvent to create a legend
for solvent, color in color_map.items():
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(size=10, color=color),
name=solvent,
legendgroup="solvents",
showlegend=True,
)
)
# Add squares with rounded corners on top of the points using the shapes API
for x, y, color in zip(x_vals, y_vals, marker_colors):
fig.add_shape(**_make_rectangle(x, y, color))
# Update layout
fig.update_layout(
title="Top Solvation Shell Compositions",
xaxis_title="Solvation Shell",
# xaxis=dict(tickmode="linear", tick0=0, dtick=1), # Set x-axis ticks to integers
xaxis=dict(
tickmode="array",
tickvals=df.index,
ticktext=shell_names,
),
yaxis=dict(
title="Shell Size",
tickmode="array",
tickvals=list(range(1, int(max(y_vals)) + 1)),
range=[0, max(y_vals) + 1], # Scale the top of the y-axis
showgrid=False,
side="right",
),
yaxis2=dict(
title="Shell Fraction",
overlaying="y",
side="left",
range=[0, max(fraction_data) * 1.1], # Scale the fraction axis
),
template="plotly_white",
margin=dict(l=20, r=20, t=60, b=20), # Add padding to the edges of the plot
legend=dict(
orientation="h",
yanchor="bottom",
y=1,
xanchor="right",
x=1,
), # Add legend at the top
)
return fig
[docs]
def plot_rdfs(
solute: Solute,
show_cutoff: bool = True,
x_axis_solute: bool = False,
merge_on_x: bool = False,
merge_on_y: bool = False,
):
"""
Plot the radial distribution functions (RDFs) of solute-solvent pairs.
Parameters
----------
solute : Solute
The Solute object containing the RDF data.
show_cutoff : bool, optional
Whether to display the solvation radius cutoff lines. Default is True.
x_axis_solute : bool, optional
Whether to place the solute on the x-axis. Default is False.
merge_on_x : bool, optional
Whether to merge subplots along the x-axis. Default is False.
merge_on_y : bool, optional
Whether to merge subplots along the y-axis. Default is False.
Returns
-------
fig : plotly.graph_objs.Figure
The plot of the radial distribution functions.
"""
# Determine the grid dimensions based on merge settings
data = solute.rdf_data
n_cols = 1 if merge_on_y else len(data)
n_rows = 1 if merge_on_x else len(data[list(data.keys())[0]])
x_title, y_title = "Solvent", "Solute"
if x_axis_solute:
n_rows, n_cols = n_cols, n_rows
x_title, y_title = y_title, x_title
# Create subplots
fig = make_subplots(
rows=n_rows,
cols=n_cols,
shared_xaxes=True,
shared_yaxes=True,
x_title=x_title,
y_title=y_title,
)
# Create a color mapping dictionary
color_map = {}
colors = plotly.colors.qualitative.Plotly
# Iterate over the data and add traces to the subplots
for i, (key, values) in enumerate(data.items()):
for j, (sub_key, sub_values) in enumerate(values.items()):
x, y = sub_values
col = i * (not merge_on_y) + 1
row = j * (not merge_on_x) + 1
if x_axis_solute:
row, col = col, row
# Assign a color to the sub-key if not already assigned
if sub_key not in color_map:
show_legend = True
color_map[sub_key] = colors[len(color_map) % len(colors)]
else:
show_legend = False
fig.add_trace(
go.Scatter(
x=x,
y=y,
name=sub_key,
line=dict(color=color_map[sub_key]),
legendgroup=sub_key,
showlegend=show_legend,
),
row=row,
col=col,
)
fig.update_yaxes(title_text=key, row=row, col=1)
fig.update_xaxes(title_text=sub_key, row=n_rows, col=col)
# Update the layout
fig.update_layout(
title_text="Radial Distribution Functions of Solute-Solvent Pairs",
template="plotly_white",
margin=dict(
l=100,
b=80,
),
)
fig.update_annotations(x=0.5, y=-0.05, selector={"text": x_title})
fig.update_annotations(y=0.5, x=-0.03, selector={"text": y_title})
if not (merge_on_x or merge_on_y) and show_cutoff:
for col, solute in enumerate(solute.atom_solutes.values()):
for row, (solvent, radius) in enumerate(solute.radii.items()):
if x_axis_solute:
row, col = col, row
fig.add_vline(
x=radius,
row=row,
col=col,
label=dict(
text="solvation radius",
textposition="top center",
yanchor="top",
),
)
return fig
def compare_networking(solutions, series=False):
# valid_x_axis = set(["solvent", "solute"])
# assert x_axis in valid_x_axis, "x_axis must be equal to 'solute' or 'solvent'."
# x_label = x_label or x_axis
# legend_label = legend_label or (valid_x_axis - {x_axis}).pop()
property_dict = {}
for solute_name, solute in solutions.items():
if not hasattr(solute, "networking"):
raise ValueError("Solute networking analysis class must be instantiated.")
property_dict[solute_name] = solute.networking.solute_status
solvents_to_plot = ["isolated", "paired", "networked"]
fig = compare_solvent_dicts(
property_dict=property_dict,
rename_solvent_dict={},
solvents_to_plot=solvents_to_plot,
legend_label="Solute Status",
x_axis_solute=True,
series=series,
)
fig.update_layout(
xaxis_title_text="Solute",
yaxis_title_text="Solute Status Fraction",
title="Fraction of Solutes Isolated, Paired, and Networked",
template="plotly_white",
)
return fig
[docs]
def compare_solvent_dicts(
property_dict: dict[str, dict[str, float]],
rename_solvent_dict: dict[str, str],
solvents_to_plot: list[str],
legend_label: str,
x_axis_solute: str = False,
series: bool = False,
) -> go.Figure:
"""
A generic plotting function that can compare dictionary data between multiple solutes.
Parameters
----------
property_dict : dict of {str: dict}
a dictionary with the solution name as keys and a dict of {str: float} as values, where each key
is the name of the solvent of each solution and each value is the property of interest
rename_solvent_dict : dict of {str: str}
Renames solvents within the plot, useful for comparing similar solvents in different solutes.
The keys are the original solvent names and the values are the new name
e.g. {"EAf" : "EAx", "fEAf" : "EAx"}
solvents_to_plot : List[str]
List of solvent names to be plotted, they will be plotted in given order.
The solvents must be common to all systems in question. Renaming in `rename_solvent_dicts`
is applied first, so the solvent names in `solvents_to_plot should match the `values` of that dict.
legend_label : str
title of legend as a string
x_axis : str
the value must be "solvent" or "solute" and decides which to plot the x_axis
series : bool
False for a bar graph, True for a line graph
Returns
-------
fig : Plotly.Figure
"""
property_dict = deepcopy(property_dict)
# coerce solutions to a common name
for solution_name in rename_solvent_dict.keys():
if solution_name in property_dict:
common_name = rename_solvent_dict[solution_name]
# remove the solution name from the properties dict and rename to the common name
solution_property_value = property_dict[solution_name].pop(solution_name)
property_dict[solution_name][common_name] = solution_property_value
# filter out components of solution to only include those in solvents_to_plot
if solvents_to_plot:
all_solvents = [
set(solution_dict.keys()) for solution_dict in property_dict.values()
]
valid_solvents = set.intersection(*all_solvents)
if not set(solvents_to_plot).issubset(valid_solvents):
raise Exception(
f"solvents_to_plot must only include solvents that are "
f"present in all solutes. Valid values are {valid_solvents}."
)
for solute_name, solution_dict in property_dict.items():
new_solution_dict = {
solvent: value
for solvent, value in solution_dict.items()
if solvent in solvents_to_plot
}
property_dict[solute_name] = new_solution_dict
# generate figure and make a DataFrame of the data
fig = go.Figure()
df = pd.DataFrame(data=property_dict.values())
df.index = list(property_dict.keys())
if series and not x_axis_solute:
# each solution is a line
df = df.transpose()
fig = px.line(
df,
x=df.index,
y=df.columns,
labels={"variable": legend_label},
markers=True,
)
fig.update_xaxes(type="category")
elif series and x_axis_solute:
# each solvent is a line
fig = px.line(df, y=df.columns, labels={"variable": legend_label}, markers=True)
fig.update_xaxes(type="category")
elif not series and not x_axis_solute:
# each solution is a bar
df = df.transpose()
fig = px.bar(
df,
x=df.index,
y=df.columns,
barmode="group",
labels={"variable": legend_label},
)
elif not series and x_axis_solute:
# each solvent is a bar
fig = px.bar(
df,
y=df.columns,
barmode="group",
labels={"variable": legend_label},
)
return fig
def _compare_function_generator(
analysis_object: str,
attribute: str,
title: str,
top_level_docstring: str,
) -> Callable:
def compare_func(
solutions,
rename_solvent_dict=None,
solvents_to_plot=None,
x_axis_solute=False,
series=False,
title=title,
x_label=None,
y_label=attribute.replace("_", " ").title(),
legend_label=None,
):
x_axis = "solute" if x_axis_solute else "solvent"
x_label = x_label or x_axis
legend_label = legend_label or x_axis
property = {}
for solute_name, solute in solutions.items():
if not hasattr(solute, analysis_object):
raise ValueError(
f"Solute {analysis_object} analysis class must be instantiated."
)
property[solute_name] = getattr(getattr(solute, analysis_object), attribute)
rename_solvent_dict = rename_solvent_dict or {}
fig = compare_solvent_dicts(
property,
rename_solvent_dict,
solvents_to_plot,
legend_label.title(),
x_axis,
series,
)
fig.update_layout(
xaxis_title_text=x_label.title(),
yaxis_title_text=y_label.title(),
title=title.title(),
template="plotly_white",
)
return fig
arguments_docstring = """
property_dict : dict of {str: dict}
a dictionary with the solution name as keys and a dict of {str: float} as values, where each key
is the name of the solvent of each solution and each value is the property of interest
rename_solvent_dict : dict of {str: str}
Renames solvents within the plot, useful for comparing similar solvents in different solutes.
The keys are the original solvent names and the values are the new name
e.g. {"EAf" : "EAx", "fEAf" : "EAx"}
solvents_to_plot : List[str]
List of solvent names to be plotted, they will be plotted in given order.
The solvents must be common to all systems in question. Renaming in `rename_solvent_dicts`
is applied first, so the solvent names in `solvents_to_plot should match the `values` of that dict.
x_axis : str
the value must be "solvent" or "solute" and decides which to plot the x_axis
series : bool
False for a bar graph, True for a line graph
x_label : str
title of the x-axis
y_label : str
title of the y-axis
title : str
title of figure
legend_label : str
title of legend
Returns
-------
fig : Plotly.Figure
"""
compare_func.__doc__ = top_level_docstring + arguments_docstring
return compare_func
compare_pairing = _compare_function_generator(
"pairing",
"solvent_pairing",
"Fractional Pairing of Solvents",
"Compare the solute-solvent pairing.",
)
compare_free_solvents = _compare_function_generator(
"pairing",
"fraction_free_solvents",
"Free Solvents in Solutes",
"Compare the relative fraction of free solvents.",
)
compare_diluent = _compare_function_generator(
"pairing",
"diluent_composition",
"Diluent Composition of Solutes",
"Compare the diluent composition.",
)
compare_coordination_numbers = _compare_function_generator(
"coordination",
"coordination_numbers",
"Coordination Numbers of Solvents",
"Compare the coordination numbers.",
)
compare_coordination_vs_random = _compare_function_generator(
"coordination",
"coordination_vs_random",
"Coordination Compare to Random Distribution of Solvents",
"Compare the coordination numbers.",
)
compare_residence_times_cutoff = _compare_function_generator(
"residence",
"residence_times_cutoff",
"Solute-Solvent Residence Time",
"Compare the solute-solvent residence times.",
)
compare_residence_times_fit = _compare_function_generator(
"residence",
"residence_times_fit",
"Solute-Solvent Residence Time.",
"Compare the solute-solvent residence times.",
)
# TODO: work on rdfs; make them tiled
# this will have to be implemented post-merge
# use iba_small_solutes (will return a solute that has three atom solutes
# solvents are on one axis and solutions are on the other
def compare_rdfs(solutions, atoms):
# can atom groups be matched to solutions / universes behind the scenes?
# yes we can use atom.u is universe
return