FBrCCA

This script shows how to use FBrCCA from PyntBCI for decoding c-VEP trials. The FBrCCA method uses a template matching classifier where templates are estimated using reconvolution and canonical correlation analysis (CCA). Additionally, FBrCCA uses a filterbank (FB).

The data used in this script come from Thielen et al. (2021), see references [1] and [2].

References

import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn

import pyntbci

seaborn.set_context("paper", font_scale=1.5)

Set the data path

The cell below specifies where the dataset has been downloaded to. Please, make sure it is set correctly according to the specification of your device. If none of the folder structures in the dataset were changed, the cells below should work just as fine.

path = os.path.join(os.path.dirname(pyntbci.__file__))  # path to the dataset
n_subjects = 5
subjects = [f"sub-{1 + i:02d}" for i in range(n_subjects)]

Analyse multiple participants

In this section, we directly apply FBrCCA and compare it to a single-band rCCA using the pass-bands defined in the filterbank.

# Set trial duration
trialtime = 4.2  # limit trials to a certain duration in seconds
n_trials = 100  # limit the number of trials in the dataset

# Set rCCA  (see pyntbci.classifiers.rCCA)
event = "duration"  # event definition type
encoding_length = 0.3  # length of a transient response in seconds
onset_event = True

# Set folds for chronological cross-validation
n_folds = 5
folds = np.repeat(np.arange(n_folds), int(n_trials / n_folds))

# Set filterbank (see pyntbci.utilities.filterbank)
tmin = 0  # seconds before trial onset that could catch filter artefacts and is cut off after the filterbank
filterbank = [  # passbands with lower and higher cutoffs in Hz
    (1.0, 60.0),
    (12.0, 60.0),
    (30.0, 60.0)]
ftype = "chebyshev1"  # filter type
gpass = 3  # maximum attenuation in the passband in dB
gstop = 20  # minimum attenuation in the stopband in dB
n_bands = len(filterbank)

# Loop participants
accuracy_rcca = np.zeros((n_subjects, n_folds, n_bands))
accuracy_fbrcca = np.zeros((n_subjects, n_folds))
for i_subject in range(n_subjects):
    subject = subjects[i_subject]

    # Load data
    fn = os.path.join(path, "data", f"thielen2021_{subject}.npz")
    tmp = np.load(fn)
    fs = int(tmp["fs"])
    X = tmp["X"][:n_trials, :, :int(trialtime * fs)]
    y = tmp["y"][:n_trials]
    V = tmp["V"]

    # Apply filterbank
    X = pyntbci.utilities.filterbank(X, filterbank, fs, tmin=tmin, ftype=ftype, gpass=gpass, gstop=gstop)

    # Cross-validation
    for i_fold in range(n_folds):

        # Split data to train and test set
        X_trn, y_trn = X[folds != i_fold, :, :, :], y[folds != i_fold]
        X_tst, y_tst = X[folds == i_fold, :, :, :], y[folds == i_fold]

        # Setup classifier
        rcca = pyntbci.classifiers.rCCA(stimulus=V, fs=fs, event=event, encoding_length=encoding_length,
                                        onset_event=onset_event)
        gate = pyntbci.gates.AggregateGate("mean")
        fbrcca = pyntbci.classifiers.Ensemble(estimator=rcca, gate=gate)

        # Train classifier
        fbrcca.fit(X_trn, y_trn)

        # Apply classifier
        yh_tst = fbrcca.predict(X_tst)

        # Compute accuracy
        accuracy_fbrcca[i_subject, i_fold] = np.mean(yh_tst == y_tst)

        # Loop individual pass-bands
        for i_band in range(n_bands):
            # Setup classifier
            rcca = pyntbci.classifiers.rCCA(stimulus=V, fs=fs, event=event, encoding_length=encoding_length,
                                            onset_event=onset_event)

            # Train classifier
            rcca.fit(X_trn[:, :, :, i_band], y_trn)

            # Apply classifier
            yh_tst = rcca.predict(X_tst[:, :, :, i_band])

            # Compute accuracy
            accuracy_rcca[i_subject, i_fold, i_band] = np.mean(yh_tst == y_tst)

# Stack results
accuracy = np.concatenate((accuracy_fbrcca[:, :, np.newaxis], accuracy_rcca), axis=2)

# Plot accuracy
plt.figure(figsize=(15, 4))
for i in range(1 + len(filterbank)):
    if i == 0:
        label = "FBrCCA"
    else:
        label = f"rCCA {filterbank[i - 1][0]:>4}-{filterbank[i - 1][1]:>4}"
    width = 0.9 / (1 + len(filterbank))
    plt.bar(i * width + np.arange(n_subjects), accuracy[:, :, i].mean(axis=1), width,
            yerr=accuracy[:, :, i].std(axis=1), label=label)
plt.axhline(1 / V.shape[0], linestyle="--", color="k", alpha=0.5, label="chance")
plt.xticks(width * (len(filterbank) / 2) + np.arange(n_subjects), subjects, rotation=45)
plt.legend(bbox_to_anchor=(1.0, 1.0))
plt.ylabel("accuracy")
plt.title("Decoding performance full dataset")
plt.tight_layout()

# Print accuracy
for i in range(1 + len(filterbank)):
    if i == 0:
        label = "FBrCCA"
    else:
        label = f"rCCA {filterbank[i - 1][0]:>4}-{filterbank[i - 1][1]:>4}"
    print(f"{label}: {np.mean(accuracy[-1, :, i]):.2f} +/- {np.std(accuracy[-1, :, i]):.2f}")

# plt.show()
Decoding performance full dataset
FBrCCA: 0.96 +/- 0.02
rCCA  1.0-60.0: 0.96 +/- 0.02
rCCA 12.0-60.0: 0.92 +/- 0.04
rCCA 30.0-60.0: 0.39 +/- 0.04

Total running time of the script: (1 minutes 7.556 seconds)

Gallery generated by Sphinx-Gallery