├── predict.py ├── pics ├── nb_A.png ├── nb_B.png ├── cnn_95.png ├── cnn_95a.png ├── testA.png ├── testB.png ├── trainA.png ├── trainB.png ├── acc_A_train.png ├── vallina_cnn_A.png ├── vallina_cnn_B.png ├── vallina_model_cnn.pdf ├── BCI_Comp_III_Wads_2004.pdf └── v_cnn.svg ├── .gitignore ├── dataset ├── sub_a.mat ├── sub_b.mat ├── sub_a_test.mat ├── sub_b_test.mat └── .gitattributes ├── description ├── untitled2.m ├── tru.m ├── untitled3.m ├── test_preprocess.m ├── train.m ├── preprocess.m ├── test.m ├── v_pca.m ├── acc.m ├── eloc64.txt ├── example.m └── topoplotEEG.m ├── README.md ├── train_valid_split.py ├── naive_bayes.py ├── Dataset.py ├── notebook ├── Untitled1.ipynb ├── Untitled.ipynb ├── NB Implem.ipynb └── NaiveBayes.ipynb ├── Model.py └── train.py /predict.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pics/nb_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/nb_A.png -------------------------------------------------------------------------------- /pics/nb_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/nb_B.png -------------------------------------------------------------------------------- /pics/cnn_95.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/cnn_95.png -------------------------------------------------------------------------------- /pics/cnn_95a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/cnn_95a.png -------------------------------------------------------------------------------- /pics/testA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/testA.png -------------------------------------------------------------------------------- /pics/testB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/testB.png -------------------------------------------------------------------------------- /pics/trainA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/trainA.png -------------------------------------------------------------------------------- /pics/trainB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/trainB.png -------------------------------------------------------------------------------- /pics/acc_A_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/acc_A_train.png -------------------------------------------------------------------------------- /pics/vallina_cnn_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/vallina_cnn_A.png -------------------------------------------------------------------------------- /pics/vallina_cnn_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/vallina_cnn_B.png -------------------------------------------------------------------------------- /pics/vallina_model_cnn.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/vallina_model_cnn.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm 2 | .idea/ 3 | 4 | # Python 5 | __pycache__/ 6 | .ipynb_checkpoints/ 7 | 8 | .DS_Store 9 | -------------------------------------------------------------------------------- /pics/BCI_Comp_III_Wads_2004.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omoiFly/CNN-for-P300/HEAD/pics/BCI_Comp_III_Wads_2004.pdf -------------------------------------------------------------------------------- /dataset/sub_a.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e6564066efc04797b827bdcd1b573a1ae4bf882668579b81a8d91744536f10a1 3 | size 80396790 4 | -------------------------------------------------------------------------------- /dataset/sub_b.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b01b3afb6f9842e5e62b80747df817067241aa92b3f48eded28dcb8198ce6e76 3 | size 80790550 4 | -------------------------------------------------------------------------------- /dataset/sub_a_test.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d362e71fcc7a493a53a6577715ec7a314f537a98d0f8fb30d97c19d9cd80fe46 3 | size 93605190 4 | -------------------------------------------------------------------------------- /dataset/sub_b_test.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:22f15eaf821c2b9ee1b819175a205f7ed775f9d48620fb776acebc12f78d9e9f 3 | size 95547937 4 | -------------------------------------------------------------------------------- /description/untitled2.m: -------------------------------------------------------------------------------- 1 | feat = []; 2 | for ep=1:85 3 | for rowcol=1:12 4 | ica = rica(responses(rowcol, :, 11, ep)', 3); 5 | feat = [feat; ica.TransformWeights]; 6 | end 7 | end -------------------------------------------------------------------------------- /dataset/.gitattributes: -------------------------------------------------------------------------------- 1 | sub_a.mat filter=lfs diff=lfs merge=lfs -text 2 | sub_a_test.mat filter=lfs diff=lfs merge=lfs -text 3 | sub_b.mat filter=lfs diff=lfs merge=lfs -text 4 | sub_b_test.mat filter=lfs diff=lfs merge=lfs -text 5 | -------------------------------------------------------------------------------- /description/tru.m: -------------------------------------------------------------------------------- 1 | dataset = data(); 2 | cv = cvpartition(size(dataset,1),'holdout',0.04); 3 | 4 | training_set = dataset(training(cv), :); 5 | testing_set = dataset(test(cv), :); 6 | 7 | training_label = training_set(:, 1); 8 | training_data = training_set(:, 2:5); 9 | 10 | testing_label = testing_set(:, 1); 11 | testing_data = testing_set(:, 2:5); 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /description/untitled3.m: -------------------------------------------------------------------------------- 1 | figure 2 | for ep=1:85 3 | for rowcol=1:12 4 | idx = rowcol*ep; 5 | if is_stimulate(rowcol, ep) == 1 6 | scatter3(feat(idx,1), feat(idx,2), feat(idx,3), 'r') 7 | hold on 8 | else 9 | scatter3(feat(idx,1), feat(idx,2), feat(idx,3), 'b') 10 | hold on 11 | end 12 | end 13 | end -------------------------------------------------------------------------------- /description/test_preprocess.m: -------------------------------------------------------------------------------- 1 | window=240; % window after stimulus (1s) 2 | 3 | responses = zeros(12, 15, 240, 64, 100); 4 | 5 | % convert to double precision 6 | Signal=double(Signal); 7 | Flashing=double(Flashing); 8 | StimulusCode=double(StimulusCode); 9 | 10 | % for each character epoch 11 | for epoch=1:size(Signal,1) 12 | % get reponse samples at start of each Flash 13 | rowcolcnt=ones(1,12); 14 | block = 1; 15 | for n=2:size(Signal,2) 16 | if Flashing(epoch,n)==0 && Flashing(epoch,n-1)==1 17 | rowcol=StimulusCode(epoch,n-1); 18 | responses(rowcol,rowcolcnt(rowcol),:,:,epoch)=Signal(epoch,n-24:n+window-25,:); 19 | rowcolcnt(rowcol)=rowcolcnt(rowcol)+1; 20 | end 21 | end 22 | end 23 | 24 | responses = reshape(mean(responses, 2), 12, 240, 64, 100); -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CNN for P300 Evoked Potentials 2 | 3 | ### Description 4 | 5 | [Dataset Description](http://www.bbci.de/competition/iii/desc_II.pdf) 6 | 7 | | | Subject_A | Subject_B | 8 | | -------- | --------- | --------- | 9 | | Accuracy | 0.87 | 0.95 | 10 | 11 | Using PyTorch for CNN Model 12 | 13 | ### Usage 14 | 15 | download [processed data](https://pan.baidu.com/s/1Tmh2D4oyL8PXKg8y-pI0ow) 提取码: hw8b 16 | 17 | to dataset/ folder 18 | 19 | run train.py ( Simply modify 'A' to 'B' for different subject ) 20 | 21 | ### Model 22 | 23 | See Model.py 'class Vanilla' 24 | 25 | ### Dataset 26 | 27 | Average every 15 times repeat experiment for higher SNR. 28 | 29 | Because every character only creates 2 positive samples but 10 negative, just repeat the 2 positive sample 4 more times ( See Dataset.py ) in order to balance positive between negative. 30 | 31 | dataset folder contains only processed data. 32 | 33 | ### Optimizer 34 | 35 | SGD + momentum with lr=5e-4 momentum=0.9 weight_decay=1e-4 36 | -------------------------------------------------------------------------------- /description/train.m: -------------------------------------------------------------------------------- 1 | clear all; 2 | clc; 3 | temp=zeros(12,85,64,240); 4 | responses_train=zeros(12,85,11,240); 5 | load ('data.mat'); 6 | data = []; 7 | label = []; 8 | 9 | temp=permute(responses,[2,1,3,4]); 10 | data_temp1=squeeze(temp(:,:,3,:)); 11 | data_temp2=squeeze(temp(:,:,5,:)); 12 | data_temp3=squeeze(temp(:,:,9,:)); 13 | data_temp4=squeeze(temp(:,:,11,:)); 14 | data_temp5=squeeze(temp(:,:,13,:)); 15 | data_temp6=squeeze(temp(:,:,22,:)); 16 | data_temp7=squeeze(temp(:,:,24,:)); 17 | data_temp8=squeeze(temp(:,:,34,:)); 18 | data_temp9=squeeze(temp(:,:,51,:)); 19 | data_temp10=squeeze(temp(:,:,56,:)); 20 | data_temp11=squeeze(temp(:,:,60,:)); 21 | 22 | responses_train=[data_temp1;data_temp2;data_temp3;data_temp4;data_temp5;data_temp6;data_temp7;data_temp8;data_temp9;data_temp10;data_temp11]; 23 | 24 | 25 | responses_train=permute(responses_train,[2,3,1]); 26 | responses_train=reshape(responses_train,12,85,2640); 27 | for i=1:12 28 | for j=1:85 29 | data = [data; responses_train(i, j,:)]; 30 | label = [label; is_stimulate(i, j)]; 31 | end 32 | end 33 | 34 | data = reshape(data, 1020, 2640); 35 | 36 | 37 | 38 | Model_A= fitcsvm(data,label); -------------------------------------------------------------------------------- /description/preprocess.m: -------------------------------------------------------------------------------- 1 | window=120; % window after stimulus (0.5s) 2 | 3 | responses = zeros(12, 15, window, 64, 85); 4 | is_stimulate = zeros(12, 15, 85); 5 | 6 | % convert to double precision 7 | Signal=double(Signal); 8 | Flashing=double(Flashing); 9 | StimulusCode=double(StimulusCode); 10 | StimulusType=double(StimulusType); 11 | 12 | % for each character epoch 13 | for epoch=1:size(Signal,1) 14 | % get reponse samples at start of each Flash 15 | rowcolcnt=ones(1,12); 16 | block = 1; 17 | for n=2:size(Signal,2) 18 | if Flashing(epoch,n)==0 && Flashing(epoch,n-1)==1 19 | rowcol=StimulusCode(epoch,n-1); 20 | responses(rowcol,rowcolcnt(rowcol),:,:,epoch)=Signal(epoch,n-24:n+window-25,:); 21 | rowcolcnt(rowcol)=rowcolcnt(rowcol)+1; 22 | if StimulusType(epoch, n-1) == 1 23 | is_stimulate(rowcol, block, epoch) = 1; 24 | end 25 | end 26 | if mod(n, 504) == 0 27 | block = block + 1; 28 | end 29 | end 30 | end 31 | 32 | 33 | responses = reshape(mean(responses, 2), 12, window, 64, 85); 34 | is_stimulate = uint8(reshape(mean(is_stimulate, 2), 12, 85)); -------------------------------------------------------------------------------- /train_valid_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class GenHelper(Dataset): 6 | def __init__(self, mother, length, mapping): 7 | # here is a mapping from this index to the mother ds index 8 | self.mapping=mapping 9 | self.length=length 10 | self.mother=mother 11 | 12 | def __getitem__(self, index): 13 | return self.mother[self.mapping[index]] 14 | 15 | def __len__(self): 16 | return self.length 17 | 18 | 19 | def train_valid_split(ds, split_fold=10, random_seed=None): 20 | ''' 21 | This is a pytorch generic function that takes a data.Dataset object and splits it to validation and training 22 | efficiently. 23 | :return: 24 | ''' 25 | if random_seed!=None: 26 | np.random.seed(random_seed) 27 | 28 | dslen=len(ds) 29 | indices= list(range(dslen)) 30 | valid_size=dslen//split_fold 31 | np.random.shuffle(indices) 32 | train_mapping=indices[valid_size:] 33 | valid_mapping=indices[:valid_size] 34 | train=GenHelper(ds, dslen - valid_size, train_mapping) 35 | valid=GenHelper(ds, valid_size, valid_mapping) 36 | 37 | return train, valid 38 | -------------------------------------------------------------------------------- /description/test.m: -------------------------------------------------------------------------------- 1 | 2 | load ('data_test.mat'); 3 | test_data=[]; 4 | 5 | % responses_predict=permute(responses_test,[1,4,3,2]); 6 | % responses_predict=reshape(responses_predict,12,100,15360); 7 | temp=permute(responses_test,[2,1,3,4]); 8 | data_temp1=squeeze(temp(:,:,3,:)); 9 | data_temp2=squeeze(temp(:,:,5,:)); 10 | data_temp3=squeeze(temp(:,:,9,:)); 11 | data_temp4=squeeze(temp(:,:,11,:)); 12 | data_temp5=squeeze(temp(:,:,13,:)); 13 | data_temp6=squeeze(temp(:,:,22,:)); 14 | data_temp7=squeeze(temp(:,:,24,:)); 15 | data_temp8=squeeze(temp(:,:,34,:)); 16 | data_temp9=squeeze(temp(:,:,51,:)); 17 | data_temp10=squeeze(temp(:,:,56,:)); 18 | data_temp11=squeeze(temp(:,:,60,:)); 19 | 20 | responses_preidict=[data_temp1;data_temp2;data_temp3;data_temp4;data_temp5;data_temp6;data_temp7;data_temp8;data_temp9;data_temp10;data_temp11]; 21 | 22 | 23 | responses_preidict=permute(responses_preidict,[2,3,1]); 24 | responses_preidict=reshape(responses_preidict,12,100,2640); 25 | 26 | 27 | for m=1:12 28 | for n=1:100 29 | test_data = [test_data; responses_preidict(m, n,:)]; 30 | end 31 | end 32 | 33 | test_data = reshape(test_data, 1200, 2640); 34 | 35 | [label,score] = predict(Model_A,test_data); 36 | 37 | -------------------------------------------------------------------------------- /description/v_pca.m: -------------------------------------------------------------------------------- 1 | function [RotMatrix, coe_PC, xRot, ProjectedData]=v_pca(data) 2 | % Implement PCA to obtain the rotation matrix (RotMatrix), which is the 3 | % eigenbasis of covariance of input. 4 | % edit by Chih-Sheng (Tommy) Huang. 5 | % 6 | % Input: 7 | % data: input data, n*dim 8 | % Output: 9 | % RotMatrix: PCA transformation matrix 10 | % coe_PC: variance 11 | % xRot: projected data 12 | % ProjectedData.xPCAwhite :PCA whitening 13 | % ProjectedData.xZCAWhite :ZCA whitening 14 | 15 | [N dim ]=size(data); 16 | if (dim> N) 17 | data=data'; 18 | tmpN =dim; 19 | tmpdim=N; 20 | N=tmpN; 21 | dim=tmpdim; 22 | end 23 | 24 | avg = mean(data); 25 | data = data - repmat(avg, N, 1); 26 | 27 | % covariance matrix 28 | sigma = data' * data / N; 29 | % singular value decomposition (SVD) 30 | [RotMatrix, S] = svd(sigma); 31 | coe_PC=diag(S); 32 | xRot = RotMatrix* data'; 33 | 34 | 35 | 36 | % PCA-whitening 37 | epsilon=10^(-5); 38 | xPCAwhite = diag(1./sqrt(coe_PC + epsilon)) * RotMatrix * data'; 39 | 40 | % ZCA-whitening 41 | xZCAWhite = RotMatrix * diag(1./sqrt(coe_PC + epsilon)) * RotMatrix * data'; 42 | 43 | ProjectedData.xPCAwhite=xPCAwhite; 44 | ProjectedData.xZCAWhite=xZCAWhite; -------------------------------------------------------------------------------- /description/acc.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | dis = ['A' 'B' 'C' 'D' 'E' 'F'; 4 | 'G' 'H' 'I' 'J' 'K' 'L'; 5 | 'M' 'N' 'O' 'P' 'Q' 'R'; 6 | 'S' 'T' 'U' 'V' 'W' 'X'; 7 | 'Y' 'Z' '1' '2' '3' '4'; 8 | '5' '6' '7' '8' '9' '_']; 9 | 10 | target=('WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU'); 11 | % load ('Predict_A.mat', 'label'); 12 | row_col=zeros(100,2); 13 | 14 | 15 | for i=1:100 16 | for col=1:6 17 | j=col+12*(i-1); 18 | if label(j)==1 19 | row_col(i,2)=col; 20 | 21 | elseif col==6&&row_col(i,2)==0 22 | row_col(i,2)=1; 23 | end 24 | end 25 | for row=7:12 26 | k=row+12*(i-1); 27 | if label(k)==1 28 | row_col(i,1)=row-6; 29 | 30 | elseif row==12&&row_col(i,1)==0 31 | row_col(i,1)=1; 32 | end 33 | end 34 | 35 | end 36 | 37 | is_target=[]; 38 | for i=1:100 39 | 40 | is_target(i)=dis(row_col(i,1),row_col(i,2)); 41 | predict_char=char(is_target()); 42 | 43 | end 44 | count=0; 45 | for i=1:100 46 | if predict_char(i)==target(i) 47 | count=count+1; 48 | end 49 | end 50 | accuracy=count/100; 51 | disp(sprintf('accuracy= %8.2\n',accuracy)); 52 | -------------------------------------------------------------------------------- /description/eloc64.txt: -------------------------------------------------------------------------------- 1 | 1 -68 0.3 Fc5. 2 | 2 -60 0.219 Fc3. 3 | 3 -42 0.14 Fc1. 4 | 4 0 0.101 Fcz. 5 | 5 42 0.14 Fc2. 6 | 6 60 0.219 Fc4. 7 | 7 68 0.3 Fc6. 8 | 8 -90 0.304 C5.. 9 | 9 -90 0.203 C3.. 10 | 10 -90 0.101 C1.. 11 | 11 0 0 Cz.. 12 | 12 90 0.101 C2.. 13 | 13 90 0.203 C4.. 14 | 14 90 0.304 C6.. 15 | 15 -112 0.3 Cp5. 16 | 16 -120 0.219 Cp3. 17 | 17 -138 0.14 Cp1. 18 | 18 180 0.101 Cpz. 19 | 19 138 0.14 Cp2. 20 | 20 120 0.219 Cp4. 21 | 21 112 0.3 Cp6. 22 | 22 -18 0.406 Fp1. 23 | 23 0 0.406 Fpz. 24 | 24 18 0.406 Fp2. 25 | 25 -36 0.406 Af7. 26 | 26 -23 0.343 Af3. 27 | 27 0 0.304 Afz. 28 | 28 23 0.343 Af4. 29 | 29 36 0.406 Af8. 30 | 30 -54 0.406 F7.. 31 | 31 -46 0.363 F5.. 32 | 32 -35 0.275 F3.. 33 | 33 -21 0.219 F1.. 34 | 34 0 0.203 Fz.. 35 | 35 21 0.219 F2.. 36 | 36 35 0.275 F4.. 37 | 37 46 0.363 F6.. 38 | 38 54 0.406 F8.. 39 | 39 -72 0.406 Ft7. 40 | 40 72 0.406 Ft8. 41 | 41 -90 0.406 T7.. 42 | 42 90 0.406 T8.. 43 | 43 -90 0.499 T9.. 44 | 44 90 0.499 T10. 45 | 45 -108 0.406 Tp7. 46 | 46 108 0.406 Tp8. 47 | 47 -126 0.406 P7.. 48 | 48 -134 0.343 P5.. 49 | 49 -145 0.275 P3.. 50 | 50 -159 0.219 P1.. 51 | 51 180 0.181 Pz.. 52 | 52 159 0.219 P2.. 53 | 53 145 0.275 P4.. 54 | 54 134 0.343 P6.. 55 | 55 126 0.406 P8.. 56 | 56 -144 0.406 Po7. 57 | 57 -157 0.343 Po3. 58 | 58 180 0.304 Poz. 59 | 59 157 0.343 Po4. 60 | 60 144 0.406 Po8. 61 | 61 -162 0.406 O1.. 62 | 62 180 0.406 Oz.. 63 | 63 162 0.406 O2.. 64 | 64 180 0.499 Iz.. 65 | 66 | -------------------------------------------------------------------------------- /naive_bayes.py: -------------------------------------------------------------------------------- 1 | from sklearn.naive_bayes import GaussianNB 2 | from sklearn import svm 3 | from train import get_feature_map, get_feature, train_loaderB, trainA, valA, trainB, valB 4 | from Dataset import Testset 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | 8 | for i in range(60): 9 | trainB(i) 10 | valB(i) 11 | 12 | test_loader = DataLoader( 13 | dataset=Testset('B'), 14 | shuffle=False, 15 | batch_size=100 16 | ) 17 | 18 | clf = svm.SVC() 19 | gnb = GaussianNB() 20 | 21 | cols, rows = get_feature(test_loader) 22 | x, y = get_feature_map(train_loaderB) 23 | 24 | gnb.fit(x,y) 25 | print(gnb.score(x, y)) 26 | 27 | clf.fit(x, y) 28 | 29 | pred_cols = [] 30 | pred_rows = [] 31 | for i in range(6): 32 | pred_cols.append(clf.predict(cols[i,:,:])) 33 | pred_rows.append(clf.predict(rows[i,:,:])) 34 | 35 | pred_rows = np.array(pred_rows) 36 | pred_cols = np.array(pred_cols) 37 | 38 | char = [['A', 'B', 'C', 'D', 'E', 'F'], 39 | ['G', 'H', 'I', 'J', 'K', 'L'], 40 | ['M', 'N', 'O', 'P', 'Q', 'R'], 41 | ['S', 'T', 'U', 'V', 'W', 'X'], 42 | ['Y', 'Z', '1', '2', '3', '4'], 43 | ['5', '6', '7', '8', '9', '_']] 44 | 45 | series = [] 46 | real_a = 'WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU' 47 | real_b = 'MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR' 48 | 49 | for i in range(100): 50 | series.append(char[np.argmax(pred_rows[:, i])][np.argmax(pred_cols[:, i])]) 51 | 52 | series = ''.join(series) 53 | print(series) 54 | counter = 0 55 | for i in range(len(real_a)): 56 | if real_b[i] == series[i]: 57 | counter += 1 58 | 59 | print(counter / len(real_a)) 60 | -------------------------------------------------------------------------------- /Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | from scipy.io import loadmat 5 | 6 | 7 | class Trainset(Dataset): 8 | def __init__(self, subject_name): 9 | if subject_name == 'A': 10 | raw_data = loadmat('dataset/sub_a.mat') 11 | else: 12 | raw_data = loadmat('dataset/sub_b.mat') 13 | 14 | signals = raw_data['responses'] 15 | label = raw_data['is_stimulate'] 16 | data = [] 17 | target = [] 18 | for i in range(12): 19 | for j in range(85): 20 | if label[i, j] == 1: 21 | data.append(signals[i, :, :, j].reshape(-1, 64)) 22 | target.append(label[i, j]) 23 | data.append(signals[i, :, :, j].reshape(-1, 64)) 24 | target.append(label[i, j]) 25 | data.append(signals[i, :, :, j].reshape(-1, 64)) 26 | target.append(label[i, j]) 27 | data.append(signals[i, :, :, j].reshape(-1, 64)) 28 | target.append(label[i, j]) 29 | data.append(signals[i, :, :, j].reshape(-1, 64)) 30 | target.append(label[i, j]) 31 | 32 | self.data = np.array(data) 33 | self.target = np.array(target) 34 | 35 | def __len__(self): 36 | return self.target.shape[0] 37 | 38 | def __getitem__(self, index): 39 | return self.data[index, :, :].astype(np.float32).T, self.target[index].astype(np.float32) 40 | 41 | 42 | class Testset(Dataset): 43 | def __init__(self, subject_name): 44 | if subject_name == 'A': 45 | raw_data = loadmat('dataset/sub_a_test.mat') 46 | else: 47 | raw_data = loadmat('dataset/sub_b_test.mat') 48 | self.signals = raw_data['responses'] 49 | 50 | def __len__(self): 51 | return self.signals.shape[-1] 52 | 53 | def __getitem__(self, index): 54 | col = self.signals[:6, :, :, index].astype(np.float32).transpose([0, 2, 1]) 55 | row = self.signals[6:, :, :, index].astype(np.float32).transpose([0, 2, 1]) 56 | return col, row 57 | 58 | -------------------------------------------------------------------------------- /notebook/Untitled1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "load Subject_A_Train.mat;\n", 10 | "window=240; % window after stimulus (1s)\n", 11 | "\n", 12 | "responses_train = zeros(12, 15, 240, 64, 85);\n", 13 | "is_stimulate = zeros(12, 15, 85);\n", 14 | "\n", 15 | "% convert to double precision\n", 16 | "Signal=double(Signal);\n", 17 | "Flashing=double(Flashing);\n", 18 | "StimulusCode=double(StimulusCode);\n", 19 | "StimulusType=double(StimulusType);\n", 20 | "\n", 21 | "% for each character epoch\n", 22 | "for epoch=1:size(Signal,1)\n", 23 | " % get reponse samples at start of each Flash\n", 24 | " rowcolcnt=ones(1,12);\n", 25 | " block = 1;\n", 26 | " for n=2:size(Signal,2)\n", 27 | " if Flashing(epoch,n)==0 && Flashing(epoch,n-1)==1\n", 28 | " rowcol=StimulusCode(epoch,n-1);\n", 29 | " responses_train(rowcol,rowcolcnt(rowcol),:,:,epoch)=Signal(epoch,n-24:n+window-25,:);\n", 30 | " rowcolcnt(rowcol)=rowcolcnt(rowcol)+1;\n", 31 | " if StimulusType(epoch, n-1) == 1\n", 32 | " is_stimulate(rowcol, block, epoch) = 1; \n", 33 | " end\n", 34 | " end\n", 35 | " if mod(n, 504) == 0\n", 36 | " block = block + 1;\n", 37 | " end\n", 38 | " end\n", 39 | "end\n", 40 | "\n", 41 | "\n", 42 | "responses_train = reshape(mean(responses_train, 2), 12, 240, 64, 85);\n", 43 | "is_stimulate_train = uint8(reshape(mean(is_stimulate, 2), 12, 85));" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "clear all; \n", 53 | "clc;\n", 54 | "temp=zeros(12,85,64,240);\n", 55 | "\n", 56 | "load ('data_train_A.mat');\n", 57 | "data = [];\n", 58 | "label = [];\n", 59 | "\n", 60 | "temp=permute(responses_train,[2,1,3,4]);\n", 61 | "data_temp1=squeeze(temp(:,:,3,:));\n", 62 | "data_temp2=squeeze(temp(:,:,5,:));\n", 63 | "data_temp3=squeeze(temp(:,:,9,:));\n", 64 | "data_temp4=squeeze(temp(:,:,11,:));\n", 65 | "data_temp5=squeeze(temp(:,:,13,:));\n", 66 | "data_temp6=squeeze(temp(:,:,22,:));\n", 67 | "data_temp7=squeeze(temp(:,:,24,:));\n", 68 | "data_temp8=squeeze(temp(:,:,34,:));\n", 69 | "data_temp9=squeeze(temp(:,:,51,:));\n", 70 | "data_temp10=squeeze(temp(:,:,56,:));\n", 71 | "data_temp11=squeeze(temp(:,:,60,:));\n", 72 | "\n", 73 | "responses_train=[data_temp1;data_temp2;data_temp3;data_temp4;data_temp5;data_temp6;data_temp7;data_temp8;data_temp9;data_temp10;data_temp11];\n", 74 | "\n", 75 | "\n", 76 | "responses_train=permute(responses_train,[2,3,1]);\n", 77 | "responses_train=reshape(responses_train,12,85,2640);\n", 78 | "for i=1:12\n", 79 | " for j=1:85 \n", 80 | " data = [data; responses_train(i, j,:)];\n", 81 | " label = [label; is_stimulate_train(i, j)];\n", 82 | " end\n", 83 | "end\n", 84 | "\n", 85 | "data = reshape(data, 1020, 2640);\n", 86 | "\n", 87 | "Model_A= fitcsvm(data,label);" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "Matlab [conda env:jmatlab]", 101 | "language": "matlab", 102 | "name": "conda-env-jmatlab-matlab" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": "octave", 106 | "file_extension": ".m", 107 | "help_links": [ 108 | { 109 | "text": "MetaKernel Magics", 110 | "url": "https://github.com/calysto/metakernel/blob/master/metakernel/magics/README.md" 111 | } 112 | ], 113 | "mimetype": "text/x-octave", 114 | "name": "matlab", 115 | "version": "0.16.1" 116 | } 117 | }, 118 | "nbformat": 4, 119 | "nbformat_minor": 2 120 | } 121 | -------------------------------------------------------------------------------- /notebook/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.\n", 13 | " warnings.warn(\"nn.{} is deprecated. Use nn.functional.interpolate instead.\".format(self.name))\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "torch.Size([1020, 50])\n", 21 | "torch.Size([85, 50])\n", 22 | "0.6352941176470588\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "from naive_bayes import predict, gnb\n", 28 | "\n", 29 | "import numpy as np" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 20, 35 | "metadata": { 36 | "scrolled": true 37 | }, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "GCHAMGARKDF5DY4F9X4Q27ZM8ZXQJCHA_4UPLE8CBS8YSPLEJ16Y92UNGB3TQIAMV95I5HFXK2IZG_4WPIBY5\n" 44 | ] 45 | }, 46 | { 47 | "data": { 48 | "text/plain": [ 49 | "0.011764705882352941" 50 | ] 51 | }, 52 | "execution_count": 20, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "char = [['A', 'B', 'C', 'D', 'E', 'F'],\n", 59 | " ['G', 'H', 'I', 'J', 'K', 'L'],\n", 60 | " ['M', 'N', 'O', 'P', 'Q', 'R'],\n", 61 | " ['S', 'T', 'U', 'V', 'W', 'X'],\n", 62 | " ['Y', 'Z', '1', '2', '3', '4'],\n", 63 | " ['5', '6', '7', '8', '9', '_']]\n", 64 | "\n", 65 | "series = []\n", 66 | "real_a = 'WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU'\n", 67 | "real_b = 'MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR'\n", 68 | "train_a = 'EAEVQTDOJG8RBRGONCEDHCTUIDBPUHMEM6OUXOCFOUKWA4VJEFRZROLHYNQDW_EKTLBWXEPOUIKZERYOOTHQI'\n", 69 | "\n", 70 | "for i in range(85):\n", 71 | " col_pred_set = predict[i+6:i+12]\n", 72 | " row_pred_set = predict[i:i+6]\n", 73 | " if np.max(col_pred_set, axis=0) != 0:\n", 74 | " col_pred = np.where(col_pred_set == np.max(col_pred_set, axis=0))[0][0]\n", 75 | " else:\n", 76 | " col_pred = np.random.randint(6)\n", 77 | " if np.max(row_pred_set, axis=0):\n", 78 | " row_pred = np.where(row_pred_set == np.max(row_pred_set, axis=0))[0][0]\n", 79 | " else:\n", 80 | " row_pred = np.random.randint(6)\n", 81 | " \n", 82 | " series.append(char[int(row_pred)][int(col_pred)])\n", 83 | "\n", 84 | "series = ''.join(series)\n", 85 | "\n", 86 | "print(series)\n", 87 | "counter = 0\n", 88 | "for i in range(len(train_a)):\n", 89 | " if train_a[i] == series[i]:\n", 90 | " counter += 1\n", 91 | "\n", 92 | "counter / len(train_a)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 18, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "import numpy as np" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "data": { 111 | "text/plain": [ 112 | "2.302585092994046" 113 | ] 114 | }, 115 | "execution_count": 4, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "np.log(10)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python [conda env:torch]", 135 | "language": "python", 136 | "name": "conda-env-torch-py" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.6.6" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 2 153 | } 154 | -------------------------------------------------------------------------------- /description/example.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 3 | %%%%% %%%%%% 4 | %%%%% sample classification ofthe P300 test data %%%%%% 5 | %%%%% %%%%%% 6 | %%%%% BCI Competition III Challenge %%%%%% 7 | %%%%% %%%%%% 8 | %%%%% (C) Dean Krusienski and Gerwin Schalk 2004 %%%%%% 9 | %%%%% Wadsworth Center/NYSDOH %%%%%% 10 | %%%%% %%%%%% 11 | %%%%% function calls: topoplotEEG.m %%%%%% 12 | %%%%% %%%%%% 13 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 14 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 15 | 16 | close all; clear all; clc 17 | 18 | fprintf(1, '********************************************* \n' ); 19 | fprintf(1, ' Sample classification of the P300 test data \n' ); 20 | fprintf(1, ' BCI Competition III Challenge \n' ); 21 | fprintf(1, ' (C) Dean Krusienski and Gerwin Schalk 2004 \n' ); 22 | fprintf(1, ' Wadsworth Center/NYSDOH \n' ); 23 | fprintf(1, '********************************************* \n\n' ); 24 | 25 | TargetChar=[]; 26 | StimulusType=[]; 27 | 28 | fprintf(1, 'Collecting Responses and Performing classification... \n\n' ); 29 | load 'Subject_A_Train.mat' % load data file 30 | window=240; % window after stimulus (1s) 31 | channel=11; % only using Cz for analysis and plots 32 | 33 | % convert to double precision 34 | Signal=double(Signal); 35 | Flashing=double(Flashing); 36 | StimulusCode=double(StimulusCode); 37 | StimulusType=double(StimulusType); 38 | 39 | % 6 X 6 onscreen matrix 40 | screen=char('A','B','C','D','E','F',... 41 | 'G','H','I','J','K','L',... 42 | 'M','N','O','P','Q','R',... 43 | 'S','T','U','V','W','X',... 44 | 'Y','Z','1','2','3','4',... 45 | '5','6','7','8','9','_'); 46 | 47 | % for each character epoch 48 | for epoch=1:size(Signal,1) 49 | 50 | % get reponse samples at start of each Flash 51 | rowcolcnt=ones(1,12); 52 | for n=2:size(Signal,2) 53 | if Flashing(epoch,n)==0 && Flashing(epoch,n-1)==1 54 | rowcol=StimulusCode(epoch,n-1); 55 | responses(rowcol,rowcolcnt(rowcol),:,:)=Signal(epoch,n-24:n+window-25,:); 56 | rowcolcnt(rowcol)=rowcolcnt(rowcol)+1; 57 | end 58 | end 59 | 60 | % average and group responses by letter 61 | m=1; 62 | avgresp=mean(responses,2); 63 | avgresp=reshape(avgresp,12,window,64); 64 | for row=7:12 65 | for col=1:6 66 | % row-column intersection 67 | letter(m,:,:)=(avgresp(row,:,:)+avgresp(col,:,:))/2; 68 | % the crude avg peak classifier score (**tuned for Subject_A**) 69 | score(m)=mean(letter(m,54:124,channel))-mean(letter(m,134:174,channel)); 70 | m=m+1; 71 | end 72 | end 73 | 74 | [val,index]=max(score); 75 | charvect(epoch)=screen(index); 76 | 77 | % if labeled, get target label and response 78 | if isempty(StimulusType)==0 79 | label=unique(StimulusCode(epoch,:).*StimulusType(epoch,:)); 80 | targetlabel=(6*(label(3)-7))+label(2); 81 | Target(epoch,:,:)=.5*(avgresp(label(2),:,:)+avgresp(label(3),:,:)); 82 | NonTarget(epoch,:,:)=mean(avgresp,1)-(1/6)*Target(epoch,:,:); 83 | end 84 | end 85 | 86 | % display results 87 | 88 | if isempty(TargetChar)==0 89 | 90 | k=0; 91 | for p=1:size(Signal,1) 92 | if charvect(p)==TargetChar(p) 93 | k=k+1; 94 | end 95 | end 96 | 97 | correct=(k/size(Signal,1))*100; 98 | 99 | fprintf(1, 'Classification Results: \n\n' ); 100 | for kk=1:size(Signal,1) 101 | fprintf(1, 'Epoch: %d Predicted: %c Target: %c\n',kk,charvect(kk),TargetChar(kk)); 102 | end 103 | fprintf(1, '\n %% Correct from Labeled Data: %2.2f%% \n',correct); 104 | 105 | % plot averaged responses and topography 106 | Tavg=reshape(mean(Target(:,:,:),1),window,64); 107 | NTavg=reshape(mean(NonTarget(:,:,:),1),window,64); 108 | figure 109 | plot([1:window]/window,Tavg(:,channel),'linewidth',2) 110 | hold on 111 | plot([1:window]/window,NTavg(:,channel),'r','linewidth',2) 112 | title('Averaged P300 Responses over Cz') 113 | legend('Targets','NonTargets'); 114 | xlabel('time (s) after stimulus') 115 | ylabel('amplitude (uV)') 116 | 117 | % Target/NonTarget voltage topography plot at 300ms (sample 72) 118 | vdiff=abs(Tavg(72,:)-NTavg(72,:)); 119 | figure 120 | topoplotEEG(vdiff,'eloc64.txt','gridscale',150) 121 | title('Target/NonTarget Voltage Difference Topography at 300ms') 122 | caxis([min(vdiff) max(vdiff)]) 123 | colorbar 124 | 125 | else 126 | 127 | for kk=1:size(Signal,1) 128 | fprintf(1, 'Epoch: %d Predicted: %c\n',kk,charvect(kk)); 129 | end 130 | 131 | end 132 | 133 | fprintf(1, '\nThe resulting classified character vector is the variable named "charvect". \n'); 134 | fprintf(1, 'This is an example of how the results *must* be formatted for submission. \n'); 135 | fprintf(1, 'The character vectors from each case and subject are to be labeled, grouped, and submitted according to the accompanied documentation. \n'); 136 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Vanilla(nn.Module): 7 | def __init__(self): 8 | super(Vanilla, self).__init__() 9 | 10 | self.conv = nn.Sequential( 11 | nn.Conv1d(64, 5, 5, padding=2), 12 | nn.MaxPool1d(2), 13 | nn.ReLU() 14 | ) 15 | 16 | self.fc = nn.Sequential( 17 | nn.Linear(5*120, 30), 18 | nn.Dropout(0.5), 19 | nn.ReLU() 20 | ) 21 | 22 | self.out = nn.Sequential( 23 | nn.Linear(30, 2), 24 | nn.Dropout(0.5), 25 | ) 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | fc = self.fc(x.view(x.size(0), -1)) 30 | x = self.out(fc).squeeze() 31 | 32 | return x, fc 33 | 34 | 35 | class RNN(nn.Module): 36 | def __init__(self): 37 | super(RNN, self).__init__() 38 | 39 | self.rnn = nn.LSTM( # LSTM 效果要比 nn.RNN() 好多了 40 | input_size=240, # 图片每行的数据像素点 41 | hidden_size=64, # rnn hidden unit 42 | num_layers=1, # 有几层 RNN layers 43 | batch_first=True, # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size) 44 | ) 45 | 46 | self.out = nn.Linear(64, 2) # 输出层 47 | 48 | def forward(self, x): 49 | # x shape (batch, time_step, input_size) 50 | # r_out shape (batch, time_step, output_size) 51 | # h_n shape (n_layers, batch, hidden_size) LSTM 有两个 hidden states, h_n 是分线, h_c 是主线 52 | # h_c shape (n_layers, batch, hidden_size) 53 | r_out, (h_n, h_c) = self.rnn(x, None) # None 表示 hidden state 会用全0的 state 54 | 55 | # 选取最后一个时间点的 r_out 输出 56 | # 这里 r_out[:, -1, :] 的值也是 h_n 的值 57 | out = self.out(r_out[:, -1, :]) 58 | return out, 0 59 | 60 | 61 | class ResBlock(nn.Module): 62 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 63 | super(ResBlock, self).__init__() 64 | self.conv0 = nn.Conv1d(in_channels, out_channels, 3, padding=1) 65 | 66 | self.conv1 = nn.Conv1d(out_channels, out_channels, 3, stride, padding=1) 67 | self.bn1 = nn.BatchNorm1d(out_channels) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.conv2 = nn.Conv1d(out_channels, out_channels, 3, stride, padding=1) 70 | self.bn2 = nn.BatchNorm1d(out_channels) 71 | self.downsample = downsample 72 | 73 | def forward(self, x): 74 | x = self.conv0(x) 75 | residual = x 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | if self.downsample: 82 | residual = self.downsample(x) 83 | out += residual 84 | out = self.relu(out) 85 | return out 86 | 87 | 88 | class ResCNN(nn.Module): 89 | def __init__(self): 90 | super(ResCNN, self).__init__() 91 | self.in_channels = 16 92 | self.conv = nn.Conv1d(64, self.in_channels, 3) 93 | self.bn = nn.BatchNorm1d(self.in_channels) 94 | 95 | self.resblk = ResBlock(self.in_channels, 6) 96 | self.sameblk = ResBlock(6, 6) 97 | 98 | self.dropout = nn.Dropout(0.5) 99 | self.fc = nn.Linear(6*238, 2) 100 | 101 | def forward(self, x): 102 | x = self.conv(x) 103 | x = self.bn(x) 104 | x = self.dropout(x) 105 | x = self.resblk(x) 106 | x = self.dropout(x) 107 | out = self.fc(x.view(x.size(0), -1)) 108 | 109 | return out, x.view(x.size(0), -1) 110 | 111 | 112 | class AutoEncoder(nn.Module): 113 | def __init__(self): 114 | super(AutoEncoder, self).__init__() 115 | self.conv_down = nn.Sequential( 116 | nn.Conv1d(64, 5, 5, padding=2), 117 | nn.MaxPool1d(2), 118 | nn.ReLU(), 119 | ) 120 | 121 | self.dense_down = nn.Sequential( 122 | nn.Linear(5*120, 30), 123 | nn.ReLU(), 124 | ) 125 | 126 | self.dense_up = nn.Sequential( 127 | nn.Linear(30, 5*120), 128 | ) 129 | 130 | # 需要 reshape 回 128 * 58 131 | self.conv_up = nn.Sequential( 132 | nn.ConvTranspose1d(5, 64, 5, padding=2), 133 | nn.Upsample(scale_factor=2), 134 | ) 135 | 136 | def forward(self, x): 137 | x = self.conv_down(x) 138 | hidden = self.dense_down(x.view(-1, 5*120)) 139 | out = self.dense_up(hidden) 140 | out = self.conv_up(out.view(-1, 5, 120)) 141 | 142 | return out, hidden 143 | 144 | 145 | class InstructedAE(nn.Module): 146 | def __init__(self): 147 | super(InstructedAE, self).__init__() 148 | 149 | self.conv1 = nn.Sequential( 150 | nn.Conv1d(64, 10, 5, padding=2), 151 | nn.MaxPool1d(4), 152 | nn.ReLU() 153 | ) 154 | 155 | self.fc1 = nn.Sequential( 156 | nn.Linear(10*60, 50), 157 | nn.ReLU(), 158 | ) 159 | 160 | self.linear_upsample = nn.Sequential( 161 | nn.Linear(50, 10*60), 162 | nn.ReLU(), 163 | ) 164 | 165 | self.conv_upsample = nn.Sequential( 166 | nn.Upsample(scale_factor=4), 167 | nn.Conv1d(10, 64, 5, padding=2), 168 | ) 169 | 170 | def forward(self, x): 171 | x = self.conv1(x).view(x.size(0), -1) 172 | hidden = self.fc1(x) 173 | out = self.linear_upsample(hidden) 174 | out = self.conv_upsample(out.view(out.size(0), 10, -1)) 175 | 176 | return out, hidden 177 | 178 | 179 | if __name__ == '__main__': 180 | import torch.onnx as onnx 181 | model = Vanilla() 182 | model.train(False) 183 | onnx.export(model, torch.zeros([85, 64, 240]), './vanilla_cnn') -------------------------------------------------------------------------------- /notebook/NB Implem.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 44, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "def separate_by_class(dataset):\n", 19 | " separated = {}\n", 20 | " for i in range(len(dataset)):\n", 21 | " vector = dataset[i]\n", 22 | " if (vector[-1] not in separated):\n", 23 | " separated[vector[-1]] = []\n", 24 | " separated[vector[-1]].append(vector)\n", 25 | " return separated" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 71, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "data": { 35 | "text/plain": [ 36 | "{0: [[1, 2, 3, 0], [1, 3, 2, 0]], 1: [[0, 3, 2, 1], [0, 2, 3, 1]]}" 37 | ] 38 | }, 39 | "execution_count": 71, 40 | "metadata": {}, 41 | "output_type": "execute_result" 42 | } 43 | ], 44 | "source": [ 45 | "data = [[1, 2, 3, 0],\n", 46 | " [1, 3, 2, 0],\n", 47 | " [0, 3, 2, 1],\n", 48 | " [0, 2, 3, 1]]\n", 49 | "\n", 50 | "testset = [[1, 3, 3, 0],\n", 51 | " [0, 3, 3, 1]]\n", 52 | "\n", 53 | "separate_by_class(data)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 91, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "separated {0: [[1, 2, 3, 0], [1, 3, 2, 0]], 1: [[0, 3, 2, 1], [0, 2, 3, 1]]}\n" 66 | ] 67 | }, 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "{0: [(1.0, 0.0), (2.5, 0.7071067811865476), (2.5, 0.7071067811865476)],\n", 72 | " 1: [(0.0, 0.0), (2.5, 0.7071067811865476), (2.5, 0.7071067811865476)]}" 73 | ] 74 | }, 75 | "execution_count": 91, 76 | "metadata": {}, 77 | "output_type": "execute_result" 78 | } 79 | ], 80 | "source": [ 81 | "def mean(numbers):\n", 82 | " return sum(numbers) / float(len(numbers))\n", 83 | "\n", 84 | "def stdev(numbers):\n", 85 | " avg = mean(numbers)\n", 86 | " variance = sum([pow(x - avg, 2) for x in numbers]) / float(len(numbers) - 1)\n", 87 | " return np.sqrt(variance)\n", 88 | "\n", 89 | "def summary(dataset):\n", 90 | " summaries = [(mean(attr), stdev(attr)) for attr in zip(*dataset)]\n", 91 | " del summaries[-1]\n", 92 | " return summaries\n", 93 | "\n", 94 | "def summarize_by_class(dataset):\n", 95 | " separated = separate_by_class(dataset)\n", 96 | " print(\"separated {}\".format(separated))\n", 97 | " summaries = {}\n", 98 | " for cls_val, instances in separated.items():\n", 99 | " summaries[cls_val] = summary(instances)\n", 100 | " return summaries\n", 101 | "\n", 102 | "summarize_by_class(data)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 83, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "def calc_prob(x, mean, stdev):\n", 112 | " exp = np.exp(-(pow(x-mean, 2)/(2*pow(stdev, 2))))\n", 113 | " return (1/(np.sqrt(2*np.pi)*stdev))*exp\n", 114 | "\n", 115 | "def calc_cls_prob(summaries, new_vector):\n", 116 | " probs = {}\n", 117 | " for cls_val, cls_summaries in summaries.items():\n", 118 | " probs[cls_val] = 1\n", 119 | " for i in range(len(cls_summaries)):\n", 120 | " mean, stdev = cls_summaries[i]\n", 121 | " probs[cls_val] *= calc_prob(new_vector[i], mean, stdev)\n", 122 | " return probs" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 84, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "def predict(summaries, new_vector):\n", 132 | " probs = calc_cls_prob(summaries, new_vector)\n", 133 | " best_label, best_prob = None, -1\n", 134 | " for cls_val, prob in probs.items():\n", 135 | " if best_label is None or prob > best_prob:\n", 136 | " best_label = cls_val\n", 137 | " best_prob = prob\n", 138 | " return best_label\n", 139 | "\n", 140 | "def get_pred(summaries, testset):\n", 141 | " preds = []\n", 142 | " for i in range(len(testset)):\n", 143 | " result = predict(summaries, testset[i])\n", 144 | " preds.append(result)\n", 145 | " return preds\n", 146 | "\n", 147 | "def get_acc(testset, preds):\n", 148 | " correct = 0\n", 149 | " for x in range(len(testset)):\n", 150 | " if testset[x][-1] == preds[x]:\n", 151 | " correct += 1\n", 152 | " return (correct / float(len(testset))) * 100.0" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 85, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "[[1, 2, 3, 0], [1, 3, 2, 0], [0, 3, 2, 1], [0, 2, 3, 1]]" 164 | ] 165 | }, 166 | "execution_count": 85, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "data" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 86, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "{0: [(1.0, 0.0), (2.5, 0.7071067811865476), (2.5, 0.7071067811865476)], 1: [(0.0, 0.0), (2.5, 0.7071067811865476), (2.5, 0.7071067811865476)]}\n", 185 | "[0, 0]\n", 186 | "50.0\n" 187 | ] 188 | }, 189 | { 190 | "name": "stderr", 191 | "output_type": "stream", 192 | "text": [ 193 | "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:2: RuntimeWarning: invalid value encountered in double_scalars\n", 194 | " \n", 195 | "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:3: RuntimeWarning: divide by zero encountered in double_scalars\n", 196 | " This is separate from the ipykernel package so we can avoid doing imports until\n", 197 | "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:2: RuntimeWarning: divide by zero encountered in double_scalars\n", 198 | " \n", 199 | "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:3: RuntimeWarning: invalid value encountered in double_scalars\n", 200 | " This is separate from the ipykernel package so we can avoid doing imports until\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "summaries = summarize_by_class(data)\n", 206 | "print(summaries)\n", 207 | "preds = get_pred(summaries, testset)\n", 208 | "acc = get_acc(testset, preds)\n", 209 | "print(preds)\n", 210 | "print(acc)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python [conda env:torch]", 224 | "language": "python", 225 | "name": "conda-env-torch-py" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.6.6" 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 2 242 | } 243 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import hiddenlayer as hl 7 | import torch.onnx as onnx 8 | from Dataset import Trainset, Testset 9 | from train_valid_split import train_valid_split 10 | from Model import Vanilla, AutoEncoder, InstructedAE, ResCNN, RNN 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | model = Vanilla().to(device) 15 | 16 | random_seed = np.random.seed(666) 17 | 18 | trainsetA, valsetA = train_valid_split(Trainset('A'), random_seed=random_seed) 19 | trainsetB, valsetB = train_valid_split(Trainset('B'), random_seed=random_seed) 20 | 21 | train_loaderA = DataLoader( 22 | dataset=trainsetA, 23 | batch_size=85, 24 | shuffle=True, 25 | ) 26 | 27 | valid_loaderA = DataLoader( 28 | dataset=valsetA, 29 | batch_size=85 30 | ) 31 | 32 | train_loaderB = DataLoader( 33 | dataset=trainsetB, 34 | batch_size=85, 35 | shuffle=True, 36 | ) 37 | 38 | valid_loaderB = DataLoader( 39 | dataset=valsetB, 40 | batch_size=85 41 | ) 42 | 43 | test_loaderA = DataLoader( 44 | dataset=Testset('A'), 45 | shuffle=False 46 | ) 47 | 48 | test_loaderB = DataLoader( 49 | dataset=Testset('B'), 50 | shuffle=False 51 | ) 52 | 53 | mse_criterion = nn.MSELoss() 54 | cross_entropy_criterion = nn.CrossEntropyLoss() 55 | 56 | # optimizer = optim.SGD([ 57 | # {'params':model.conv1.parameters()}, 58 | # {'params':model.conv2.parameters()}, 59 | # {'params':model.fc.parameters()}, 60 | # {'params':model.out.parameters(), 'lr': 1e-8}], lr=5e-4, momentum=0.9, weight_decay=1e-4) 61 | # 62 | # ae_optimizer = optim.Adam([ 63 | # {'params':model.linear_upsample.parameters()}, 64 | # {'params':model.conv_upsample.parameters()}], lr=1e-3, weight_decay=1e-4) 65 | 66 | optimizer = optim.SGD(model.parameters(), lr=5e-4, momentum=0.9, 67 | weight_decay=1e-4) 68 | 69 | if False: 70 | state = torch.load('./ae') 71 | model.load_state_dict(state['model_state_dict']) 72 | optimizer.load_state_dict(state['optimizer_state_dict']) 73 | 74 | 75 | def trainA(ep): 76 | model.train() 77 | total_loss = 0 78 | total_acc = 0 79 | for step, (x, y) in enumerate(train_loaderA): 80 | data, target = x.to(device), y.to(device).long() 81 | output, _ = model(data) 82 | 83 | optimizer.zero_grad() 84 | loss = cross_entropy_criterion(output, target) 85 | total_loss += loss 86 | loss.backward() 87 | 88 | optimizer.step() 89 | 90 | predict = output.data.max(1)[1] 91 | acc = predict.eq(target.data).sum() 92 | total_acc += acc.item() / train_loaderA.batch_size 93 | 94 | avg_loss = total_loss / len(train_loaderA) 95 | avg_acc = total_acc / len(train_loaderA) 96 | print("Epoch: {} Loss: {}, Acc: {}".format(ep, avg_loss, 97 | avg_acc)) 98 | 99 | if ep == 499: 100 | torch.save({ 101 | 'epoch': ep, 102 | 'model_state_dict': model.state_dict(), 103 | 'optimizer_state_dict': optimizer.state_dict(), 104 | 'loss': total_loss, 105 | 'acc': total_acc 106 | }, './vallina_cnn_A') 107 | 108 | 109 | def trainB(ep): 110 | model.train() 111 | total_loss = 0 112 | total_acc = 0 113 | for step, (x, y) in enumerate(train_loaderB): 114 | data, target = x.to(device), y.to(device).long() 115 | output, _ = model(data) 116 | 117 | optimizer.zero_grad() 118 | loss = cross_entropy_criterion(output, target) 119 | total_loss += loss 120 | loss.backward() 121 | 122 | optimizer.step() 123 | 124 | predict = output.data.max(1)[1] 125 | acc = predict.eq(target.data).sum() 126 | total_acc += acc.item() / train_loaderB.batch_size 127 | 128 | avg_loss = total_loss / len(train_loaderB) 129 | avg_acc = total_acc / len(train_loaderB) 130 | print("Epoch: {} Loss: {}, Acc: {}".format(ep, avg_loss, 131 | avg_acc)) 132 | 133 | if ep == 499: 134 | torch.save({ 135 | 'epoch': ep, 136 | 'model_state_dict': model.state_dict(), 137 | 'optimizer_state_dict': optimizer.state_dict(), 138 | 'loss': total_loss, 139 | 'acc': total_acc 140 | }, './vallina_cnn_B') 141 | 142 | 143 | 144 | def valA(ep): 145 | model.eval() 146 | total_loss = 0 147 | total_acc = 0 148 | 149 | with torch.no_grad(): 150 | for step, (x, y) in enumerate(valid_loaderA): 151 | data, target = x.to(device), y.to(device).long() 152 | output, _ = model(data) 153 | 154 | loss = cross_entropy_criterion(output, target) 155 | total_loss += loss 156 | 157 | predict = output.data.max(1)[1] 158 | acc = predict.eq(target.data).sum() 159 | total_acc += acc.item() / valid_loaderA.batch_size 160 | print("Valid Epoch: {} Loss: {}, Acc: {}".format(ep, total_loss / len(valid_loaderA), 161 | total_acc / len(valid_loaderA))) 162 | 163 | 164 | def valB(ep): 165 | model.eval() 166 | total_loss = 0 167 | total_acc = 0 168 | 169 | with torch.no_grad(): 170 | for step, (x, y) in enumerate(valid_loaderB): 171 | data, target = x.to(device), y.to(device).long() 172 | output, _ = model(data) 173 | 174 | loss = cross_entropy_criterion(output, target) 175 | total_loss += loss 176 | 177 | predict = output.data.max(1)[1] 178 | acc = predict.eq(target.data).sum() 179 | total_acc += acc.item() / valid_loaderB.batch_size 180 | print("Valid Epoch: {} Loss: {}, Acc: {}".format(ep, total_loss / len(valid_loaderB), 181 | total_acc / len(valid_loaderB))) 182 | 183 | 184 | def predictA(): 185 | char = [['A', 'B', 'C', 'D', 'E', 'F'], 186 | ['G', 'H', 'I', 'J', 'K', 'L'], 187 | ['M', 'N', 'O', 'P', 'Q', 'R'], 188 | ['S', 'T', 'U', 'V', 'W', 'X'], 189 | ['Y', 'Z', '1', '2', '3', '4'], 190 | ['5', '6', '7', '8', '9', '_']] 191 | 192 | series = [] 193 | real_a = 'WQXPLZCOMRKO97YFZDEZ1DPI9NNVGRQDJCUVRMEUOOOJD2UFYPOO6J7LDGYEGOA5VHNEHBTXOO1TDOILUEE5BFAEEXAW_K4R3MRU' 194 | model.eval() 195 | with torch.no_grad(): 196 | for step, (cols, rows) in enumerate(test_loaderA): 197 | col_pred_set = [] 198 | row_pred_set = [] 199 | # print(" 0 - 6") 200 | for line in range(6): 201 | data = cols[:, line, :, :].to(device) 202 | output, _ = model(data) 203 | col_pred_set.append(output.data.cpu().numpy()) 204 | # print(" 7 - 12") 205 | for line in range(6): 206 | data = rows[:, line, :, :].to(device) 207 | output, _ = model(data) 208 | row_pred_set.append(output.data.cpu().numpy()) 209 | col_pred_set = np.array(col_pred_set).squeeze() 210 | row_pred_set = np.array(row_pred_set).squeeze() 211 | col_pred = np.argmax(col_pred_set, axis=0)[1] 212 | row_pred = np.argmax(row_pred_set, axis=0)[1] 213 | 214 | series.append(char[row_pred][col_pred]) 215 | series = ''.join(series) 216 | print(series) 217 | counter = 0 218 | for i in range(len(real_a)): 219 | if real_a[i] == series[i]: 220 | counter += 1 221 | print(counter / len(real_a)) 222 | return counter / len(real_a) 223 | 224 | 225 | def predictB(): 226 | char = [['A', 'B', 'C', 'D', 'E', 'F'], 227 | ['G', 'H', 'I', 'J', 'K', 'L'], 228 | ['M', 'N', 'O', 'P', 'Q', 'R'], 229 | ['S', 'T', 'U', 'V', 'W', 'X'], 230 | ['Y', 'Z', '1', '2', '3', '4'], 231 | ['5', '6', '7', '8', '9', '_']] 232 | 233 | series = [] 234 | real_b = 'MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR' 235 | model.eval() 236 | with torch.no_grad(): 237 | for step, (cols, rows) in enumerate(test_loaderB): 238 | col_pred_set = [] 239 | row_pred_set = [] 240 | # print(" 0 - 6") 241 | for line in range(6): 242 | data = cols[:, line, :, :].to(device) 243 | output, _ = model(data) 244 | col_pred_set.append(output.data.cpu().numpy()) 245 | # print(" 7 - 12") 246 | for line in range(6): 247 | data = rows[:, line, :, :].to(device) 248 | output, _ = model(data) 249 | row_pred_set.append(output.data.cpu().numpy()) 250 | col_pred_set = np.array(col_pred_set).squeeze() 251 | row_pred_set = np.array(row_pred_set).squeeze() 252 | col_pred = np.argmax(col_pred_set, axis=0)[1] 253 | row_pred = np.argmax(row_pred_set, axis=0)[1] 254 | 255 | series.append(char[row_pred][col_pred]) 256 | series = ''.join(series) 257 | print(series) 258 | counter = 0 259 | for i in range(len(real_b)): 260 | if real_b[i] == series[i]: 261 | counter += 1 262 | print(counter / len(real_b)) 263 | return counter / len(real_b) 264 | 265 | def get_feature_map(loader, contain_label=True): 266 | with torch.no_grad(): 267 | for data, target in loader: 268 | data = data.to(device) 269 | _, feature = model(data) 270 | print(feature.shape) 271 | return feature.cpu().numpy(), target.cpu().numpy() 272 | 273 | 274 | def get_feature(loader): 275 | with torch.no_grad(): 276 | col_pred_set = [] 277 | row_pred_set = [] 278 | for step, (cols, rows) in enumerate(loader): 279 | # print(" 0 - 6") 280 | for line in range(6): 281 | data = cols[:, line, :, :].to(device) 282 | _, output = model(data) 283 | col_pred_set.append(output.data.cpu().numpy()) 284 | # print(" 7 - 12") 285 | for line in range(6): 286 | data = rows[:, line, :, :].to(device) 287 | _, output = model(data) 288 | row_pred_set.append(output.data.cpu().numpy()) 289 | 290 | return np.array(col_pred_set), np.array(row_pred_set) 291 | 292 | 293 | def train_AE(ep): 294 | model.train() 295 | total_loss = 0 296 | for step, (x, y) in enumerate(train_loaderB): 297 | data = x.to(device) 298 | output, _ = model(data) 299 | 300 | optimizer.zero_grad() 301 | loss = mse_criterion(output, data) 302 | total_loss += loss 303 | loss.backward() 304 | 305 | optimizer.step() 306 | 307 | avg_loss = total_loss / len(train_loaderB) 308 | print("Epoch: {} Loss: {}".format(ep, avg_loss)) 309 | 310 | if ep % 500 == 0 and ep > 0: 311 | torch.save({ 312 | 'epoch': ep, 313 | 'model_state_dict': model.state_dict(), 314 | 'optimizer_state_dict': optimizer.state_dict(), 315 | 'loss': total_loss, 316 | }, './ae') 317 | 318 | 319 | if __name__ == '__main__': 320 | best_loss = 30 321 | i = 0 322 | for i in range(150): 323 | trainA(i) 324 | valA(i) 325 | predictA() 326 | 327 | # onnx.export(model, 328 | # torch.zeros([85, 64, 240]).to(device), 329 | # 'vallina_cnn.onnx', 330 | # verbose=True) 331 | # if loss < best_loss: 332 | # best_loss = loss 333 | # no_improve_counter = 0 334 | # no_improve_counter += 1 335 | # 336 | # if no_improve_counter > 200: 337 | # torch.save({ 338 | # 'epoch': i, 339 | # 'model_state_dict': model.state_dict(), 340 | # 'optimizer_state_dict': optimizer.state_dict(), 341 | # 'loss': best_loss 342 | # }, './autoencoder_val') 343 | # print('early stop with 200 no improvement') 344 | # break 345 | # 346 | # 347 | -------------------------------------------------------------------------------- /description/topoplotEEG.m: -------------------------------------------------------------------------------- 1 | % topoplot() - plot a topographic map of an EEG field as a 2-D 2 | % circular view (looking down at the top of the head) 3 | % using cointerpolation on a fine cartesian grid. 4 | % Usage: 5 | % >> topoplot(datavector,'eloc_file'); 6 | % >> topoplot(datavector,'eloc_file', 'Param1','Value1', ...) 7 | % Inputs: 8 | % datavector = vector of values at the corresponding locations. 9 | % 'eloc_file' = name of an EEG electrode position file {0 -> 'chan_file'} 10 | % 11 | % Optional Parameters & Values (in any order): 12 | % Param Value 13 | % 'colormap' - any sized colormap 14 | % 'interplimits' - 'electrodes' to furthest electrode 15 | % 'head' to edge of head 16 | % {default 'head'} 17 | % 'gridscale' - scaling grid size {default 67} 18 | % 'maplimits' - 'absmax' +/- the absolute-max 19 | % 'maxmin' scale to data range 20 | % [clim1,clim2] user-definined lo/hi 21 | % {default = 'absmax'} 22 | % 'style' - 'straight' colormap only 23 | % 'contour' contour lines only 24 | % 'both' both colormap and contour lines 25 | % 'fill' constant color between lines 26 | % 'blank' just head and electrodes 27 | % {default = 'both'} 28 | % 'numcontour' - number of contour lines 29 | % {default = 6} 30 | % 'shading' - 'flat','interp' {default = 'flat'} 31 | % 'headcolor' - Color of head cartoon {default black} 32 | % 'electrodes' - 'on','off','labels','numbers' 33 | % 'efontsize','electcolor','emarker','emarkersize' - details 34 | % 35 | % Note: topoplot() only works when map limits are >= the max and min 36 | % interpolated data values. 37 | % Eloc_file format: 38 | % chan_number degrees radius reject_level amp_gain channel_name 39 | % (Angle-0 = Cz-to-Fz; C3-angle =-90; Radius at edge of image = 0.5) 40 | % 41 | % For a sample eloc file: >> topoplot('example') 42 | % 43 | % Note: topoplot() will ignore any electrode with a position outside 44 | % the head (radius > 0.5) 45 | 46 | % Topoplot Version 2.1 47 | 48 | % Begun by Andy Spydell, NHRC, 7-23-96 49 | % 8-96 Revised by Colin Humphries, CNL / Salk Institute, La Jolla CA 50 | % -changed surf command to imagesc (faster) 51 | % -can now handle arbitrary scaling of electrode distances 52 | % -can now handle non integer angles in eloc_file 53 | % 4-4-97 Revised again by Colin Humphries, reformat by SM 54 | % -added parameters 55 | % -changed eloc_file format 56 | % 2-26-98 Revised by Colin 57 | % -changed image back to surface command 58 | % -added fill and blank styles 59 | % -removed extra background colormap entry (now use any colormap) 60 | % -added parameters for electrode colors and labels 61 | % -now each topoplot axes use the caxis command again. 62 | % -removed OUTPUT parameter 63 | % 3-11-98 changed default emarkersize, improve help msg -sm 64 | 65 | function handle = topoplot(Vl,loc_file,p1,v1,p2,v2,p3,v3,p4,v4,p5,v5,p6,v6,p7,v7,p8,v8,p9,v9) 66 | 67 | % User Defined Defaults: 68 | MAXCHANS = 256; 69 | DEFAULT_ELOC = 'eloc64.txt'; 70 | INTERPLIMITS = 'head'; % head, electrodes 71 | MAPLIMITS = 'absmax'; % absmax, maxmin, [values] 72 | GRID_SCALE = 67; 73 | CONTOURNUM = 6; 74 | STYLE = 'both'; % both,straight,fill,contour,blank 75 | HCOLOR = [0 0 0]; 76 | ECOLOR = [0 0 0]; 77 | CONTCOLOR = [0 0 0]; 78 | ELECTROD = 'on'; % ON OFF LABEL 79 | EMARKERSIZE = 6; 80 | EFSIZE = get(0,'DefaultAxesFontSize'); 81 | HLINEWIDTH = 2; 82 | EMARKER = '.'; 83 | SHADING = 'flat'; % flat or interp 84 | 85 | %%%%%%%%%%%%%%%%%%%%%%% 86 | nargs = nargin; 87 | if nargs < 2 88 | loc_file = DEFAULT_ELOC; 89 | end 90 | if nargs == 1 91 | if isstr(Vl) 92 | if any(strcmp(lower(Vl),{'example','demo'})) 93 | fprintf(['This is an example of an electrode location file,\n',... 94 | 'an ascii file consisting of the following four columns:\n',... 95 | ' channel_number degrees arc_length channel_name\n\n',... 96 | 'Example:\n',... 97 | ' 1 -18 .352 Fp1.\n',... 98 | ' 2 18 .352 Fp2.\n',... 99 | ' 5 -90 .181 C3..\n',... 100 | ' 6 90 .181 C4..\n',... 101 | ' 7 -90 .500 A1..\n',... 102 | ' 8 90 .500 A2..\n',... 103 | ' 9 -142 .231 P3..\n',... 104 | '10 142 .231 P4..\n',... 105 | '11 0 .181 Fz..\n',... 106 | '12 0 0 Cz..\n',... 107 | '13 180 .181 Pz..\n\n',... 108 | 'The model head sphere has a diameter of 1.\n',... 109 | 'The vertex (Cz) has arc length 0. Channels with arc \n',... 110 | 'lengths > 0.5 are not plotted nor used for interpolation.\n'... 111 | 'Zero degrees is towards the nasion. Positive angles\n',... 112 | 'point to the right hemisphere; negative to the left.\n',... 113 | 'Channel names should each be four chars, padded with\n',... 114 | 'periods (in place of spaces).\n']) 115 | return 116 | 117 | end 118 | end 119 | end 120 | if isempty(loc_file) 121 | loc_file = 0; 122 | end 123 | if loc_file == 0 124 | loc_file = DEFAULT_ELOC; 125 | end 126 | 127 | if nargs > 2 128 | if ~(round(nargs/2) == nargs/2) 129 | error('topoplot(): Odd number of inputs?') 130 | end 131 | for i = 3:2:nargs 132 | Param = eval(['p',int2str((i-3)/2 +1)]); 133 | Value = eval(['v',int2str((i-3)/2 +1)]); 134 | if ~isstr(Param) 135 | error('topoplot(): Parameter must be a string') 136 | end 137 | Param = lower(Param); 138 | switch lower(Param) 139 | case 'colormap' 140 | if size(Value,2)~=3 141 | error('topoplot(): Colormap must be a n x 3 matrix') 142 | end 143 | colormap(Value) 144 | case {'interplimits','headlimits'} 145 | if ~isstr(Value) 146 | error('topoplot(): interplimits value must be a string') 147 | end 148 | Value = lower(Value); 149 | if ~strcmp(Value,'electrodes') & ~strcmp(Value,'head') 150 | error('topoplot(): Incorrect value for interplimits') 151 | end 152 | INTERPLIMITS = Value; 153 | case 'maplimits' 154 | MAPLIMITS = Value; 155 | case 'gridscale' 156 | GRID_SCALE = Value; 157 | case 'style' 158 | STYLE = lower(Value); 159 | case 'numcontour' 160 | CONTOURNUM = Value; 161 | case 'electrodes' 162 | ELECTROD = lower(Value); 163 | case 'emarker' 164 | EMARKER = Value; 165 | case {'headcolor','hcolor'} 166 | HCOLOR = Value; 167 | case {'electcolor','ecolor'} 168 | ECOLOR = Value; 169 | case {'emarkersize','emsize'} 170 | EMARKERSIZE = Value; 171 | case {'efontsize','efsize'} 172 | EFSIZE = Value; 173 | case 'shading' 174 | SHADING = lower(Value); 175 | if ~any(strcmp(SHADING,{'flat','interp'})) 176 | error('Invalid Shading Parameter') 177 | end 178 | otherwise 179 | error('Unknown parameter.') 180 | end 181 | end 182 | end 183 | 184 | [r,c] = size(Vl); 185 | if r>1 & c>1, 186 | error('topoplot(): data should be a single vector\n'); 187 | end 188 | fid = fopen(loc_file); 189 | if fid<1, 190 | fprintf('topoplot(): cannot open eloc_file (%s).\n',loc_file); 191 | return 192 | end 193 | A = fscanf(fid,'%d %f %f %s',[7 MAXCHANS]); 194 | fclose(fid); 195 | 196 | A = A'; 197 | 198 | if length(Vl) ~= size(A,1), 199 | fprintf(... 200 | 'topoplot(): data vector must have the same rows (%d) as eloc_file (%d)\n',... 201 | length(Vl),size(A,1)); 202 | A 203 | error(''); 204 | end 205 | 206 | labels = setstr(A(:,4:7)); 207 | idx = find(labels == '.'); % some labels have dots 208 | labels(idx) = setstr(abs(' ')*ones(size(idx))); % replace them with spaces 209 | 210 | Th = pi/180*A(:,2); % convert degrees to radians 211 | Rd = A(:,3); 212 | ii = find(Rd <= 0.5); % interpolate on-head channels only 213 | Th = Th(ii); 214 | Rd = Rd(ii); 215 | Vl = Vl(ii); 216 | labels = labels(ii,:); 217 | 218 | [x,y] = pol2cart(Th,Rd); % transform from polar to cartesian coordinates 219 | rmax = 0.5; 220 | 221 | ha = gca; 222 | cla 223 | hold on 224 | 225 | if ~strcmp(STYLE,'blank') 226 | % find limits for interpolation 227 | if strcmp(INTERPLIMITS,'head') 228 | xmin = min(-.5,min(x)); xmax = max(0.5,max(x)); 229 | ymin = min(-.5,min(y)); ymax = max(0.5,max(y)); 230 | else 231 | xmin = max(-.5,min(x)); xmax = min(0.5,max(x)); 232 | ymin = max(-.5,min(y)); ymax = min(0.5,max(y)); 233 | end 234 | 235 | xi = linspace(xmin,xmax,GRID_SCALE); % x-axis description (row vector) 236 | yi = linspace(ymin,ymax,GRID_SCALE); % y-axis description (row vector) 237 | 238 | [Xi,Yi,Zi] = griddata(y,x,Vl,yi',xi,'invdist'); % Interpolate data 239 | 240 | % Take data within head 241 | mask = (sqrt(Xi.^2+Yi.^2) <= rmax); 242 | ii = find(mask == 0); 243 | Zi(ii) = NaN; 244 | 245 | % calculate colormap limits 246 | m = size(colormap,1); 247 | if isstr(MAPLIMITS) 248 | if strcmp(MAPLIMITS,'absmax') 249 | amin = -max(max(abs(Zi))); 250 | amax = max(max(abs(Zi))); 251 | elseif strcmp(MAPLIMITS,'maxmin') 252 | amin = min(min(Zi)); 253 | amax = max(max(Zi)); 254 | end 255 | else 256 | amin = MAPLIMITS(1); 257 | amax = MAPLIMITS(2); 258 | end 259 | delta = xi(2)-xi(1); % length of grid entry 260 | 261 | % Draw topoplot on head 262 | if strcmp(STYLE,'contour') 263 | contour(Xi,Yi,Zi,CONTOURNUM,'k'); 264 | elseif strcmp(STYLE,'both') 265 | surface(Xi-delta/2,Yi-delta/2,zeros(size(Zi)),Zi,'EdgeColor','none',... 266 | 'FaceColor',SHADING); 267 | contour(Xi,Yi,Zi,CONTOURNUM,'k'); 268 | elseif strcmp(STYLE,'straight') 269 | surface(Xi-delta/2,Yi-delta/2,zeros(size(Zi)),Zi,'EdgeColor','none',... 270 | 'FaceColor',SHADING); 271 | elseif strcmp(STYLE,'fill') 272 | contourf(Xi,Yi,Zi,CONTOURNUM,'k'); 273 | else 274 | error('Invalid style') 275 | end 276 | caxis([amin amax]) % set coloraxis 277 | end 278 | 279 | set(ha,'Xlim',[-rmax*1.3 rmax*1.3],'Ylim',[-rmax*1.3 rmax*1.3]) 280 | 281 | % %%% Draw Head %%%% 282 | l = 0:2*pi/100:2*pi; 283 | basex = .18*rmax; 284 | tip = rmax*1.15; base = rmax-.004; 285 | EarX = [.497 .510 .518 .5299 .5419 .54 .547 .532 .510 .489]; 286 | EarY = [.0555 .0775 .0783 .0746 .0555 -.0055 -.0932 -.1313 -.1384 -.1199]; 287 | 288 | % Plot Electrodes 289 | if strcmp(ELECTROD,'on') 290 | hp2 = plot(y,x,EMARKER,'Color',ECOLOR,'markersize',EMARKERSIZE); 291 | elseif strcmp(ELECTROD,'labels') 292 | for i = 1:size(labels,1) 293 | text(y(i),x(i),labels(i,:),'HorizontalAlignment','center',... 294 | 'VerticalAlignment','middle','Color',ECOLOR,... 295 | 'FontSize',EFSIZE) 296 | end 297 | elseif strcmp(ELECTROD,'numbers') 298 | whos y x 299 | for i = 1:size(labels,1) 300 | text(y(i),x(i),int2str(i),'HorizontalAlignment','center',... 301 | 'VerticalAlignment','middle','Color',ECOLOR,... 302 | 'FontSize',EFSIZE) 303 | end 304 | end 305 | 306 | % Plot Head, Ears, Nose 307 | plot(cos(l).*rmax,sin(l).*rmax,... 308 | 'color',HCOLOR,'Linestyle','-','LineWidth',HLINEWIDTH); 309 | 310 | plot([.18*rmax;0;-.18*rmax],[base;tip;base],... 311 | 'Color',HCOLOR,'LineWidth',HLINEWIDTH); 312 | 313 | plot(EarX,EarY,'color',HCOLOR,'LineWidth',HLINEWIDTH) 314 | plot(-EarX,EarY,'color',HCOLOR,'LineWidth',HLINEWIDTH) 315 | 316 | hold off 317 | axis off 318 | 319 | -------------------------------------------------------------------------------- /notebook/NaiveBayes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import numpy as np\n", 11 | "from Dataset import Trainset, Testset\n", 12 | "from train_valid_split import train_valid_split\n", 13 | "from scipy.io import loadmat\n", 14 | "from Model import Net" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 12, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from sklearn import datasets\n", 24 | "from sklearn.decomposition import PCA, FastICA\n", 25 | "from sklearn.naive_bayes import GaussianNB\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 63, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "pca = PCA(10, whiten=True)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 86, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "raw_data = loadmat('./sub_a.mat')\n", 44 | "\n", 45 | "signals = raw_data['responses']\n", 46 | "label = raw_data['is_stimulate']\n", 47 | "\n", 48 | "data = []\n", 49 | "target = []\n", 50 | "\n", 51 | "for i in range(12):\n", 52 | " for j in range(85):\n", 53 | " data.append(pca.fit_transform(signals[i, :, :, j].T).T)\n", 54 | " target.append(label[i, j])\n", 55 | "\n", 56 | "data = np.array(data)\n", 57 | "target = np.array(target)\n" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 87, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "(1020, 10, 64)" 69 | ] 70 | }, 71 | "execution_count": 87, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "data.shape" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 48, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "raw_data = loadmat('./sub_b_test.mat')\n", 87 | "\n", 88 | "signals = raw_data['responses']\n", 89 | "\n", 90 | "data = []\n", 91 | "\n", 92 | "for i in range(12):\n", 93 | " for j in range(85):\n", 94 | " data.append(signals[i, :, :, j].reshape(-1, 64))\n", 95 | "\n", 96 | "data = np.array(data)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "from sklearn.model_selection import train_test_split\n", 106 | "\n", 107 | "x_train, x_test, y_train, y_test = train_test_split(data,\n", 108 | " target,\n", 109 | " )" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 88, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "0.8303921568627451\n", 122 | "0.8333333333333334\n", 123 | "0.8352941176470589\n", 124 | "0.8333333333333334\n", 125 | "0.8303921568627451\n", 126 | "0.8333333333333334\n", 127 | "0.8333333333333334\n", 128 | "0.8333333333333334\n", 129 | "0.8333333333333334\n", 130 | "0.8333333333333334\n", 131 | "0.8323529411764706\n", 132 | "0.8303921568627451\n", 133 | "0.8333333333333334\n", 134 | "0.8303921568627451\n", 135 | "0.8333333333333334\n", 136 | "0.8333333333333334\n", 137 | "0.8313725490196079\n", 138 | "0.8333333333333334\n", 139 | "0.8323529411764706\n", 140 | "0.8362745098039216\n", 141 | "0.8333333333333334\n", 142 | "0.8323529411764706\n", 143 | "0.8333333333333334\n", 144 | "0.8323529411764706\n", 145 | "0.8333333333333334\n", 146 | "0.8333333333333334\n", 147 | "0.8333333333333334\n", 148 | "0.8352941176470589\n", 149 | "0.8333333333333334\n", 150 | "0.8333333333333334\n", 151 | "0.8313725490196079\n", 152 | "0.8333333333333334\n", 153 | "0.8333333333333334\n", 154 | "0.8343137254901961\n", 155 | "0.8343137254901961\n", 156 | "0.8333333333333334\n", 157 | "0.8343137254901961\n", 158 | "0.8333333333333334\n", 159 | "0.8333333333333334\n", 160 | "0.8323529411764706\n", 161 | "0.8333333333333334\n", 162 | "0.8323529411764706\n", 163 | "0.8333333333333334\n", 164 | "0.8343137254901961\n", 165 | "0.8352941176470589\n", 166 | "0.8333333333333334\n", 167 | "0.8294117647058824\n", 168 | "0.8333333333333334\n", 169 | "0.8333333333333334\n", 170 | "0.8333333333333334\n", 171 | "0.8323529411764706\n", 172 | "0.8333333333333334\n", 173 | "0.8382352941176471\n", 174 | "0.8333333333333334\n", 175 | "0.8333333333333334\n", 176 | "0.8323529411764706\n", 177 | "0.8333333333333334\n", 178 | "0.8333333333333334\n", 179 | "0.8323529411764706\n", 180 | "0.8343137254901961\n", 181 | "0.8333333333333334\n", 182 | "0.8343137254901961\n", 183 | "0.8333333333333334\n", 184 | "0.8343137254901961\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "classifiers = []\n", 190 | "\n", 191 | "for i in range(64):\n", 192 | " gnb = GaussianNB()\n", 193 | " bayes = gnb.fit(data[:,:,i], target)\n", 194 | " print(gnb.score(data[:,:,i], target))\n", 195 | " classifiers.append(bayes)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 89, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "vote = np.zeros(data.shape[0])\n", 212 | "for step, bayes in enumerate(classifiers):\n", 213 | " vote = vote + bayes.predict(data[:, :, step])\n", 214 | " \n" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 90, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "char = [['A', 'B', 'C', 'D', 'E', 'F'],\n", 224 | " ['G', 'H', 'I', 'J', 'K', 'L'],\n", 225 | " ['M', 'N', 'O', 'P', 'Q', 'R'],\n", 226 | " ['S', 'T', 'U', 'V', 'W', 'X'],\n", 227 | " ['Y', 'Z', '1', '2', '3', '4'],\n", 228 | " ['5', '6', '7', '8', '9', '_']]" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 91, 234 | "metadata": { 235 | "scrolled": true 236 | }, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 243 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 244 | "[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 245 | "[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 246 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 247 | "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 248 | "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 249 | "[0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0.]\n", 250 | "[0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1.]\n", 251 | "[0. 0. 0. 0. 0. 0. 2. 0. 0. 0. 0. 0.]\n", 252 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 253 | "[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 254 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 255 | "[0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1.]\n", 256 | "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 257 | "[0. 0. 0. 1. 0. 0. 1. 1. 2. 0. 0. 1.]\n", 258 | "[0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 259 | "[0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 260 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 2. 1. 0.]\n", 261 | "[1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 262 | "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 263 | "[0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 264 | "[0. 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 265 | "[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 266 | "[0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0.]\n", 267 | "[0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 268 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 269 | "[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 270 | "[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 271 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 272 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 273 | "[0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1.]\n", 274 | "[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 275 | "[0. 0. 1. 1. 0. 0. 3. 0. 0. 0. 0. 0.]\n", 276 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 277 | "[1. 0. 0. 2. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 278 | "[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 279 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 280 | "[0. 0. 0. 0. 0. 2. 0. 1. 0. 0. 0. 0.]\n", 281 | "[0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0.]\n", 282 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 283 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 284 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 285 | "[0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 286 | "[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 287 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 288 | "[0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0.]\n", 289 | "[0. 2. 0. 0. 1. 0. 0. 1. 3. 0. 0. 0.]\n", 290 | "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 291 | "[0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0.]\n", 292 | "[0. 0. 0. 2. 2. 0. 0. 0. 0. 1. 1. 0.]\n", 293 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 294 | "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 295 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 296 | "[1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 297 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 298 | "[0. 0. 0. 0. 1. 2. 0. 1. 0. 0. 0. 0.]\n", 299 | "[0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 300 | "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 301 | "[0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0.]\n", 302 | "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 303 | "[2. 2. 0. 0. 0. 2. 1. 0. 0. 0. 0. 0.]\n", 304 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 305 | "[0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 306 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 307 | "[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 308 | "[0. 0. 1. 0. 0. 0. 0. 0. 0. 2. 0. 0.]\n", 309 | "[0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 310 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 311 | "[0. 0. 0. 0. 0. 0. 1. 0. 2. 1. 0. 0.]\n", 312 | "[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 313 | "[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 314 | "[0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 315 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 316 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 317 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 318 | "[1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 319 | "[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 320 | "[0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 321 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 322 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 323 | "[0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0.]\n", 324 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 325 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 326 | "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n" 327 | ] 328 | } 329 | ], 330 | "source": [ 331 | "series = []\n", 332 | "for i in range(0, data.shape[0], 12):\n", 333 | " print(vote[i:i+12])\n", 334 | " series.append(char[np.argmax(vote[i+6:i+12])][np.argmax(vote[i:i+6])])\n" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 92, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "str = ''.join(series)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 93, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "data": { 353 | "text/plain": [ 354 | "'AAEAAA5AGAAMAZAPDNSAZBBEXBAGAYS1ECAVCALGAAA6EYGNBRV5AASALPBIYAAJAMUJSMDABAAAA6BAAG5AA'" 355 | ] 356 | }, 357 | "execution_count": 93, 358 | "metadata": {}, 359 | "output_type": "execute_result" 360 | } 361 | ], 362 | "source": [ 363 | "str" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 94, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "real_b = 'MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR'\n", 373 | "real_b_train = 'VGREAAH8TVRHBYN_UGCOLO4EUERDOOHCIFOMDNU6LQCPKEIREKOYRQIDJXPBKOJDWZEUEWWFOEBHXTQTTZUMO'\n", 374 | "\n", 375 | "real_a_train = 'EAEVQTDOJG8RBRGONCEDHCTUIDBPUHMEM6OUXOCFOUKWA4VJEFRZROLHYNQDW_EKTLBWXEPOUIKZERYOOTHQI'\n" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 95, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "data": { 385 | "text/plain": [ 386 | "2" 387 | ] 388 | }, 389 | "execution_count": 95, 390 | "metadata": {}, 391 | "output_type": "execute_result" 392 | } 393 | ], 394 | "source": [ 395 | "counter = 0\n", 396 | "for i in range(85):\n", 397 | " if str[i] == real_a_train[i]:\n", 398 | " counter += 1\n", 399 | "\n", 400 | "counter" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [] 409 | } 410 | ], 411 | "metadata": { 412 | "kernelspec": { 413 | "display_name": "Python [conda env:torch]", 414 | "language": "python", 415 | "name": "conda-env-torch-py" 416 | }, 417 | "language_info": { 418 | "codemirror_mode": { 419 | "name": "ipython", 420 | "version": 3 421 | }, 422 | "file_extension": ".py", 423 | "mimetype": "text/x-python", 424 | "name": "python", 425 | "nbconvert_exporter": "python", 426 | "pygments_lexer": "ipython3", 427 | "version": "3.6.6" 428 | } 429 | }, 430 | "nbformat": 4, 431 | "nbformat_minor": 2 432 | } 433 | -------------------------------------------------------------------------------- /pics/v_cnn.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | torch-jit-export 11 | 12 | 13 | Conv (op#0)\n input0 0\n input1 1\n input2 2\n output0 7 14 | 15 | 16 | Conv (op#0) 17 | input0 0 18 | input1 1 19 | input2 2 20 | output0 7 21 | 22 | 23 | 24 | 25 | 70 26 | 27 | 7 28 | 29 | 30 | Conv (op#0)\n input0 0\n input1 1\n input2 2\n output0 7->70 31 | 32 | 33 | 34 | 35 | MaxPool (op#1)\n input0 7\n output0 8 36 | 37 | 38 | MaxPool (op#1) 39 | input0 7 40 | output0 8 41 | 42 | 43 | 44 | 45 | 70->MaxPool (op#1)\n input0 7\n output0 8 46 | 47 | 48 | 49 | 50 | 00 51 | 52 | 0 53 | 54 | 55 | 00->Conv (op#0)\n input0 0\n input1 1\n input2 2\n output0 7 56 | 57 | 58 | 59 | 60 | 10 61 | 62 | 1 63 | 64 | 65 | 10->Conv (op#0)\n input0 0\n input1 1\n input2 2\n output0 7 66 | 67 | 68 | 69 | 70 | 20 71 | 72 | 2 73 | 74 | 75 | 20->Conv (op#0)\n input0 0\n input1 1\n input2 2\n output0 7 76 | 77 | 78 | 79 | 80 | 80 81 | 82 | 8 83 | 84 | 85 | MaxPool (op#1)\n input0 7\n output0 8->80 86 | 87 | 88 | 89 | 90 | Relu (op#2)\n input0 8\n output0 9 91 | 92 | 93 | Relu (op#2) 94 | input0 8 95 | output0 9 96 | 97 | 98 | 99 | 100 | 80->Relu (op#2)\n input0 8\n output0 9 101 | 102 | 103 | 104 | 105 | 90 106 | 107 | 9 108 | 109 | 110 | Relu (op#2)\n input0 8\n output0 9->90 111 | 112 | 113 | 114 | 115 | Shape (op#4)\n input0 9\n output0 11 116 | 117 | 118 | Shape (op#4) 119 | input0 9 120 | output0 11 121 | 122 | 123 | 124 | 125 | 90->Shape (op#4)\n input0 9\n output0 11 126 | 127 | 128 | 129 | 130 | Reshape (op#10)\n input0 9\n input1 16\n output0 17 131 | 132 | 133 | Reshape (op#10) 134 | input0 9 135 | input1 16 136 | output0 17 137 | 138 | 139 | 140 | 141 | 90->Reshape (op#10)\n input0 9\n input1 16\n output0 17 142 | 143 | 144 | 145 | 146 | 110 147 | 148 | 11 149 | 150 | 151 | Shape (op#4)\n input0 9\n output0 11->110 152 | 153 | 154 | 155 | 156 | 170 157 | 158 | 17 159 | 160 | 161 | Reshape (op#10)\n input0 9\n input1 16\n output0 17->170 162 | 163 | 164 | 165 | 166 | Constant (op#3)\n output0 10 167 | 168 | 169 | Constant (op#3) 170 | output0 10 171 | 172 | 173 | 174 | 175 | 100 176 | 177 | 10 178 | 179 | 180 | Constant (op#3)\n output0 10->100 181 | 182 | 183 | 184 | 185 | Gather (op#5)\n input0 11\n input1 10\n output0 12 186 | 187 | 188 | Gather (op#5) 189 | input0 11 190 | input1 10 191 | output0 12 192 | 193 | 194 | 195 | 196 | 100->Gather (op#5)\n input0 11\n input1 10\n output0 12 197 | 198 | 199 | 200 | 201 | 120 202 | 203 | 12 204 | 205 | 206 | Gather (op#5)\n input0 11\n input1 10\n output0 12->120 207 | 208 | 209 | 210 | 211 | 110->Gather (op#5)\n input0 11\n input1 10\n output0 12 212 | 213 | 214 | 215 | 216 | Unsqueeze (op#7)\n input0 12\n output0 14 217 | 218 | 219 | Unsqueeze (op#7) 220 | input0 12 221 | output0 14 222 | 223 | 224 | 225 | 226 | 120->Unsqueeze (op#7)\n input0 12\n output0 14 227 | 228 | 229 | 230 | 231 | 140 232 | 233 | 14 234 | 235 | 236 | Unsqueeze (op#7)\n input0 12\n output0 14->140 237 | 238 | 239 | 240 | 241 | Constant (op#6)\n output0 13 242 | 243 | 244 | Constant (op#6) 245 | output0 13 246 | 247 | 248 | 249 | 250 | 130 251 | 252 | 13 253 | 254 | 255 | Constant (op#6)\n output0 13->130 256 | 257 | 258 | 259 | 260 | Unsqueeze (op#8)\n input0 13\n output0 15 261 | 262 | 263 | Unsqueeze (op#8) 264 | input0 13 265 | output0 15 266 | 267 | 268 | 269 | 270 | 130->Unsqueeze (op#8)\n input0 13\n output0 15 271 | 272 | 273 | 274 | 275 | 150 276 | 277 | 15 278 | 279 | 280 | Unsqueeze (op#8)\n input0 13\n output0 15->150 281 | 282 | 283 | 284 | 285 | Concat (op#9)\n input0 14\n input1 15\n output0 16 286 | 287 | 288 | Concat (op#9) 289 | input0 14 290 | input1 15 291 | output0 16 292 | 293 | 294 | 295 | 296 | 140->Concat (op#9)\n input0 14\n input1 15\n output0 16 297 | 298 | 299 | 300 | 301 | 160 302 | 303 | 16 304 | 305 | 306 | Concat (op#9)\n input0 14\n input1 15\n output0 16->160 307 | 308 | 309 | 310 | 311 | 150->Concat (op#9)\n input0 14\n input1 15\n output0 16 312 | 313 | 314 | 315 | 316 | 160->Reshape (op#10)\n input0 9\n input1 16\n output0 17 317 | 318 | 319 | 320 | 321 | Gemm (op#11)\n input0 17\n input1 3\n input2 4\n output0 18 322 | 323 | 324 | Gemm (op#11) 325 | input0 17 326 | input1 3 327 | input2 4 328 | output0 18 329 | 330 | 331 | 332 | 333 | 170->Gemm (op#11)\n input0 17\n input1 3\n input2 4\n output0 18 334 | 335 | 336 | 337 | 338 | 180 339 | 340 | 18 341 | 342 | 343 | Gemm (op#11)\n input0 17\n input1 3\n input2 4\n output0 18->180 344 | 345 | 346 | 347 | 348 | Dropout (op#12)\n input0 18\n output0 19\n output1 20 349 | 350 | 351 | Dropout (op#12) 352 | input0 18 353 | output0 19 354 | output1 20 355 | 356 | 357 | 358 | 359 | 180->Dropout (op#12)\n input0 18\n output0 19\n output1 20 360 | 361 | 362 | 363 | 364 | 30 365 | 366 | 3 367 | 368 | 369 | 30->Gemm (op#11)\n input0 17\n input1 3\n input2 4\n output0 18 370 | 371 | 372 | 373 | 374 | 40 375 | 376 | 4 377 | 378 | 379 | 40->Gemm (op#11)\n input0 17\n input1 3\n input2 4\n output0 18 380 | 381 | 382 | 383 | 384 | 190 385 | 386 | 19 387 | 388 | 389 | Dropout (op#12)\n input0 18\n output0 19\n output1 20->190 390 | 391 | 392 | 393 | 394 | 200 395 | 396 | 20 397 | 398 | 399 | Dropout (op#12)\n input0 18\n output0 19\n output1 20->200 400 | 401 | 402 | 403 | 404 | Relu (op#13)\n input0 19\n output0 21 405 | 406 | 407 | Relu (op#13) 408 | input0 19 409 | output0 21 410 | 411 | 412 | 413 | 414 | 190->Relu (op#13)\n input0 19\n output0 21 415 | 416 | 417 | 418 | 419 | 210 420 | 421 | 21 422 | 423 | 424 | Relu (op#13)\n input0 19\n output0 21->210 425 | 426 | 427 | 428 | 429 | Gemm (op#14)\n input0 21\n input1 5\n input2 6\n output0 22 430 | 431 | 432 | Gemm (op#14) 433 | input0 21 434 | input1 5 435 | input2 6 436 | output0 22 437 | 438 | 439 | 440 | 441 | 210->Gemm (op#14)\n input0 21\n input1 5\n input2 6\n output0 22 442 | 443 | 444 | 445 | 446 | 220 447 | 448 | 22 449 | 450 | 451 | Gemm (op#14)\n input0 21\n input1 5\n input2 6\n output0 22->220 452 | 453 | 454 | 455 | 456 | Dropout (op#15)\n input0 22\n output0 23\n output1 24 457 | 458 | 459 | Dropout (op#15) 460 | input0 22 461 | output0 23 462 | output1 24 463 | 464 | 465 | 466 | 467 | 220->Dropout (op#15)\n input0 22\n output0 23\n output1 24 468 | 469 | 470 | 471 | 472 | 50 473 | 474 | 5 475 | 476 | 477 | 50->Gemm (op#14)\n input0 21\n input1 5\n input2 6\n output0 22 478 | 479 | 480 | 481 | 482 | 60 483 | 484 | 6 485 | 486 | 487 | 60->Gemm (op#14)\n input0 21\n input1 5\n input2 6\n output0 22 488 | 489 | 490 | 491 | 492 | 230 493 | 494 | 23 495 | 496 | 497 | Dropout (op#15)\n input0 22\n output0 23\n output1 24->230 498 | 499 | 500 | 501 | 502 | 240 503 | 504 | 24 505 | 506 | 507 | Dropout (op#15)\n input0 22\n output0 23\n output1 24->240 508 | 509 | 510 | 511 | 512 | Squeeze (op#16)\n input0 23\n output0 25 513 | 514 | 515 | Squeeze (op#16) 516 | input0 23 517 | output0 25 518 | 519 | 520 | 521 | 522 | 230->Squeeze (op#16)\n input0 23\n output0 25 523 | 524 | 525 | 526 | 527 | 250 528 | 529 | 25 530 | 531 | 532 | Squeeze (op#16)\n input0 23\n output0 25->250 533 | 534 | 535 | 536 | 537 | 538 | --------------------------------------------------------------------------------