from __future__ import division
from builtins import str
import numpy as np
import pandas as pd
import seaborn as sns
from .helpers import *
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
# Default font sizes
DEFAULT_LABEL_FONTSIZE = 14
[docs]
def plot(results, subjgroup=None, subjname='Subject Group', listgroup=None,
listname='List', subjconds=None, listconds=None, plot_type=None,
plot_style=None, title=None, legend=None, xlim=None, ylim=None,
save_path=None, show=True, ax=None, **kwargs):
"""
General plot function that groups data by subject/list number and performs analysis.
Parameters
----------
results : quail.FriedEgg
Object containing results
subjgroup : list of strings or ints
String/int variables indicating how to group over subjects. Must be
the length of the number of subjects
subjname : string
Name of the subject grouping variable
listgroup : list of strings or ints
String/int variables indicating how to group over list. Must be
the length of the number of lists
listname : string
Name of the list grouping variable
subjconds : list
List of subject hues (str) to plot
listconds : list
List of list hues (str) to plot
plot_type : string
Specifies the type of plot. If list (default), the list groupings (listgroup)
will determine the plot grouping. If subject, the subject groupings
(subjgroup) will determine the plot grouping. If split (currenty just
works for accuracy plots), both listgroup and subjgroup will determine
the plot groupings
plot_style : string
Specifies the style of the plot. This currently works only for
accuracy and fingerprint plots. The plot style can be bar (default for
accruacy plot), violin (default for fingerprint plots) or swarm.
title : string
The title of the plot
legend : bool
If true (default), a legend is plotted.
ylim : list of numbers
A ymin/max can be specified by a list of the form [ymin, ymax]
xlim : list of numbers
A xmin/max can be specified by a list of the form [xmin, xmax]
save_path : str
Path to save out figure. Include the file extension, e.g.
save_path='figure.pdf'
show : bool
If False, do not show figure, but still return ax handle (default True).
ax : Matplotlib.Axes object or None
A plot object to draw to. If None, a new one is created and returned.
Returns
----------
ax : matplotlib.Axes.Axis
An axis handle for the figure
"""
sns.set_palette("viridis")
plot_type = plot_type if plot_type is not None else 'list'
def plot_acc(data, plot_style, plot_type, listname, subjname, **kwargs):
# set defaul style to bar
plot_style = plot_style if plot_style is not None else 'bar'
plot_type = plot_type if plot_type is not None else 'list'
if plot_style == 'bar':
plot_func = sns.barplot
elif plot_style == 'swarm':
plot_func = sns.swarmplot
elif plot_style == 'violin':
plot_func = sns.violinplot
if plot_type == 'list':
ax = plot_func(data=data, x=listname, y="Accuracy", hue=listname, legend=False, **kwargs)
elif plot_type == 'subject':
ax = plot_func(data=data, x=subjname, y="Accuracy", **kwargs)
elif plot_type == 'split':
ax = plot_func(data=data, x=subjname, y="Accuracy", hue=listname, **kwargs)
ax.set_ylabel("Accuracy", fontsize=DEFAULT_LABEL_FONTSIZE)
ax.set_xlabel(ax.get_xlabel(), fontsize=DEFAULT_LABEL_FONTSIZE)
return ax
def plot_temporal(data, plot_style, plot_type, listname, subjname, **kwargs):
# set default style to bar
plot_style = plot_style if plot_style is not None else 'bar'
plot_type = plot_type if plot_type is not None else 'list'
if plot_style == 'bar':
plot_func = sns.barplot
elif plot_style == 'swarm':
plot_func = sns.swarmplot
elif plot_style == 'violin':
plot_func = sns.violinplot
if plot_type == 'list':
ax = plot_func(data=data, x=listname, y="Temporal clustering score", hue=listname, legend=False, **kwargs)
elif plot_type == 'subject':
ax = plot_func(data=data, x=subjname, y="Temporal clustering score", **kwargs)
elif plot_type == 'split':
ax = plot_func(data=data, x=subjname, y="Temporal clustering score", hue=listname, **kwargs)
ax.set_ylabel("Temporal clustering score", fontsize=DEFAULT_LABEL_FONTSIZE)
ax.set_xlabel(ax.get_xlabel(), fontsize=DEFAULT_LABEL_FONTSIZE)
return ax
def plot_fingerprint(data, plot_style, plot_type, listname, subjname, **kwargs):
# set default style to violin
plot_style = plot_style if plot_style is not None else 'violin'
plot_type = plot_type if plot_type is not None else 'list'
if plot_style == 'bar':
plot_func = sns.barplot
elif plot_style == 'swarm':
plot_func = sns.swarmplot
elif plot_style == 'violin':
plot_func = sns.violinplot
if plot_type == 'list':
ax = plot_func(data=tidy_data, x="Feature", y="Clustering Score", hue=listname, legend=legend, **kwargs)
elif plot_type == 'subject':
ax = plot_func(data=tidy_data, x="Feature", y="Clustering Score", hue=subjname, legend=legend, **kwargs)
else:
ax = plot_func(data=tidy_data, x="Feature", y="Clustering Score", **kwargs)
ax.set_ylabel("Clustering score", fontsize=DEFAULT_LABEL_FONTSIZE)
ax.set_xlabel("Feature", fontsize=DEFAULT_LABEL_FONTSIZE)
return ax
def plot_fingerprint_temporal(data, plot_style, plot_type, listname, subjname, **kwargs):
# set default style to violin
plot_style = plot_style if plot_style is not None else 'violin'
plot_type = plot_type if plot_type is not None else 'list'
if plot_style == 'bar':
plot_func = sns.barplot
elif plot_style == 'swarm':
plot_func = sns.swarmplot
elif plot_style == 'violin':
plot_func = sns.violinplot
order = list(tidy_data['Feature'].unique())
if plot_type == 'list':
ax = plot_func(data=data, x="Feature", y="Clustering Score", hue=listname, order=order, **kwargs)
elif plot_type == 'subject':
ax = plot_func(data=data, x="Feature", y="Clustering Score", hue=subjname, order=order, **kwargs)
else:
ax = plot_func(data=data, x="Feature", y="Clustering Score", order=order, **kwargs)
ax.set_ylabel("Clustering score", fontsize=DEFAULT_LABEL_FONTSIZE)
ax.set_xlabel("Feature", fontsize=DEFAULT_LABEL_FONTSIZE)
return ax
def plot_spc(data, plot_style, plot_type, listname, subjname, **kwargs):
plot_type = plot_type if plot_type is not None else 'list'
if plot_type == 'subject':
ax = sns.lineplot(data = data, x="Position", y="Proportion Recalled", hue=subjname, **kwargs)
elif plot_type == 'list':
ax = sns.lineplot(data = data, x="Position", y="Proportion Recalled", hue=listname, **kwargs)
ax.set_xlim(0, data['Position'].max())
ax.set_ylabel("Proportion recalled", fontsize=DEFAULT_LABEL_FONTSIZE)
ax.set_xlabel("Position", fontsize=DEFAULT_LABEL_FONTSIZE)
return ax
def plot_pnr(data, plot_style, plot_type, listname, subjname, position, list_length, **kwargs):
plot_type = plot_type if plot_type is not None else 'list'
if plot_type == 'subject':
ax = sns.lineplot(data = data, x="Position", y='Probability of Recall: Position ' + str(position), hue=subjname, **kwargs)
elif plot_type == 'list':
ax = sns.lineplot(data = data, x="Position", y='Probability of Recall: Position ' + str(position), hue=listname, **kwargs)
ax.set_xlim(0,list_length-1)
ax.set_ylabel('Probability of recall: position ' + str(position), fontsize=DEFAULT_LABEL_FONTSIZE)
ax.set_xlabel("Position", fontsize=DEFAULT_LABEL_FONTSIZE)
return ax
def plot_lagcrp(data, plot_style, plot_type, listname, subjname, **kwargs):
plot_type = plot_type if plot_type is not None else 'list'
if plot_type == 'subject':
ax = sns.lineplot(data=data[data['Position']<0], x="Position", y="Conditional Response Probability", hue=subjname, legend=False, **kwargs)
if 'ax' in kwargs:
del kwargs['ax']
sns.lineplot(data=data[data['Position']>0], x="Position", y="Conditional Response Probability", hue=subjname, ax=ax, legend=False, **kwargs)
elif plot_type == 'list':
ax = sns.lineplot(data=data[data['Position']<0], x="Position", y="Conditional Response Probability", hue=listname, legend=False, **kwargs)
if 'ax' in kwargs:
del kwargs['ax']
sns.lineplot(data=data[data['Position']>0], x="Position", y="Conditional Response Probability", hue=listname, ax=ax, legend=False, **kwargs)
if legend:
# Deduplicate legend
handles, labels = ax.get_legend_handles_labels()
if handles:
by_label = dict(zip(labels, handles))
title_text = subjname if plot_type == 'subject' else listname
ax.legend(by_label.values(), by_label.keys(), title=title_text)
ax.set_xlim(-5,5)
ax.set_ylabel("Conditional response probability", fontsize=DEFAULT_LABEL_FONTSIZE)
ax.set_xlabel("Lag", fontsize=DEFAULT_LABEL_FONTSIZE)
return ax
# if no grouping, set default to iterate over each list independently
subjgroup = subjgroup if subjgroup is not None else results.data.index.levels[0].values
listgroup = listgroup if listgroup is not None else results.data.index.levels[1].values
if subjconds:
# make sure its a list
if type(subjconds) is not list:
subjconds=[subjconds]
# slice
idx = pd.IndexSlice
results.data = results.data.sort_index()
results.data = results.data.loc[idx[subjconds, :],:]
# filter subjgroup
subjgroup = list(filter(lambda x: x in subjconds, subjgroup))
if listconds:
# make sure its a list
if type(listconds) is not list:
listconds=[listconds]
# slice
idx = pd.IndexSlice
results.data = results.data.sort_index()
results.data = results.data.loc[idx[:, listconds],:]
# convert to tiny and format for plotting
tidy_data = format2tidy(results.data, subjname, listname, subjgroup, analysis=results.analysis, position=results.position)
# Auto-suppress legend if only one group
# Auto-suppress legend if only one group and user didn't specify
if legend is None:
legend = True
try:
if plot_type == 'list':
# Check unique listnames
if len(tidy_data[listname].unique()) <= 1:
legend = False
elif plot_type == 'subject':
if len(tidy_data[subjname].unique()) <= 1:
legend = False
except:
pass
if not ax==None:
kwargs['ax']=ax
# plot!
if results.analysis=='accuracy':
ax = plot_acc(tidy_data, plot_style, plot_type, listname, subjname, **kwargs)
elif results.analysis=='temporal':
ax = plot_temporal(tidy_data, plot_style, plot_type, listname, subjname, **kwargs)
elif results.analysis=='fingerprint':
ax = plot_fingerprint(tidy_data, plot_style, plot_type, listname, subjname, **kwargs)
elif results.analysis=='fingerprint_temporal':
ax = plot_fingerprint_temporal(tidy_data, plot_style, plot_type, listname, subjname, **kwargs)
elif results.analysis=='spc':
ax = plot_spc(tidy_data, plot_style, plot_type, listname, subjname, **kwargs)
elif results.analysis=='pfr' or results.analysis=='pnr':
ax = plot_pnr(tidy_data, plot_style, plot_type, listname, subjname, position=results.position, list_length=results.list_length, **kwargs)
elif results.analysis=='lagcrp':
ax = plot_lagcrp(tidy_data, plot_style, plot_type, listname, subjname, **kwargs)
else:
raise ValueError("Did not recognize analysis.")
# add title
if title:
plt.title(title)
if legend is False:
if ax.get_legend() is not None:
ax.get_legend().remove()
if xlim:
plt.xlim(xlim)
if ylim:
plt.ylim(ylim)
sns.despine(ax=ax, top=True, right=True)
if save_path:
mpl.rcParams['pdf.fonttype'] = 42
plt.savefig(save_path)
return ax