Source code for pyutils_sh.battery

"""
Functions for aggregating subject data from 
Cognitive Battery (https://github.com/sho-87/cognitive-battery)
"""

import os
import pandas as pd


[docs]def aggregate_wide( dir_battery, dir_output, response_type="full", use_file=False, save=True ): """ Aggregate data from all battery tasks. Takes a directory containing individual subject data files created from the Cognitive Battery, and calculates summary statistics for all subjects across all tasks. A single output summary file is created containing the aggregated battery data. Parameters ---------- dir_battery : str Path to the directory containing subject data files created by the Cognitive Battery. dir_output : str Path to the directory where the output summary file will be saved. A filed named 'battery_data.csv' will be created in this directory. response_type : {'full', 'correct', 'incorrect'}, optional Should the summary data be calculated using all trials? Only correct trials? Or only incorrect trials? This is not supported in all tasks. use_file : bool, optional If True, aggregated battery data will be imported from the existing summary file instead of being re-aggregated. save : bool, optional Set to True to save an output summary file to the output directory. If False, then no file will be saved, but a dataframe will still be returned from this function. Returns ------- all_data : dataframe Pandas dataframe containing the aggregated summary data for all tasks. """ if use_file: if os.path.isfile( os.path.join(dir_output, "battery_data_{}.csv".format(response_type)) ): print("Importing battery summary file...") return pd.read_csv( os.path.join(dir_output, "battery_data_{}.csv".format(response_type)) ) # Create dataframes df_info = pd.DataFrame( columns=["sub_num", "datetime", "condition", "age", "sex", "RA"] ) df_ant = pd.DataFrame( columns=[ "sub_num", "ant_follow_error_rt", "ant_follow_correct_rt", "ant_neutral_rt", "ant_congruent_rt", "ant_incongruent_rt", "ant_neutral_rtsd", "ant_congruent_rtsd", "ant_incongruent_rtsd", "ant_neutral_rtcov", "ant_congruent_rtcov", "ant_incongruent_rtcov", "ant_neutral_correct", "ant_congruent_correct", "ant_incongruent_correct", "ant_nocue_rt", "ant_center_rt", "ant_spatial_rt", "ant_double_rt", "ant_nocue_rtsd", "ant_center_rtsd", "ant_spatial_rtsd", "ant_double_rtsd", "ant_nocue_rtcov", "ant_center_rtcov", "ant_spatial_rtcov", "ant_double_rtcov", "ant_nocue_correct", "ant_center_correct", "ant_spatial_correct", "ant_double_correct", "ant_conflict_intercept", "ant_conflict_slope", "ant_conflict_slope_norm", "ant_alerting_intercept", "ant_alerting_slope", "ant_alerting_slope_norm", "ant_orienting_intercept", "ant_orienting_slope", "ant_orienting_slope_norm", ] ) df_flanker_compat = pd.DataFrame( columns=[ "sub_num", "flanker_compat_follow_error_rt", "flanker_compat_follow_correct_rt", "flanker_compat_congruent_rt", "flanker_compat_incongruent_rt", "flanker_compat_congruent_rtsd", "flanker_compat_incongruent_rtsd", "flanker_compat_congruent_rtcov", "flanker_compat_incongruent_rtcov", "flanker_compat_congruent_correct", "flanker_compat_incongruent_correct", "flanker_compat_conflict_intercept", "flanker_compat_conflict_slope", "flanker_compat_conflict_slope_norm", ] ) df_flanker_incompat = pd.DataFrame( columns=[ "sub_num", "flanker_incompat_follow_error_rt", "flanker_incompat_follow_correct_rt", "flanker_incompat_congruent_rt", "flanker_incompat_incongruent_rt", "flanker_incompat_congruent_rtsd", "flanker_incompat_incongruent_rtsd", "flanker_incompat_congruent_rtcov", "flanker_incompat_incongruent_rtcov", "flanker_incompat_congruent_correct", "flanker_incompat_incongruent_correct", "flanker_incompat_conflict_intercept", "flanker_incompat_conflict_slope", "flanker_incompat_conflict_slope_norm", ] ) df_flanker_both = df_flanker_compat.merge(df_flanker_incompat, on="sub_num") cols = list(df_flanker_both.columns.values) cols.pop(cols.index("sub_num")) # Remove sub_num from list df_flanker_both = df_flanker_both[["sub_num"] + cols] df_digit = pd.DataFrame( columns=[ "sub_num", "digit_correct_count", "digit_correct_prop", "digit_num_items", ] ) df_mrt = pd.DataFrame(columns=["sub_num", "mrt_count", "mrt_prop", "mrt_num_items"]) df_ravens = pd.DataFrame( columns=[ "sub_num", "ravens_rt", "ravens_count", "ravens_prop", "ravens_num_items", ] ) df_sart = pd.DataFrame( columns=[ "sub_num", "sart_follow_error_rt", "sart_follow_correct_rt", "sart_total_rt", "sart_total_rtsd", "sart_total_rtcov", "sart_frequent_rt", "sart_frequent_rtsd", "sart_frequent_rtcov", "sart_infrequent_rt", "sart_infrequent_rtsd", "sart_infrequent_rtcov", "sart_error_count", " sart_errors_prop", "sart_errors_num_items", ] ) df_sternberg = pd.DataFrame( columns=[ "sub_num", "stern_follow_error_rt", "stern_follow_correct_rt", "stern_set_2_rt", "stern_set_6_rt", "stern_set_2_rtsd", "stern_set_6_rtsd", "stern_set_2_rtcov", "stern_set_6_rtcov", "stern_set_2_correct", "stern_set_6_correct", "stern_intercept", "stern_slope", "stern_slope_norm", ] ) # Aggregate all data for f in os.listdir(dir_battery): if f.endswith(".xls"): print("Summarizing {}".format(f)) sub = pd.read_excel( os.path.join(dir_battery, f), None, converters={"sub_num": str} ) try: sub_num = sub["info"].loc[0, "sub_num"] except KeyError: sub_num = sub["info"].loc[0, "subNum"] datetime = sub["info"].loc[0, "datetime"] condition = int(sub["info"]["condition"]) age = int(sub["info"]["age"]) sex = sub["info"].loc[0, "sex"] ra = sub["info"].loc[0, "RA"] for task, data in sub.items(): if task == "info": df_info.loc[df_info.shape[0]] = [ sub_num, datetime, condition, age, sex, ra, ] elif task == "ANT": # full / correct / incorrect df_ant.loc[df_ant.shape[0]] = aggregate_ant( data, sub_num, response_type ) elif task == "Digit span (backwards)": df_digit.loc[df_digit.shape[0]] = aggregate_digit_span( data, sub_num ) elif task == "Eriksen Flanker": compat_conditions = data["compatibility"].unique() # full / correct / incorrect if ( len(compat_conditions) == 1 and compat_conditions == "compatible" ): df_flanker_compat.loc[ df_flanker_compat.shape[0] ] = aggregate_flanker(data, sub_num, response_type) elif ( len(compat_conditions) == 1 and compat_conditions == "incompatible" ): df_flanker_incompat.loc[ df_flanker_incompat.shape[0] ] = aggregate_flanker(data, sub_num, response_type) else: df_flanker_both.loc[ df_flanker_both.shape[0] ] = aggregate_flanker(data, sub_num, response_type) elif task == "MRT": df_mrt.loc[df_mrt.shape[0]] = aggregate_mrt(data, sub_num) elif task == "Ravens Matrices": df_ravens.loc[df_ravens.shape[0]] = aggregate_ravens(data, sub_num) elif task == "SART": df_sart.loc[df_sart.shape[0]] = aggregate_sart(data, sub_num) elif task == "Sternberg": # full / correct / incorrect df_sternberg.loc[df_sternberg.shape[0]] = aggregate_sternberg( data, sub_num, response_type ) # Merge task data # Only merge tasks that were used tasks = [ df_ant, df_digit, df_flanker_compat, df_flanker_incompat, df_flanker_both, df_mrt, df_ravens, df_sart, df_sternberg, ] all_data = df_info for task in tasks: if task.shape[0] != 0: all_data = all_data.merge(task, on="sub_num", how="left") all_data["sub_num"] = all_data["sub_num"].astype(int) all_data = all_data.sort_values("sub_num").reset_index(drop=True) # Save output csv if save: all_data.to_csv( os.path.join(dir_output, "battery_data_{}.csv".format(response_type)), index=False, sep=",", ) return all_data
[docs]def aggregate_ant(data, sub_num, response_type="full"): """ Aggregate data from the ANT task. Calculates various summary statistics for the ANT task for a given subject. Parameters ---------- data : dataframe Pandas dataframe containing a single subjects trial data for the task. sub_num : str Subject number to which the data file belongs. response_type : {'full', 'correct', 'incorrect'}, optional Should the summary data be calculated using all trials? Only correct trials? Or only incorrect trials? This is not supported in all tasks. Returns ------- stats : list List containing the calculated data for the subject. """ # Calculate times following errors and correct responses df = data follow_error_rt = df.loc[df.correct.shift() == 0, "RT"].mean() follow_correct_rt = df.loc[df.correct.shift() == 1, "RT"].mean() if response_type == "correct": df = data[data["correct"] == 1] elif response_type == "incorrect": df = data[data["correct"] == 0] elif response_type == "full": df = data # Aggregated descriptives ## congruency conditions grouped_congruency = df.groupby("congruency") neutral_rt = grouped_congruency.mean().get_value("neutral", "RT") congruent_rt = grouped_congruency.mean().get_value("congruent", "RT") incongruent_rt = grouped_congruency.mean().get_value("incongruent", "RT") neutral_rtsd = grouped_congruency.std().get_value("neutral", "RT") congruent_rtsd = grouped_congruency.std().get_value("congruent", "RT") incongruent_rtsd = grouped_congruency.std().get_value("incongruent", "RT") neutral_rtcov = neutral_rtsd / neutral_rt congruent_rtcov = congruent_rtsd / congruent_rt incongruent_rtcov = incongruent_rtsd / incongruent_rt neutral_correct = grouped_congruency.sum().get_value("neutral", "correct") congruent_correct = grouped_congruency.sum().get_value("congruent", "correct") incongruent_correct = grouped_congruency.sum().get_value("incongruent", "correct") ## cue conditions grouped_cue = df.groupby("cue") nocue_rt = grouped_cue.mean().get_value("nocue", "RT") center_rt = grouped_cue.mean().get_value("center", "RT") spatial_rt = grouped_cue.mean().get_value("spatial", "RT") double_rt = grouped_cue.mean().get_value("double", "RT") nocue_rtsd = grouped_cue.std().get_value("nocue", "RT") center_rtsd = grouped_cue.std().get_value("center", "RT") spatial_rtsd = grouped_cue.std().get_value("spatial", "RT") double_rtsd = grouped_cue.std().get_value("double", "RT") nocue_rtcov = nocue_rtsd / nocue_rt center_rtcov = center_rtsd / center_rt spatial_rtcov = spatial_rtsd / spatial_rt double_rtcov = double_rtsd / double_rt nocue_correct = grouped_cue.sum().get_value("nocue", "correct") center_correct = grouped_cue.sum().get_value("center", "correct") spatial_correct = grouped_cue.sum().get_value("spatial", "correct") double_correct = grouped_cue.sum().get_value("double", "correct") # OLS regression conflict_intercept, conflict_slope = congruent_rt, incongruent_rt - congruent_rt conflict_slope_norm = conflict_slope / congruent_rt alerting_intercept, alerting_slope = double_rt, nocue_rt - double_rt alerting_slope_norm = alerting_slope / double_rt orienting_intercept, orienting_slope = spatial_rt, center_rt - spatial_rt orienting_slope_norm = orienting_slope / spatial_rt return [ sub_num, follow_error_rt, follow_correct_rt, neutral_rt, congruent_rt, incongruent_rt, neutral_rtsd, congruent_rtsd, incongruent_rtsd, neutral_rtcov, congruent_rtcov, incongruent_rtcov, neutral_correct, congruent_correct, incongruent_correct, nocue_rt, center_rt, spatial_rt, double_rt, nocue_rtsd, center_rtsd, spatial_rtsd, double_rtsd, nocue_rtcov, center_rtcov, spatial_rtcov, double_rtcov, nocue_correct, center_correct, spatial_correct, double_correct, conflict_intercept, conflict_slope, conflict_slope_norm, alerting_intercept, alerting_slope, alerting_slope_norm, orienting_intercept, orienting_slope, orienting_slope_norm, ]
[docs]def aggregate_digit_span(data, sub_num): """ Aggregate data from the digit span task. Calculates various summary statistics for the digit span task for a given subject. Parameters ---------- data : dataframe Pandas dataframe containing a single subjects trial data for the task. sub_num : str Subject number to which the data file belongs. Returns ------- stats : list List containing the calculated data for the subject. """ digit_correct_count = data["correct"].sum() digit_correct_num_items = data.shape[0] digit_correct_prop = digit_correct_count / digit_correct_num_items return [sub_num, digit_correct_count, digit_correct_prop, digit_correct_num_items]
[docs]def aggregate_flanker(data, sub_num, response_type="full"): """ Aggregate data from the Flanker task. Calculates various summary statistics for the Flanker task for a given subject. Parameters ---------- data : dataframe Pandas dataframe containing a single subjects trial data for the task. sub_num : str Subject number to which the data file belongs. response_type : {'full', 'correct', 'incorrect'}, optional Should the summary data be calculated using all trials? Only correct trials? Or only incorrect trials? This is not supported in all tasks. Returns ------- stats : list List containing the calculated data for the subject. """ columns = [sub_num] # split compatibility conditions for comp_type in sorted(list(data["compatibility"].unique()), reverse=False): df_cur = data[data["compatibility"] == comp_type] # Calculate times following errors and correct responses follow_error_rt = df_cur.loc[df_cur.correct.shift() == 0, "RT"].mean() follow_correct_rt = df_cur.loc[df_cur.correct.shift() == 1, "RT"].mean() if response_type == "correct": df = df_cur[df_cur["correct"] == 1] elif response_type == "incorrect": df = df_cur[df_cur["correct"] == 0] elif response_type == "full": df = df_cur grouped_congruency = df.groupby(["congruency"]) congruent_rt = grouped_congruency.mean().get_value("congruent", "RT") incongruent_rt = grouped_congruency.mean().get_value("incongruent", "RT") congruent_rtsd = grouped_congruency.std().get_value("congruent", "RT") incongruent_rtsd = grouped_congruency.std().get_value("incongruent", "RT") congruent_rtcov = congruent_rtsd / congruent_rt incongruent_rtcov = incongruent_rtsd / incongruent_rt congruent_correct = grouped_congruency.sum().get_value("congruent", "correct") incongruent_correct = grouped_congruency.sum().get_value( "incongruent", "correct" ) # OLS regression conflict_intercept, conflict_slope = congruent_rt, incongruent_rt - congruent_rt conflict_slope_norm = conflict_slope / congruent_rt columns += [ follow_error_rt, follow_correct_rt, congruent_rt, incongruent_rt, congruent_rtsd, incongruent_rtsd, congruent_rtcov, incongruent_rtcov, congruent_correct, incongruent_correct, conflict_intercept, conflict_slope, conflict_slope_norm, ] return columns
[docs]def aggregate_mrt(data, sub_num): """ Aggregate data from the MRT task. Calculates various summary statistics for the MRT task for a given subject. Parameters ---------- data : dataframe Pandas dataframe containing a single subjects trial data for the task. sub_num : str Subject number to which the data file belongs. Returns ------- stats : list List containing the calculated data for the subject. """ mrt_count = data["correct"].sum() mrt_num_items = data.shape[0] mrt_prop = mrt_count / mrt_num_items return [sub_num, mrt_count, mrt_prop, mrt_num_items]
[docs]def aggregate_ravens(data, sub_num): """ Aggregate data from the Raven's Matrices task. Calculates various summary statistics for the Raven's Matrices task for a given subject. Parameters ---------- data : dataframe Pandas dataframe containing a single subjects trial data for the task. sub_num : str Subject number to which the data file belongs. Returns ------- stats : list List containing the calculated data for the subject. """ ravens_rt = data["RT"].mean() ravens_count = data["correct"].sum() ravens_num_items = data.shape[0] ravens_prop = ravens_count / ravens_num_items return [sub_num, ravens_rt, ravens_count, ravens_prop, ravens_num_items]
[docs]def aggregate_sart(data, sub_num): """ Aggregate data from the SART task. Calculates various summary statistics for the SART task for a given subject. Parameters ---------- data : dataframe Pandas dataframe containing a single subjects trial data for the task. sub_num : str Subject number to which the data file belongs. Returns ------- stats : list List containing the calculated data for the subject. """ # Calculate times following errors and correct responses follow_error_rt = data.loc[data.accuracy.shift() == 0, "RT"].mean() follow_correct_rt = data.loc[data.accuracy.shift() == 1, "RT"].mean() total_rt = data["RT"].mean() total_rtsd = data["RT"].std() total_rtcov = total_rtsd / total_rt frequent_rt = data[data["stimulus"] != 3]["RT"].mean() frequent_rtsd = data[data["stimulus"] != 3]["RT"].std() frequent_rtcov = frequent_rtsd / frequent_rt infrequent_rt = data[data["stimulus"] == 3]["RT"].mean() infrequent_rtsd = data[data["stimulus"] == 3]["RT"].std() infrequent_rtcov = infrequent_rtsd / infrequent_rt sart_error_count = data[data["stimulus"] == 3]["key press"].sum() sart_errors_num_items = data[data["stimulus"] == 3].shape[0] sart_errors_prop = sart_error_count / sart_errors_num_items return [ sub_num, follow_error_rt, follow_correct_rt, total_rt, total_rtsd, total_rtcov, frequent_rt, frequent_rtsd, frequent_rtcov, infrequent_rt, infrequent_rtsd, infrequent_rtcov, sart_error_count, sart_errors_prop, sart_errors_num_items, ]
[docs]def aggregate_sternberg(data, sub_num, response_type="full"): """ Aggregate data from the Sternberg task. Calculates various summary statistics for the Sternberg task for a given subject. Parameters ---------- data : dataframe Pandas dataframe containing a single subjects trial data for the task. sub_num : str Subject number to which the data file belongs. response_type : {'full', 'correct', 'incorrect'}, optional Should the summary data be calculated using all trials? Only correct trials? Or only incorrect trials? This is not supported in all tasks. Returns ------- stats : list List containing the calculated data for the subject. """ # Calculate times following errors and correct responses df = data follow_error_rt = df.loc[df.correct.shift() == 0, "RT"].mean() follow_correct_rt = df.loc[df.correct.shift() == 1, "RT"].mean() if response_type == "correct": df = data[data["correct"] == 1] elif response_type == "incorrect": df = data[data["correct"] == 0] elif response_type == "full": df = data # Aggregated descriptives grouped_set_size = df.groupby("setSize") set_2_rt = grouped_set_size.mean().get_value(2, "RT") set_2_rtsd = grouped_set_size.std().get_value(2, "RT") set_2_rtcov = set_2_rtsd / set_2_rt set_2_correct = grouped_set_size.sum().get_value(2, "correct") set_6_rt = grouped_set_size.mean().get_value(6, "RT") set_6_rtsd = grouped_set_size.std().get_value(6, "RT") set_6_rtcov = set_6_rtsd / set_6_rt set_6_correct = grouped_set_size.sum().get_value(6, "correct") # OLS regression intercept, slope = set_2_rt, set_6_rt - set_2_rt slope_norm = slope / set_2_rt return [ sub_num, follow_error_rt, follow_correct_rt, set_2_rt, set_6_rt, set_2_rtsd, set_6_rtsd, set_2_rtcov, set_6_rtcov, set_2_correct, set_6_correct, intercept, slope, slope_norm, ]