"""
Functions for aggregating subject data from
Cognitive Battery (https://github.com/sho-87/cognitive-battery)
"""
import os
import pandas as pd
[docs]def aggregate_wide(dir_battery, dir_output, response_type="full",
use_file=False, save=True):
"""
Aggregate data from all battery tasks.
Takes a directory containing individual subject data files created from
the Cognitive Battery, and calculates summary statistics for all
subjects across all tasks. A single output summary file is created
containing the aggregated battery data.
Parameters
----------
dir_battery : str
Path to the directory containing subject data files created by the
Cognitive Battery.
dir_output : str
Path to the directory where the output summary file will be saved. A
filed named 'battery_data.csv' will be created in this directory.
response_type : {'full', 'correct', 'incorrect'}, optional
Should the summary data be calculated using all trials? Only correct
trials? Or only incorrect trials? This is not supported in all tasks.
use_file : bool, optional
If True, aggregated battery data will be imported from the existing
summary file instead of being re-aggregated.
save : bool, optional
Set to True to save an output summary file to the output directory.
If False, then no file will be saved, but a dataframe will still be
returned from this function.
Returns
-------
all_data : dataframe
Pandas dataframe containing the aggregated summary data for all tasks.
"""
if use_file:
if os.path.isfile(os.path.join(dir_output,
"battery_data_{}.csv".format(response_type))):
print("Importing battery summary file...")
return pd.read_csv(os.path.join(dir_output,
"battery_data_{}.csv".format(response_type)))
# Create dataframes
df_info = pd.DataFrame(columns=["sub_num", "datetime",
"condition", "age", "sex", "RA"])
df_ant = pd.DataFrame(columns=["sub_num",
"ant_follow_error_rt", "ant_follow_correct_rt",
"ant_neutral_rt", "ant_congruent_rt", "ant_incongruent_rt",
"ant_neutral_rtsd", "ant_congruent_rtsd", "ant_incongruent_rtsd",
"ant_neutral_rtcov", "ant_congruent_rtcov", "ant_incongruent_rtcov",
"ant_neutral_correct", "ant_congruent_correct", "ant_incongruent_correct",
"ant_nocue_rt", "ant_center_rt", "ant_spatial_rt", "ant_double_rt",
"ant_nocue_rtsd", "ant_center_rtsd", "ant_spatial_rtsd", "ant_double_rtsd",
"ant_nocue_rtcov", "ant_center_rtcov", "ant_spatial_rtcov", "ant_double_rtcov",
"ant_nocue_correct", "ant_center_correct", "ant_spatial_correct", "ant_double_correct",
"ant_conflict_intercept", "ant_conflict_slope", "ant_conflict_slope_norm",
"ant_alerting_intercept", "ant_alerting_slope", "ant_alerting_slope_norm",
"ant_orienting_intercept", "ant_orienting_slope", "ant_orienting_slope_norm"])
df_flanker_compat = pd.DataFrame(columns=["sub_num",
"flanker_compat_follow_error_rt", "flanker_compat_follow_correct_rt",
"flanker_compat_congruent_rt", "flanker_compat_incongruent_rt",
"flanker_compat_congruent_rtsd", "flanker_compat_incongruent_rtsd",
"flanker_compat_congruent_rtcov", "flanker_compat_incongruent_rtcov",
"flanker_compat_congruent_correct", "flanker_compat_incongruent_correct",
"flanker_compat_conflict_intercept", "flanker_compat_conflict_slope",
"flanker_compat_conflict_slope_norm"])
df_flanker_incompat = pd.DataFrame(columns=["sub_num",
"flanker_incompat_follow_error_rt", "flanker_incompat_follow_correct_rt",
"flanker_incompat_congruent_rt", "flanker_incompat_incongruent_rt",
"flanker_incompat_congruent_rtsd", "flanker_incompat_incongruent_rtsd",
"flanker_incompat_congruent_rtcov", "flanker_incompat_incongruent_rtcov",
"flanker_incompat_congruent_correct", "flanker_incompat_incongruent_correct",
"flanker_incompat_conflict_intercept", "flanker_incompat_conflict_slope",
"flanker_incompat_conflict_slope_norm"])
df_flanker_both = df_flanker_compat.merge(df_flanker_incompat, on="sub_num")
cols = list(df_flanker_both.columns.values)
cols.pop(cols.index("sub_num")) # Remove sub_num from list
df_flanker_both = df_flanker_both[["sub_num"] + cols]
df_digit = pd.DataFrame(columns=["sub_num", "digit_correct_count",
"digit_correct_prop", "digit_num_items"])
df_mrt = pd.DataFrame(columns=["sub_num", "mrt_count",
"mrt_prop", "mrt_num_items"])
df_ravens = pd.DataFrame(columns=["sub_num", "ravens_rt", "ravens_count",
"ravens_prop", "ravens_num_items"])
df_sart = pd.DataFrame(
columns=["sub_num", "sart_follow_error_rt",
"sart_follow_correct_rt", "sart_total_rt",
"sart_total_rtsd", "sart_total_rtcov",
"sart_frequent_rt", "sart_frequent_rtsd",
"sart_frequent_rtcov", "sart_infrequent_rt",
"sart_infrequent_rtsd", "sart_infrequent_rtcov",
"sart_error_count"," sart_errors_prop",
"sart_errors_num_items"])
df_sternberg = pd.DataFrame(
columns=["sub_num",
"stern_follow_error_rt", "stern_follow_correct_rt",
"stern_set_2_rt", "stern_set_6_rt",
"stern_set_2_rtsd", "stern_set_6_rtsd",
"stern_set_2_rtcov", "stern_set_6_rtcov",
"stern_set_2_correct", "stern_set_6_correct",
"stern_intercept", "stern_slope", "stern_slope_norm"])
# Aggregate all data
for f in os.listdir(dir_battery):
if f.endswith(".xls"):
print("Summarizing {}".format(f))
sub = pd.read_excel(os.path.join(dir_battery, f), None,
converters={"sub_num":str})
try:
sub_num = sub["info"].loc[0,"sub_num"]
except KeyError:
sub_num = sub["info"].loc[0,"subNum"]
datetime = sub["info"].loc[0,"datetime"]
condition = int(sub["info"]["condition"])
age = int(sub["info"]["age"])
sex = sub["info"].loc[0,"sex"]
ra = sub["info"].loc[0,"RA"]
for task, data in sub.items():
if task == "info":
df_info.loc[df_info.shape[0]] = [sub_num, datetime, condition, age, sex, ra]
elif task == "ANT":
# full / correct / incorrect
df_ant.loc[df_ant.shape[0]] = aggregate_ant(data, sub_num, response_type)
elif task == "Digit span (backwards)":
df_digit.loc[df_digit.shape[0]] = aggregate_digit_span(data, sub_num)
elif task == "Eriksen Flanker":
compat_conditions = data["compatibility"].unique()
# full / correct / incorrect
if len(compat_conditions) == 1 and compat_conditions == "compatible":
df_flanker_compat.loc[df_flanker_compat.shape[0]] = aggregate_flanker(data, sub_num, response_type)
elif len(compat_conditions) == 1 and compat_conditions == "incompatible":
df_flanker_incompat.loc[df_flanker_incompat.shape[0]] = aggregate_flanker(data, sub_num, response_type)
else:
df_flanker_both.loc[df_flanker_both.shape[0]] = aggregate_flanker(data, sub_num, response_type)
elif task == "MRT":
df_mrt.loc[df_mrt.shape[0]] = aggregate_mrt(data, sub_num)
elif task == "Ravens Matrices":
df_ravens.loc[df_ravens.shape[0]] = aggregate_ravens(data, sub_num)
elif task == "SART":
df_sart.loc[df_sart.shape[0]] = aggregate_sart(data, sub_num)
elif task == "Sternberg":
# full / correct / incorrect
df_sternberg.loc[df_sternberg.shape[0]] = aggregate_sternberg(data, sub_num, response_type)
# Merge task data
# Only merge tasks that were used
tasks = [df_ant, df_digit, df_flanker_compat, df_flanker_incompat,
df_flanker_both, df_mrt, df_ravens, df_sart, df_sternberg]
all_data = df_info
for task in tasks:
if task.shape[0] != 0:
all_data = all_data.merge(task, on="sub_num", how="left")
all_data['sub_num'] = all_data['sub_num'].astype(int)
all_data = all_data.sort_values("sub_num").reset_index(drop=True)
# Save output csv
if save:
all_data.to_csv(os.path.join(dir_output,
"battery_data_{}.csv".format(response_type)),
index=False, sep=",")
return all_data
[docs]def aggregate_ant(data, sub_num, response_type="full"):
"""
Aggregate data from the ANT task.
Calculates various summary statistics for the ANT task for a given subject.
Parameters
----------
data : dataframe
Pandas dataframe containing a single subjects trial data for the task.
sub_num : str
Subject number to which the data file belongs.
response_type : {'full', 'correct', 'incorrect'}, optional
Should the summary data be calculated using all trials? Only correct
trials? Or only incorrect trials? This is not supported in all tasks.
Returns
-------
stats : list
List containing the calculated data for the subject.
"""
# Calculate times following errors and correct responses
df = data
follow_error_rt = df.loc[df.correct.shift() == 0, "RT"].mean()
follow_correct_rt = df.loc[df.correct.shift() == 1, "RT"].mean()
if response_type == "correct":
df = data[data["correct"] == 1]
elif response_type == "incorrect":
df = data[data["correct"] == 0]
elif response_type == "full":
df = data
# Aggregated descriptives
## congruency conditions
grouped_congruency = df.groupby("congruency")
neutral_rt = grouped_congruency.mean().get_value("neutral", "RT")
congruent_rt = grouped_congruency.mean().get_value("congruent", "RT")
incongruent_rt = grouped_congruency.mean().get_value("incongruent", "RT")
neutral_rtsd = grouped_congruency.std().get_value("neutral", "RT")
congruent_rtsd = grouped_congruency.std().get_value("congruent", "RT")
incongruent_rtsd = grouped_congruency.std().get_value("incongruent", "RT")
neutral_rtcov = neutral_rtsd / neutral_rt
congruent_rtcov = congruent_rtsd / congruent_rt
incongruent_rtcov = incongruent_rtsd / incongruent_rt
neutral_correct = grouped_congruency.sum().get_value("neutral", "correct")
congruent_correct = grouped_congruency.sum().get_value("congruent", "correct")
incongruent_correct = grouped_congruency.sum().get_value("incongruent", "correct")
## cue conditions
grouped_cue = df.groupby("cue")
nocue_rt = grouped_cue.mean().get_value("nocue", "RT")
center_rt = grouped_cue.mean().get_value("center", "RT")
spatial_rt = grouped_cue.mean().get_value("spatial", "RT")
double_rt = grouped_cue.mean().get_value("double", "RT")
nocue_rtsd = grouped_cue.std().get_value("nocue", "RT")
center_rtsd = grouped_cue.std().get_value("center", "RT")
spatial_rtsd = grouped_cue.std().get_value("spatial", "RT")
double_rtsd = grouped_cue.std().get_value("double", "RT")
nocue_rtcov = nocue_rtsd / nocue_rt
center_rtcov = center_rtsd / center_rt
spatial_rtcov = spatial_rtsd / spatial_rt
double_rtcov = double_rtsd / double_rt
nocue_correct = grouped_cue.sum().get_value("nocue", "correct")
center_correct = grouped_cue.sum().get_value("center", "correct")
spatial_correct = grouped_cue.sum().get_value("spatial", "correct")
double_correct = grouped_cue.sum().get_value("double", "correct")
# OLS regression
conflict_intercept, conflict_slope = congruent_rt, incongruent_rt - congruent_rt
conflict_slope_norm = conflict_slope / congruent_rt
alerting_intercept, alerting_slope = double_rt, nocue_rt - double_rt
alerting_slope_norm = alerting_slope / double_rt
orienting_intercept, orienting_slope = spatial_rt, center_rt - spatial_rt
orienting_slope_norm = orienting_slope / spatial_rt
return [sub_num,
follow_error_rt, follow_correct_rt,
neutral_rt, congruent_rt, incongruent_rt,
neutral_rtsd, congruent_rtsd, incongruent_rtsd,
neutral_rtcov, congruent_rtcov, incongruent_rtcov,
neutral_correct, congruent_correct, incongruent_correct,
nocue_rt, center_rt, spatial_rt, double_rt,
nocue_rtsd, center_rtsd, spatial_rtsd, double_rtsd,
nocue_rtcov, center_rtcov, spatial_rtcov, double_rtcov,
nocue_correct, center_correct, spatial_correct, double_correct,
conflict_intercept, conflict_slope, conflict_slope_norm,
alerting_intercept, alerting_slope, alerting_slope_norm,
orienting_intercept, orienting_slope, orienting_slope_norm]
[docs]def aggregate_digit_span(data, sub_num):
"""
Aggregate data from the digit span task.
Calculates various summary statistics for the digit span task for a
given subject.
Parameters
----------
data : dataframe
Pandas dataframe containing a single subjects trial data for the task.
sub_num : str
Subject number to which the data file belongs.
Returns
-------
stats : list
List containing the calculated data for the subject.
"""
digit_correct_count = data["correct"].sum()
digit_correct_num_items = data.shape[0]
digit_correct_prop = digit_correct_count / digit_correct_num_items
return [sub_num, digit_correct_count, digit_correct_prop,
digit_correct_num_items]
[docs]def aggregate_flanker(data, sub_num, response_type="full"):
"""
Aggregate data from the Flanker task.
Calculates various summary statistics for the Flanker task for a
given subject.
Parameters
----------
data : dataframe
Pandas dataframe containing a single subjects trial data for the task.
sub_num : str
Subject number to which the data file belongs.
response_type : {'full', 'correct', 'incorrect'}, optional
Should the summary data be calculated using all trials? Only correct
trials? Or only incorrect trials? This is not supported in all tasks.
Returns
-------
stats : list
List containing the calculated data for the subject.
"""
columns = [sub_num]
# split compatibility conditions
for comp_type in sorted(list(data["compatibility"].unique()), reverse=False):
df_cur = data[data["compatibility"] == comp_type]
# Calculate times following errors and correct responses
follow_error_rt = df_cur.loc[df_cur.correct.shift() == 0, "RT"].mean()
follow_correct_rt = df_cur.loc[df_cur.correct.shift() == 1, "RT"].mean()
if response_type == "correct":
df = df_cur[df_cur["correct"] == 1]
elif response_type == "incorrect":
df = df_cur[df_cur["correct"] == 0]
elif response_type == "full":
df = df_cur
grouped_congruency = df.groupby(["congruency"])
congruent_rt = grouped_congruency.mean().get_value("congruent", "RT")
incongruent_rt = grouped_congruency.mean().get_value("incongruent", "RT")
congruent_rtsd = grouped_congruency.std().get_value("congruent", "RT")
incongruent_rtsd = grouped_congruency.std().get_value("incongruent", "RT")
congruent_rtcov = congruent_rtsd / congruent_rt
incongruent_rtcov = incongruent_rtsd / incongruent_rt
congruent_correct = grouped_congruency.sum().get_value("congruent", "correct")
incongruent_correct = grouped_congruency.sum().get_value("incongruent", "correct")
# OLS regression
conflict_intercept, conflict_slope = congruent_rt, incongruent_rt - congruent_rt
conflict_slope_norm = conflict_slope / congruent_rt
columns += [follow_error_rt, follow_correct_rt,
congruent_rt, incongruent_rt,
congruent_rtsd, incongruent_rtsd,
congruent_rtcov, incongruent_rtcov,
congruent_correct, incongruent_correct,
conflict_intercept, conflict_slope, conflict_slope_norm]
return columns
[docs]def aggregate_mrt(data, sub_num):
"""
Aggregate data from the MRT task.
Calculates various summary statistics for the MRT task for a given subject.
Parameters
----------
data : dataframe
Pandas dataframe containing a single subjects trial data for the task.
sub_num : str
Subject number to which the data file belongs.
Returns
-------
stats : list
List containing the calculated data for the subject.
"""
mrt_count = data["correct"].sum()
mrt_num_items = data.shape[0]
mrt_prop = mrt_count / mrt_num_items
return [sub_num, mrt_count, mrt_prop, mrt_num_items]
[docs]def aggregate_ravens(data, sub_num):
"""
Aggregate data from the Raven's Matrices task.
Calculates various summary statistics for the Raven's Matrices task for a
given subject.
Parameters
----------
data : dataframe
Pandas dataframe containing a single subjects trial data for the task.
sub_num : str
Subject number to which the data file belongs.
Returns
-------
stats : list
List containing the calculated data for the subject.
"""
ravens_rt = data["RT"].mean()
ravens_count = data["correct"].sum()
ravens_num_items = data.shape[0]
ravens_prop = ravens_count / ravens_num_items
return [sub_num, ravens_rt, ravens_count, ravens_prop, ravens_num_items]
[docs]def aggregate_sart(data, sub_num):
"""
Aggregate data from the SART task.
Calculates various summary statistics for the SART task for a given subject.
Parameters
----------
data : dataframe
Pandas dataframe containing a single subjects trial data for the task.
sub_num : str
Subject number to which the data file belongs.
Returns
-------
stats : list
List containing the calculated data for the subject.
"""
# Calculate times following errors and correct responses
follow_error_rt = data.loc[data.accuracy.shift() == 0, "RT"].mean()
follow_correct_rt = data.loc[data.accuracy.shift() == 1, "RT"].mean()
total_rt = data["RT"].mean()
total_rtsd = data["RT"].std()
total_rtcov = total_rtsd / total_rt
frequent_rt = data[data["stimulus"] != 3]["RT"].mean()
frequent_rtsd = data[data["stimulus"] != 3]["RT"].std()
frequent_rtcov = frequent_rtsd / frequent_rt
infrequent_rt = data[data["stimulus"] == 3]["RT"].mean()
infrequent_rtsd = data[data["stimulus"] == 3]["RT"].std()
infrequent_rtcov = infrequent_rtsd / infrequent_rt
sart_error_count = data[data["stimulus"] == 3]["key press"].sum()
sart_errors_num_items = data[data["stimulus"] == 3].shape[0]
sart_errors_prop = sart_error_count / sart_errors_num_items
return [sub_num, follow_error_rt, follow_correct_rt,
total_rt, total_rtsd, total_rtcov,
frequent_rt, frequent_rtsd, frequent_rtcov,
infrequent_rt, infrequent_rtsd, infrequent_rtcov,
sart_error_count, sart_errors_prop, sart_errors_num_items]
[docs]def aggregate_sternberg(data, sub_num, response_type="full"):
"""
Aggregate data from the Sternberg task.
Calculates various summary statistics for the Sternberg task for a
given subject.
Parameters
----------
data : dataframe
Pandas dataframe containing a single subjects trial data for the task.
sub_num : str
Subject number to which the data file belongs.
response_type : {'full', 'correct', 'incorrect'}, optional
Should the summary data be calculated using all trials? Only correct
trials? Or only incorrect trials? This is not supported in all tasks.
Returns
-------
stats : list
List containing the calculated data for the subject.
"""
# Calculate times following errors and correct responses
df = data
follow_error_rt = df.loc[df.correct.shift() == 0, "RT"].mean()
follow_correct_rt = df.loc[df.correct.shift() == 1, "RT"].mean()
if response_type == "correct":
df = data[data["correct"] == 1]
elif response_type == "incorrect":
df = data[data["correct"] == 0]
elif response_type == "full":
df = data
# Aggregated descriptives
grouped_set_size = df.groupby("setSize")
set_2_rt = grouped_set_size.mean().get_value(2, "RT")
set_2_rtsd = grouped_set_size.std().get_value(2, "RT")
set_2_rtcov = set_2_rtsd / set_2_rt
set_2_correct = grouped_set_size.sum().get_value(2, "correct")
set_6_rt = grouped_set_size.mean().get_value(6, "RT")
set_6_rtsd = grouped_set_size.std().get_value(6, "RT")
set_6_rtcov = set_6_rtsd / set_6_rt
set_6_correct = grouped_set_size.sum().get_value(6, "correct")
# OLS regression
intercept, slope = set_2_rt, set_6_rt - set_2_rt
slope_norm = slope / set_2_rt
return [sub_num,
follow_error_rt, follow_correct_rt,
set_2_rt, set_6_rt,
set_2_rtsd, set_6_rtsd,
set_2_rtcov, set_6_rtcov,
set_2_correct, set_6_correct,
intercept, slope, slope_norm]