Decoding in time-frequency space data using the Common Spatial Pattern (CSP)

The time-frequency decomposition is estimated by iterating over raw data that has been band-passed at different frequencies. This is used to compute a covariance matrix over each epoch or a rolling time-window and extract the CSP filtered signals. A linear discriminant classifier is then applied to these signals.

# Authors: Laura Gwilliams <laura.gwilliams@nyu.edu>
#          Jean-Remi King <jeanremi.king@gmail.com>
#          Alex Barachant <alexandre.barachant@gmail.com>
#          Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt

from mne import Epochs, find_events, create_info
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.time_frequency import AverageTFR

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder

Set parameters and read data

event_id = dict(hands=2, feet=3)  # motor imagery: hands vs feet
subject = 1
runs = [6, 10, 14]
raw_fnames = eegbci.load_data(subject, runs)
raw_files = [read_raw_edf(f, stim_channel='auto', preload=True)
             for f in raw_fnames]
raw = concatenate_raws(raw_files)

# Extract information from the raw file
sfreq = raw.info['sfreq']
events = find_events(raw, shortest_event=0, stim_channel='STI 014')
raw.pick_types(meg=False, eeg=True, stim=False, eog=False, exclude='bads')

# Assemble the classifier using scikit-learn pipeline
clf = make_pipeline(CSP(n_components=4, reg=None, log=True, norm_trace=False),
                    LinearDiscriminantAnalysis())
n_splits = 5  # how many folds to use for cross-validation
cv = StratifiedKFold(n_splits=n_splits, shuffle=True)

# Classification & Time-frequency parameters
tmin, tmax = -.200, 2.000
n_cycles = 10.  # how many complete cycles: used to define window size
min_freq = 5.
max_freq = 25.
n_freqs = 8  # how many frequency bins to use

# Assemble list of frequency range tuples
freqs = np.linspace(min_freq, max_freq, n_freqs)  # assemble frequencies
freq_ranges = list(zip(freqs[:-1], freqs[1:]))  # make freqs list of tuples

# Infer window spacing from the max freq and number of cycles to avoid gaps
window_spacing = (n_cycles / np.max(freqs) / 2.)
centered_w_times = np.arange(tmin, tmax, window_spacing)[1:]
n_windows = len(centered_w_times)

# Instantiate label encoder
le = LabelEncoder()

Out:

Extracting edf Parameters from /home/ubuntu/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S001/S001R06.edf...
EDF file detected
Setting channel info structure...
Created Raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Ready.
Extracting edf Parameters from /home/ubuntu/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S001/S001R10.edf...
EDF file detected
Setting channel info structure...
Created Raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Ready.
Extracting edf Parameters from /home/ubuntu/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S001/S001R14.edf...
EDF file detected
Setting channel info structure...
Created Raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Ready.
Removing orphaned offset at the beginning of the file.
89 events found
Events id: [1 2 3]

Loop through frequencies, apply classifier and save scores

# init scores
freq_scores = np.zeros((n_freqs - 1,))

# Loop through each frequency range of interest
for freq, (fmin, fmax) in enumerate(freq_ranges):

    # Infer window size based on the frequency being used
    w_size = n_cycles / ((fmax + fmin) / 2.)  # in seconds

    # Apply band-pass filter to isolate the specified frequencies
    raw_filter = raw.copy().filter(fmin, fmax, n_jobs=1, fir_design='firwin',
                                   skip_by_annotation='edge')

    # Extract epochs from filtered data, padded by window size
    epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size,
                    proj=False, baseline=None, preload=True)
    epochs.drop_bad()
    y = le.fit_transform(epochs.events[:, 2])

    X = epochs.get_data()

    # Save mean scores over folds for each frequency and time window
    freq_scores[freq] = np.mean(cross_val_score(estimator=clf, X=X, y=y,
                                                scoring='roc_auc', cv=cv,
                                                n_jobs=1), axis=0)

Out:

Setting up band-pass filter from 5 - 7.9 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.0 Hz
Filter length of 264 samples (1.650 sec) selected
Setting up band-pass filter from 5 - 7.9 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.0 Hz
Filter length of 264 samples (1.650 sec) selected
Setting up band-pass filter from 5 - 7.9 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.0 Hz
Filter length of 264 samples (1.650 sec) selected
45 matching events found
Loading data for 45 events and 851 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 7.9 - 11 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.7 Hz
Filter length of 264 samples (1.650 sec) selected
Setting up band-pass filter from 7.9 - 11 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.7 Hz
Filter length of 264 samples (1.650 sec) selected
Setting up band-pass filter from 7.9 - 11 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.7 Hz
Filter length of 264 samples (1.650 sec) selected
45 matching events found
Loading data for 45 events and 697 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 11 - 14 Hz
l_trans_bandwidth chosen to be 2.7 Hz
h_trans_bandwidth chosen to be 3.4 Hz
Filter length of 197 samples (1.231 sec) selected
Setting up band-pass filter from 11 - 14 Hz
l_trans_bandwidth chosen to be 2.7 Hz
h_trans_bandwidth chosen to be 3.4 Hz
Filter length of 197 samples (1.231 sec) selected
Setting up band-pass filter from 11 - 14 Hz
l_trans_bandwidth chosen to be 2.7 Hz
h_trans_bandwidth chosen to be 3.4 Hz
Filter length of 197 samples (1.231 sec) selected
45 matching events found
Loading data for 45 events and 617 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 14 - 16 Hz
l_trans_bandwidth chosen to be 3.4 Hz
h_trans_bandwidth chosen to be 4.1 Hz
Filter length of 156 samples (0.975 sec) selected
Setting up band-pass filter from 14 - 16 Hz
l_trans_bandwidth chosen to be 3.4 Hz
h_trans_bandwidth chosen to be 4.1 Hz
Filter length of 156 samples (0.975 sec) selected
Setting up band-pass filter from 14 - 16 Hz
l_trans_bandwidth chosen to be 3.4 Hz
h_trans_bandwidth chosen to be 4.1 Hz
Filter length of 156 samples (0.975 sec) selected
45 matching events found
Loading data for 45 events and 567 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 16 - 19 Hz
l_trans_bandwidth chosen to be 4.1 Hz
h_trans_bandwidth chosen to be 4.8 Hz
Filter length of 129 samples (0.806 sec) selected
Setting up band-pass filter from 16 - 19 Hz
l_trans_bandwidth chosen to be 4.1 Hz
h_trans_bandwidth chosen to be 4.8 Hz
Filter length of 129 samples (0.806 sec) selected
Setting up band-pass filter from 16 - 19 Hz
l_trans_bandwidth chosen to be 4.1 Hz
h_trans_bandwidth chosen to be 4.8 Hz
Filter length of 129 samples (0.806 sec) selected
45 matching events found
Loading data for 45 events and 533 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 19 - 22 Hz
l_trans_bandwidth chosen to be 4.8 Hz
h_trans_bandwidth chosen to be 5.5 Hz
Filter length of 110 samples (0.688 sec) selected
Setting up band-pass filter from 19 - 22 Hz
l_trans_bandwidth chosen to be 4.8 Hz
h_trans_bandwidth chosen to be 5.5 Hz
Filter length of 110 samples (0.688 sec) selected
Setting up band-pass filter from 19 - 22 Hz
l_trans_bandwidth chosen to be 4.8 Hz
h_trans_bandwidth chosen to be 5.5 Hz
Filter length of 110 samples (0.688 sec) selected
45 matching events found
Loading data for 45 events and 507 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 22 - 25 Hz
l_trans_bandwidth chosen to be 5.5 Hz
h_trans_bandwidth chosen to be 6.2 Hz
Filter length of 95 samples (0.594 sec) selected
Setting up band-pass filter from 22 - 25 Hz
l_trans_bandwidth chosen to be 5.5 Hz
h_trans_bandwidth chosen to be 6.2 Hz
Filter length of 95 samples (0.594 sec) selected
Setting up band-pass filter from 22 - 25 Hz
l_trans_bandwidth chosen to be 5.5 Hz
h_trans_bandwidth chosen to be 6.2 Hz
Filter length of 95 samples (0.594 sec) selected
45 matching events found
Loading data for 45 events and 489 original time points ...
0 bad epochs dropped

Plot frequency results

plt.bar(left=freqs[:-1], height=freq_scores, width=np.diff(freqs)[0],
        align='edge', edgecolor='black')
plt.xticks(freqs)
plt.ylim([0, 1])
plt.axhline(len(epochs['feet']) / len(epochs), color='k', linestyle='--',
            label='chance level')
plt.legend()
plt.xlabel('Frequency (Hz)')
plt.ylabel('Decoding Scores')
plt.title('Frequency Decoding Scores')
../../_images/sphx_glr_plot_decoding_csp_timefreq_001.png

Loop through frequencies and time, apply classifier and save scores

# init scores
tf_scores = np.zeros((n_freqs - 1, n_windows))

# Loop through each frequency range of interest
for freq, (fmin, fmax) in enumerate(freq_ranges):

    # Infer window size based on the frequency being used
    w_size = n_cycles / ((fmax + fmin) / 2.)  # in seconds

    # Apply band-pass filter to isolate the specified frequencies
    raw_filter = raw.copy().filter(fmin, fmax, n_jobs=1, fir_design='firwin',
                                   skip_by_annotation='edge')

    # Extract epochs from filtered data, padded by window size
    epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size,
                    proj=False, baseline=None, preload=True)
    epochs.drop_bad()
    y = le.fit_transform(epochs.events[:, 2])

    # Roll covariance, csp and lda over time
    for t, w_time in enumerate(centered_w_times):

        # Center the min and max of the window
        w_tmin = w_time - w_size / 2.
        w_tmax = w_time + w_size / 2.

        # Crop data into time-window of interest
        X = epochs.copy().crop(w_tmin, w_tmax).get_data()

        # Save mean scores over folds for each frequency and time window
        tf_scores[freq, t] = np.mean(cross_val_score(estimator=clf, X=X, y=y,
                                                     scoring='roc_auc', cv=cv,
                                                     n_jobs=1), axis=0)

Out:

Setting up band-pass filter from 5 - 7.9 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.0 Hz
Filter length of 264 samples (1.650 sec) selected
Setting up band-pass filter from 5 - 7.9 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.0 Hz
Filter length of 264 samples (1.650 sec) selected
Setting up band-pass filter from 5 - 7.9 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.0 Hz
Filter length of 264 samples (1.650 sec) selected
45 matching events found
Loading data for 45 events and 851 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 7.9 - 11 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.7 Hz
Filter length of 264 samples (1.650 sec) selected
Setting up band-pass filter from 7.9 - 11 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.7 Hz
Filter length of 264 samples (1.650 sec) selected
Setting up band-pass filter from 7.9 - 11 Hz
l_trans_bandwidth chosen to be 2.0 Hz
h_trans_bandwidth chosen to be 2.7 Hz
Filter length of 264 samples (1.650 sec) selected
45 matching events found
Loading data for 45 events and 697 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 11 - 14 Hz
l_trans_bandwidth chosen to be 2.7 Hz
h_trans_bandwidth chosen to be 3.4 Hz
Filter length of 197 samples (1.231 sec) selected
Setting up band-pass filter from 11 - 14 Hz
l_trans_bandwidth chosen to be 2.7 Hz
h_trans_bandwidth chosen to be 3.4 Hz
Filter length of 197 samples (1.231 sec) selected
Setting up band-pass filter from 11 - 14 Hz
l_trans_bandwidth chosen to be 2.7 Hz
h_trans_bandwidth chosen to be 3.4 Hz
Filter length of 197 samples (1.231 sec) selected
45 matching events found
Loading data for 45 events and 617 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 14 - 16 Hz
l_trans_bandwidth chosen to be 3.4 Hz
h_trans_bandwidth chosen to be 4.1 Hz
Filter length of 156 samples (0.975 sec) selected
Setting up band-pass filter from 14 - 16 Hz
l_trans_bandwidth chosen to be 3.4 Hz
h_trans_bandwidth chosen to be 4.1 Hz
Filter length of 156 samples (0.975 sec) selected
Setting up band-pass filter from 14 - 16 Hz
l_trans_bandwidth chosen to be 3.4 Hz
h_trans_bandwidth chosen to be 4.1 Hz
Filter length of 156 samples (0.975 sec) selected
45 matching events found
Loading data for 45 events and 567 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 16 - 19 Hz
l_trans_bandwidth chosen to be 4.1 Hz
h_trans_bandwidth chosen to be 4.8 Hz
Filter length of 129 samples (0.806 sec) selected
Setting up band-pass filter from 16 - 19 Hz
l_trans_bandwidth chosen to be 4.1 Hz
h_trans_bandwidth chosen to be 4.8 Hz
Filter length of 129 samples (0.806 sec) selected
Setting up band-pass filter from 16 - 19 Hz
l_trans_bandwidth chosen to be 4.1 Hz
h_trans_bandwidth chosen to be 4.8 Hz
Filter length of 129 samples (0.806 sec) selected
45 matching events found
Loading data for 45 events and 533 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 19 - 22 Hz
l_trans_bandwidth chosen to be 4.8 Hz
h_trans_bandwidth chosen to be 5.5 Hz
Filter length of 110 samples (0.688 sec) selected
Setting up band-pass filter from 19 - 22 Hz
l_trans_bandwidth chosen to be 4.8 Hz
h_trans_bandwidth chosen to be 5.5 Hz
Filter length of 110 samples (0.688 sec) selected
Setting up band-pass filter from 19 - 22 Hz
l_trans_bandwidth chosen to be 4.8 Hz
h_trans_bandwidth chosen to be 5.5 Hz
Filter length of 110 samples (0.688 sec) selected
45 matching events found
Loading data for 45 events and 507 original time points ...
0 bad epochs dropped
Setting up band-pass filter from 22 - 25 Hz
l_trans_bandwidth chosen to be 5.5 Hz
h_trans_bandwidth chosen to be 6.2 Hz
Filter length of 95 samples (0.594 sec) selected
Setting up band-pass filter from 22 - 25 Hz
l_trans_bandwidth chosen to be 5.5 Hz
h_trans_bandwidth chosen to be 6.2 Hz
Filter length of 95 samples (0.594 sec) selected
Setting up band-pass filter from 22 - 25 Hz
l_trans_bandwidth chosen to be 5.5 Hz
h_trans_bandwidth chosen to be 6.2 Hz
Filter length of 95 samples (0.594 sec) selected
45 matching events found
Loading data for 45 events and 489 original time points ...
0 bad epochs dropped

Plot time-frequency results

# Set up time frequency object
av_tfr = AverageTFR(create_info(['freq'], sfreq), tf_scores[np.newaxis, :],
                    centered_w_times, freqs[1:], 1)

chance = np.mean(y)  # set chance level to white in the plot
av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores",
            cmap=plt.cm.Reds)
../../_images/sphx_glr_plot_decoding_csp_timefreq_002.png

Out:

No baseline correction applied

Total running time of the script: ( 0 minutes 15.026 seconds)

Gallery generated by Sphinx-Gallery