Source code for quail.plot

from __future__ import division
from builtins import str
import numpy as np
import pandas as pd
import seaborn as sns
from .helpers import *
import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams['pdf.fonttype'] = 42

# Default font sizes
DEFAULT_LABEL_FONTSIZE = 14

[docs] def plot(results, subjgroup=None, subjname='Subject Group', listgroup=None, listname='List', subjconds=None, listconds=None, plot_type=None, plot_style=None, title=None, legend=None, xlim=None, ylim=None, save_path=None, show=True, ax=None, **kwargs): """ General plot function that groups data by subject/list number and performs analysis. Parameters ---------- results : quail.FriedEgg Object containing results subjgroup : list of strings or ints String/int variables indicating how to group over subjects. Must be the length of the number of subjects subjname : string Name of the subject grouping variable listgroup : list of strings or ints String/int variables indicating how to group over list. Must be the length of the number of lists listname : string Name of the list grouping variable subjconds : list List of subject hues (str) to plot listconds : list List of list hues (str) to plot plot_type : string Specifies the type of plot. If list (default), the list groupings (listgroup) will determine the plot grouping. If subject, the subject groupings (subjgroup) will determine the plot grouping. If split (currenty just works for accuracy plots), both listgroup and subjgroup will determine the plot groupings plot_style : string Specifies the style of the plot. This currently works only for accuracy and fingerprint plots. The plot style can be bar (default for accruacy plot), violin (default for fingerprint plots) or swarm. title : string The title of the plot legend : bool If true (default), a legend is plotted. ylim : list of numbers A ymin/max can be specified by a list of the form [ymin, ymax] xlim : list of numbers A xmin/max can be specified by a list of the form [xmin, xmax] save_path : str Path to save out figure. Include the file extension, e.g. save_path='figure.pdf' show : bool If False, do not show figure, but still return ax handle (default True). ax : Matplotlib.Axes object or None A plot object to draw to. If None, a new one is created and returned. Returns ---------- ax : matplotlib.Axes.Axis An axis handle for the figure """ sns.set_palette("viridis") plot_type = plot_type if plot_type is not None else 'list' def plot_acc(data, plot_style, plot_type, listname, subjname, **kwargs): # set defaul style to bar plot_style = plot_style if plot_style is not None else 'bar' plot_type = plot_type if plot_type is not None else 'list' if plot_style == 'bar': plot_func = sns.barplot elif plot_style == 'swarm': plot_func = sns.swarmplot elif plot_style == 'violin': plot_func = sns.violinplot if plot_type == 'list': ax = plot_func(data=data, x=listname, y="Accuracy", hue=listname, legend=False, **kwargs) elif plot_type == 'subject': ax = plot_func(data=data, x=subjname, y="Accuracy", **kwargs) elif plot_type == 'split': ax = plot_func(data=data, x=subjname, y="Accuracy", hue=listname, **kwargs) ax.set_ylabel("Accuracy", fontsize=DEFAULT_LABEL_FONTSIZE) ax.set_xlabel(ax.get_xlabel(), fontsize=DEFAULT_LABEL_FONTSIZE) return ax def plot_temporal(data, plot_style, plot_type, listname, subjname, **kwargs): # set default style to bar plot_style = plot_style if plot_style is not None else 'bar' plot_type = plot_type if plot_type is not None else 'list' if plot_style == 'bar': plot_func = sns.barplot elif plot_style == 'swarm': plot_func = sns.swarmplot elif plot_style == 'violin': plot_func = sns.violinplot if plot_type == 'list': ax = plot_func(data=data, x=listname, y="Temporal clustering score", hue=listname, legend=False, **kwargs) elif plot_type == 'subject': ax = plot_func(data=data, x=subjname, y="Temporal clustering score", **kwargs) elif plot_type == 'split': ax = plot_func(data=data, x=subjname, y="Temporal clustering score", hue=listname, **kwargs) ax.set_ylabel("Temporal clustering score", fontsize=DEFAULT_LABEL_FONTSIZE) ax.set_xlabel(ax.get_xlabel(), fontsize=DEFAULT_LABEL_FONTSIZE) return ax def plot_fingerprint(data, plot_style, plot_type, listname, subjname, **kwargs): # set default style to violin plot_style = plot_style if plot_style is not None else 'violin' plot_type = plot_type if plot_type is not None else 'list' if plot_style == 'bar': plot_func = sns.barplot elif plot_style == 'swarm': plot_func = sns.swarmplot elif plot_style == 'violin': plot_func = sns.violinplot if plot_type == 'list': ax = plot_func(data=tidy_data, x="Feature", y="Clustering Score", hue=listname, legend=legend, **kwargs) elif plot_type == 'subject': ax = plot_func(data=tidy_data, x="Feature", y="Clustering Score", hue=subjname, legend=legend, **kwargs) else: ax = plot_func(data=tidy_data, x="Feature", y="Clustering Score", **kwargs) ax.set_ylabel("Clustering score", fontsize=DEFAULT_LABEL_FONTSIZE) ax.set_xlabel("Feature", fontsize=DEFAULT_LABEL_FONTSIZE) return ax def plot_fingerprint_temporal(data, plot_style, plot_type, listname, subjname, **kwargs): # set default style to violin plot_style = plot_style if plot_style is not None else 'violin' plot_type = plot_type if plot_type is not None else 'list' if plot_style == 'bar': plot_func = sns.barplot elif plot_style == 'swarm': plot_func = sns.swarmplot elif plot_style == 'violin': plot_func = sns.violinplot order = list(tidy_data['Feature'].unique()) if plot_type == 'list': ax = plot_func(data=data, x="Feature", y="Clustering Score", hue=listname, order=order, **kwargs) elif plot_type == 'subject': ax = plot_func(data=data, x="Feature", y="Clustering Score", hue=subjname, order=order, **kwargs) else: ax = plot_func(data=data, x="Feature", y="Clustering Score", order=order, **kwargs) ax.set_ylabel("Clustering score", fontsize=DEFAULT_LABEL_FONTSIZE) ax.set_xlabel("Feature", fontsize=DEFAULT_LABEL_FONTSIZE) return ax def plot_spc(data, plot_style, plot_type, listname, subjname, **kwargs): plot_type = plot_type if plot_type is not None else 'list' if plot_type == 'subject': ax = sns.lineplot(data = data, x="Position", y="Proportion Recalled", hue=subjname, **kwargs) elif plot_type == 'list': ax = sns.lineplot(data = data, x="Position", y="Proportion Recalled", hue=listname, **kwargs) ax.set_xlim(0, data['Position'].max()) ax.set_ylabel("Proportion recalled", fontsize=DEFAULT_LABEL_FONTSIZE) ax.set_xlabel("Position", fontsize=DEFAULT_LABEL_FONTSIZE) return ax def plot_pnr(data, plot_style, plot_type, listname, subjname, position, list_length, **kwargs): plot_type = plot_type if plot_type is not None else 'list' if plot_type == 'subject': ax = sns.lineplot(data = data, x="Position", y='Probability of Recall: Position ' + str(position), hue=subjname, **kwargs) elif plot_type == 'list': ax = sns.lineplot(data = data, x="Position", y='Probability of Recall: Position ' + str(position), hue=listname, **kwargs) ax.set_xlim(0,list_length-1) ax.set_ylabel('Probability of recall: position ' + str(position), fontsize=DEFAULT_LABEL_FONTSIZE) ax.set_xlabel("Position", fontsize=DEFAULT_LABEL_FONTSIZE) return ax def plot_lagcrp(data, plot_style, plot_type, listname, subjname, **kwargs): plot_type = plot_type if plot_type is not None else 'list' if plot_type == 'subject': ax = sns.lineplot(data=data[data['Position']<0], x="Position", y="Conditional Response Probability", hue=subjname, legend=False, **kwargs) if 'ax' in kwargs: del kwargs['ax'] sns.lineplot(data=data[data['Position']>0], x="Position", y="Conditional Response Probability", hue=subjname, ax=ax, legend=False, **kwargs) elif plot_type == 'list': ax = sns.lineplot(data=data[data['Position']<0], x="Position", y="Conditional Response Probability", hue=listname, legend=False, **kwargs) if 'ax' in kwargs: del kwargs['ax'] sns.lineplot(data=data[data['Position']>0], x="Position", y="Conditional Response Probability", hue=listname, ax=ax, legend=False, **kwargs) if legend: # Deduplicate legend handles, labels = ax.get_legend_handles_labels() if handles: by_label = dict(zip(labels, handles)) title_text = subjname if plot_type == 'subject' else listname ax.legend(by_label.values(), by_label.keys(), title=title_text) ax.set_xlim(-5,5) ax.set_ylabel("Conditional response probability", fontsize=DEFAULT_LABEL_FONTSIZE) ax.set_xlabel("Lag", fontsize=DEFAULT_LABEL_FONTSIZE) return ax # if no grouping, set default to iterate over each list independently subjgroup = subjgroup if subjgroup is not None else results.data.index.levels[0].values listgroup = listgroup if listgroup is not None else results.data.index.levels[1].values if subjconds: # make sure its a list if type(subjconds) is not list: subjconds=[subjconds] # slice idx = pd.IndexSlice results.data = results.data.sort_index() results.data = results.data.loc[idx[subjconds, :],:] # filter subjgroup subjgroup = list(filter(lambda x: x in subjconds, subjgroup)) if listconds: # make sure its a list if type(listconds) is not list: listconds=[listconds] # slice idx = pd.IndexSlice results.data = results.data.sort_index() results.data = results.data.loc[idx[:, listconds],:] # convert to tiny and format for plotting tidy_data = format2tidy(results.data, subjname, listname, subjgroup, analysis=results.analysis, position=results.position) # Auto-suppress legend if only one group # Auto-suppress legend if only one group and user didn't specify if legend is None: legend = True try: if plot_type == 'list': # Check unique listnames if len(tidy_data[listname].unique()) <= 1: legend = False elif plot_type == 'subject': if len(tidy_data[subjname].unique()) <= 1: legend = False except: pass if not ax==None: kwargs['ax']=ax # plot! if results.analysis=='accuracy': ax = plot_acc(tidy_data, plot_style, plot_type, listname, subjname, **kwargs) elif results.analysis=='temporal': ax = plot_temporal(tidy_data, plot_style, plot_type, listname, subjname, **kwargs) elif results.analysis=='fingerprint': ax = plot_fingerprint(tidy_data, plot_style, plot_type, listname, subjname, **kwargs) elif results.analysis=='fingerprint_temporal': ax = plot_fingerprint_temporal(tidy_data, plot_style, plot_type, listname, subjname, **kwargs) elif results.analysis=='spc': ax = plot_spc(tidy_data, plot_style, plot_type, listname, subjname, **kwargs) elif results.analysis=='pfr' or results.analysis=='pnr': ax = plot_pnr(tidy_data, plot_style, plot_type, listname, subjname, position=results.position, list_length=results.list_length, **kwargs) elif results.analysis=='lagcrp': ax = plot_lagcrp(tidy_data, plot_style, plot_type, listname, subjname, **kwargs) else: raise ValueError("Did not recognize analysis.") # add title if title: plt.title(title) if legend is False: if ax.get_legend() is not None: ax.get_legend().remove() if xlim: plt.xlim(xlim) if ylim: plt.ylim(ylim) sns.despine(ax=ax, top=True, right=True) if save_path: mpl.rcParams['pdf.fonttype'] = 42 plt.savefig(save_path) return ax