#!/usr/bin/env python
from __future__ import division
from __future__ import print_function
from builtins import zip
from builtins import range
import six
import numpy as np
import pandas as pd
import warnings
import six
from ..helpers import *
from ..distance import dist_funcs as dist_funcs_dict
from .recmat import recall_matrix
from .accuracy import accuracy_helper
from .spc import spc_helper
from .pnr import pnr_helper
from .lagcrp import lagcrp_helper
from .clustering import fingerprint_helper
analyses = {
'accuracy' : accuracy_helper,
'spc' : spc_helper,
'pfr' : pnr_helper,
'pnr' : pnr_helper,
'lagcrp' : lagcrp_helper,
'fingerprint' : fingerprint_helper,
'temporal' : fingerprint_helper
}
# main analysis function
[docs]
def analyze(egg, subjgroup=None, listgroup=None, subjname='Subject',
listname='List', analysis=None, position=0, permute=False,
n_perms=1000, parallel=False, match='exact',
distance='euclidean', features=None, ts=None, n_jobs=-1):
"""
General analysis function that groups data by subject/list number and performs analysis.
Parameters
----------
egg : Egg data object
The data to be analyzed
subjgroup : list of strings or ints
String/int variables indicating how to group over subjects. Must be
the length of the number of subjects
subjname : string
Name of the subject grouping variable
listgroup : list of strings or ints
String/int variables indicating how to group over list. Must be
the length of the number of lists
listname : string
Name of the list grouping variable
analysis : string
This is the analysis you want to run. Can be accuracy, spc, pfr,
temporal or fingerprint
position : int
Optional argument for pnr analysis. Defines encoding position of item
to run pnr. Default is 0, and it is zero indexed
permute : bool
Optional argument for fingerprint/temporal cluster analyses. Determines
whether to correct clustering scores by shuffling recall order for each list
to create a distribution of clustering scores (for each feature). The
"corrected" clustering score is the proportion of clustering scores in
that random distribution that were lower than the clustering score for
the observed recall sequence. Default is False.
n_perms : int
Optional argument for fingerprint/temporal cluster analyses. Number of
permutations to run for "corrected" clustering scores. Default is 1000 (
per recall list).
parallel : bool
Option to use multiprocessing (this can help speed up the permutations
tests in the clustering calculations)
match : str (exact, best or smooth)
Matching approach to compute recall matrix. If exact, the presented and
recalled items must be identical (default). If best, the recalled item
that is most similar to the presented items will be selected. If smooth,
a weighted average of all presented items will be used, where the
weights are derived from the similarity between the recalled item and
each presented item.
distance : str
The distance function used to compare presented and recalled items.
Applies only to 'best' and 'smooth' matching approaches. Can be any
distance function supported by numpy.spatial.distance.cdist.
n_jobs : int
Number of parallel jobs for fingerprint analysis. Default is -1 (all cores).
Only used for fingerprint/temporal analyses when joblib is available.
Returns
----------
result : quail.FriedEgg
Class instance containing the analysis results
"""
if analysis is None:
raise ValueError('You must pass an analysis type.')
if analysis not in analyses.keys():
raise ValueError('Analysis not recognized. Choose one of the following: '
'accuracy, spc, pfr, lagcrp, fingerprint, temporal')
from ..egg import FriedEgg
if hasattr(egg, 'subjgroup'):
if egg.subjgroup is not None:
subjgroup = egg.subjgroup
if hasattr(egg, 'subjname'):
if egg.subjname is not None:
subjname = egg.subjname
if hasattr(egg, 'listgroup'):
if egg.listgroup is not None:
listgroup = egg.listgroup
if hasattr(egg, 'listname'):
if egg.listname is not None:
listname = egg.listname
if features is None:
features = egg.feature_names
opts = {
'subjgroup' : subjgroup,
'listgroup' : listgroup,
'subjname' : subjname,
'parallel' : parallel,
'match' : match,
'distance' : distance,
'features' : features,
'analysis_type' : analysis,
'analysis' : analyses[analysis]
}
if analysis == 'pfr':
opts.update({'position' : 0})
elif analysis == 'pnr':
opts.update({'position' : position})
if analysis == 'temporal':
opts.update({'features' : ['Temporal']})
if analysis in ['temporal', 'fingerprint']:
opts.update({'permute' : permute, 'n_perms' : n_perms, 'n_jobs' : n_jobs})
if analysis == 'lagcrp':
opts.update({'ts' : ts})
return FriedEgg(data=_analyze_chunk(egg, **opts), analysis=analysis,
list_length=egg.list_length, n_lists=egg.n_lists,
n_subjects=egg.n_subjects, position=position)
def _analyze_chunk(data, subjgroup=None, subjname='Subject', listgroup=None,
listname='List', analysis=None, analysis_type=None,
pass_features=False, features=None, parallel=False,
**kwargs):
"""
Private function that groups data by subject/list number and performs
analysis for a chunk of data.
Parameters
----------
data : Egg data object
The data to be analyzed
subjgroup : list of strings or ints
String/int variables indicating how to group over subjects. Must be
the length of the number of subjects
subjname : string
Name of the subject grouping variable
listgroup : list of strings or ints
String/int variables indicating how to group over list. Must be
the length of the number of lists
listname : string
Name of the list grouping variable
analysis : function
This function analyzes data and returns it.
pass_features : bool
Logical indicating whether the analyses uses the features field of the
Egg
Returns
----------
analyzed_data : Pandas DataFrame
DataFrame containing the analysis results
"""
# perform the analysis
def _analysis(c):
subj, lst = c
subjects = [s for s in subjdict[subj]]
lists = [l for l in listdict[subj][lst]]
s = data.crack(lists=lists, subjects=subjects)
index = pd.MultiIndex.from_arrays([[subj],[lst]], names=[subjname, listname])
opts = dict()
if analysis_type == 'fingerprint':
opts.update({'columns' : features})
elif analysis_type == 'lagcrp':
if kwargs['ts']:
opts.update({'columns' : range(-kwargs['ts'],kwargs['ts']+1)})
else:
opts.update({'columns' : range(-data.list_length,data.list_length+1)})
return pd.DataFrame([analysis(s, features=features, **kwargs)],
index=index, **opts)
subjgroup = subjgroup if subjgroup else data.pres.index.levels[0].values
listgroup = listgroup if listgroup else data.pres.index.levels[1].values
subjdict = {subj : data.pres.index.levels[0].values[subj==np.array(subjgroup)] for subj in set(subjgroup)}
if all(isinstance(el, list) for el in listgroup):
# Per-subject listgroup: listgroup is a list of lists, one per subject
# Map subject indices to their listgroup dictionaries
per_subject_listdict = []
for listgrpsub in listgroup:
ld = {lst : data.pres.index.levels[1].values[lst==np.array(listgrpsub)] for lst in set(listgrpsub)}
per_subject_listdict.append(ld)
# Create listdict keyed by subject group, mapping to the appropriate per-subject dict
# For each subject group, find the corresponding subject indices and their listdicts
listdict = {}
for subj_group in subjdict:
# Get subject indices that belong to this group
subj_indices_in_group = subjdict[subj_group]
if len(subj_indices_in_group) > 0:
# Use the first subject's listdict as representative for the group
# (assuming all subjects in a group have the same list groupings)
first_subj_idx = subj_indices_in_group[0]
if first_subj_idx < len(per_subject_listdict):
listdict[subj_group] = per_subject_listdict[first_subj_idx]
else:
# Fallback: use first listdict
listdict[subj_group] = per_subject_listdict[0]
else:
# Shared list grouping
ld = {lst : data.pres.index.levels[1].values[lst==np.array(listgroup)] for lst in set(listgroup)}
listdict = {subj : ld for subj in subjdict}
# Now listdict is always a dict keyed by subject group
chunks = [(subj, lst) for subj in subjdict for lst in listdict[subj]]
if parallel:
import multiprocessing
from pathos.multiprocessing import ProcessingPool as Pool
try:
p = Pool(multiprocessing.cpu_count())
res = p.map(_analysis, chunks)
finally:
p.close()
p.join()
p.clear()
else:
res = [_analysis(c) for c in chunks]
return pd.concat(res)