├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── docker ├── Dockerfile ├── elementwise_binary_broadcast_op-inl.h └── mxnet │ └── Dockerfile ├── exp_csl ├── exp_dbc ├── exp_ninapro ├── run.py ├── scripts ├── capgmyo.ipynb ├── chi.ipynb ├── convnet.ipynb ├── csl.ipynb ├── exp ├── exp-b ├── exp-ws43 ├── exp-ws44 ├── exp-ws45 ├── exp-ws46 ├── exp2 ├── exp3 ├── exp_revgrad ├── exp_revgrad2 ├── exp_tzeng ├── exp_tzeng2 ├── exp_tzeng3 ├── exp_tzeng4 ├── exp_tzeng5 ├── figure.ipynb ├── mount-cache ├── ninapro_lowpass.m ├── paper-bars.ipynb ├── rundocker ├── runsrep ├── sdata.ipynb ├── sigr ├── srep.ipynb ├── test_csl_multistream.py ├── test_db1_input.py ├── test_dbc_multistream.py ├── test_ninapro_multistream.py ├── test_semimyo.py ├── testsrep.py ├── testsrep.sh ├── train_semimyo.sh ├── trainsrep.sh └── waste.ipynb └── sigr ├── __init__.py ├── activity_img ├── actimg_extractor.py ├── actimg_extractor_1.py └── activity_image.py ├── app.py ├── backup.py ├── base_module.py ├── constant.py ├── coral.py ├── data ├── __init__.py ├── capgmyo │ ├── __init__.py │ ├── dba.py │ ├── dbb.py │ └── dbc.py ├── capgmyoiter │ ├── __init__.py │ ├── blockwise_multistream_iter.py │ ├── blockwise_multistream_iter_v2.py │ ├── capgmyoiter_dbb.py │ ├── piecesigimg_multistream_iter.py │ ├── piecewise_multistream_iter.py │ ├── piecewise_multistream_iter_v2.py │ ├── piecewise_plus_rawimg_multistream_iter.py │ ├── piecewisetwoaxis_multistream_iter.py │ └── single_frame_multistream_iter.py ├── capgmyoiter_dbb │ ├── __init__.py │ ├── blockwise_multistream_iter.py │ ├── blockwise_multistream_iter_v2.py │ ├── capgmyoiter_dbb.py │ ├── piecesigimg_multistream_iter.py │ ├── piecewise_multistream_iter.py │ ├── piecewise_multistream_iter_v2.py │ ├── piecewise_plus_rawimg_multistream_iter.py │ ├── piecewisetwoaxis_multistream_iter.py │ └── single_frame_multistream_iter.py ├── capgmyoiter_dbc │ ├── __init__.py │ ├── blockwise_multistream_iter.py │ ├── blockwise_multistream_iter_v2.py │ ├── capgmyoiter_dbb.py │ ├── piecesigimg_multistream_iter.py │ ├── piecewise_multistream_iter.py │ ├── piecewise_multistream_iter_v2.py │ ├── piecewise_plus_rawimg_multistream_iter.py │ ├── piecewisetwoaxis_multistream_iter.py │ └── single_frame_multistream_iter.py ├── csl.py ├── csliter │ ├── __init__.py │ ├── csl_blockpiecewise_multistream_iter.py │ ├── csl_blockwise_diff_multistream_iter.py │ ├── csl_blockwise_multistream_iter.py │ ├── csl_blockwise_multistream_iter_v2.py │ ├── csl_blockwise_plusrawimg_multistream_iter.py │ ├── csl_piecewise_multistream_iter.py │ ├── csl_piecewise_multistream_iter_v2.py │ ├── csl_piecewise_plusrawimg_multistream_iter.py │ ├── csl_rms_blockwise_multistream_iter.py │ └── dualCHstream_iter.py ├── ninapro │ ├── MULTISRC_rawemg_feature_chwise_multistream_iter.py │ ├── MULTISRC_rawemg_feature_multistream_iter.py │ ├── SINGLESRC_feature_chwise_multistream_iter.py │ ├── SINGLESRC_rawemg_feature_singlestream_iter.py │ ├── __init__.py │ ├── block_multistream_iter.py │ ├── block_multistream_iter_v2.py │ ├── caputo.py │ ├── ch_multistream_iter.py │ ├── ch_multistream_plus_rawimg_iter.py │ ├── chdiff_sigimg_multistream_iter.py │ ├── chdiffimage_iter.py │ ├── chdiffsigimage_iter.py │ ├── db1.py │ ├── db1_feature_map.py │ ├── db1_g12.py │ ├── db1_g5.py │ ├── db1_g53.py │ ├── db1_g8.py │ ├── db1_matlab_lowpass.py │ ├── db1_raw_semg_glove.py │ ├── db1_rawdata_semgfeature.py │ ├── db1_signal_image.py │ ├── db1_signal_image_fast.py │ ├── db1_softmax_as_input.py │ ├── featureimage_iter.py │ ├── featureimg_rawimg_multistream_iter.py │ ├── featuremap_iter.py │ ├── frame_multistream_iter.py │ ├── sigimg_rawimg_multistream_iter.py │ ├── simplestacked_iter.py │ ├── single_frame_multistream_iter.py │ └── test_db1_raw_semg_glove.py ├── preprocess.py ├── s21.py └── s21_soft_label.scv ├── emg_features.py ├── evaluation.py ├── evaluation_db1input.py ├── evaluation_db1multistream.py ├── evaluation_db1multistream_outputpreds.py ├── evaluation_dbcmultistream.py ├── evaluation_semimyo.py ├── feature_map ├── activity_image.py ├── emg_features.py ├── feature-map.py ├── ninapro_feature_map_extractor-43.py ├── ninapro_feature_map_extractor-45.py ├── ninapro_feature_map_extractor-49.py └── ninapro_feature_map_extractor.py ├── fft.py ├── genIndex.py ├── lstm.py ├── module.py ├── module_multistream.py ├── module_multistream_loadsinglestreamparams.py ├── module_semimyo.py ├── parse_log.py ├── sklearn_module.py ├── stacked_optical_flow ├── optical_flow_extractor.py └── test_optical_flow_extractor.py ├── symbol.py ├── symbol_multistream.py ├── symbol_multistream_dynetwork.py ├── symbol_semimyo.py ├── train_feature.py ├── train_high_density_emg.py ├── train_semimyo.py ├── train_sigimg_actimg_fast.py ├── utils ├── __init__.py └── proxy.py └── vote.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .cache/ 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | .ipynb_checkpoints/ 4 | .cache/ 5 | /scripts/exp_inter 6 | /tmp/ 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Multi-stream Convolutional Neural Network for sEMG-based Gesture Recognition in Musclecomputer interface 2 | 3 | This repo contains the code for the experiments in the paper: Wentao Wei, Yongkang Wong, Yu Du, Mohan Kankanhalli, Weidong Geng. " [A Multi-stream Convolutional Neural Network for sEMG-based Gesture Recognition in Muscle-Computer Interface](https://www.sciencedirect.com/science/article/abs/pii/S0167865517304439)" 4 | 5 | ## Requirements 6 | - A CUDA compatible GPU 7 | - Ubuntu >= 14.04 or any other Linux/Unix that can run Docker 8 | - [Docker](http://docker.io/) 9 | - [Nvidia Docker](https://github.com/NVIDIA/nvidia-docker) 10 | 11 | ## Usage 12 | - **Pull or build docker image** 13 | ``` 14 | docker pull zjucapg/semg:latest 15 | ``` 16 | or 17 | ``` 18 | docker build -t zjucapg/semg:latest -d docker/Dockerfile . 19 | ``` 20 | - **Dataset** 21 | 22 | Three databases including Ninapro DB1, CapgMyo and CSL-HDEMG can be used for training and test. 23 | 24 | ``` 25 | mkdir .cache 26 | # put NinaPro DB1 in .cache/ninapro-db1 27 | # put CapgMyo DB-a in .cache/dba or DB-b in .cache/dbb or DB-c in .cache/dbc 28 | # put CSL-HDEMG in .cache/csl 29 | ``` 30 | The NinaPro DB1 needs to be segmented by gesture labels and stored in Matlab format as follows.`.cache/ninapro-db1/data/sss/ggg/sss_ggg_ttt.mat` contains a field `data` reprensents the trial `ttt` of gesture `ggg` of subject `sss`. And numbers start from zero. Gesture 0 is the rest gesture. 31 | 32 | For instance, `.cache/ninapro-db1/data/000/001/000_001_000.mat` is the 0th trial of 1st gesture of the 0th subject. 33 | 34 | You can download the original dataset from or download the prepared dataset from our site . CapgMyo and CSL-HDEMG datasets can be acquired on and , respectively. 35 | 36 | - **Quick Start** 37 | ``` 38 | # Get into the capg/semg:mscnn container 39 | nvidia-docker run -ti -v your_projectdir:/code zjucapg/semg /bin/bash 40 | # Train 41 | ./exp_ninapro # Ninapro DB1 42 | or ./exp_dbc # CapgMyo DB-c 43 | or ./exp_csl # CSL HDEMG 44 | # Test 45 | python scripts/test_ninapro_multistream.py # Ninapro DB1 46 | python scripts/test_dbc_multistream.py # CapgMyo DB-c 47 | python scripts/test_csl_multistream.py # CSL HDEMG 48 | ``` 49 | 50 | 51 | - **Trained model** 52 | We also provide trained model for straight using, which you can just extract the zip file and put it into `.cache`. 53 | - The model files are stored on Google Drive, which contain three categories as follows: 54 | - [Ninapro DB1](https://drive.google.com/open?id=1oc2smnwt5JuKqpOX9BSIqwf496x_3neE) 55 | - [CapgMyo DB-c](https://drive.google.com/open?id=1ehHyfhxoZnIwCXh1qkrWXrKhoj5ABaAS) 56 | - [CSL-HDEMG](https://drive.google.com/open?id=1GgAz1QwjPWtvfwU2O-mCOcp4-5odXgTb) 57 | 58 | ## License 59 | Licensed under an GPL v3.0 license. 60 | 61 | ## Bibtex 62 | ``` 63 | @article{WEI20190301, 64 | title = "A multi-stream convolutional neural network for sEMG-based gesture recognition in muscle-computer interface", 65 | journal = "Pattern Recognition Letters", 66 | volume = "119", 67 | pages = "131 - 138", 68 | year = "2019", 69 | note = "Deep Learning for Pattern Recognition", 70 | issn = "0167-8655", 71 | doi = "https://doi.org/10.1016/j.patrec.2017.12.005", 72 | url = "http://www.sciencedirect.com/science/article/pii/S0167865517304439", 73 | author = "Wentao Wei and Yongkang Wong and Yu Du and Yu Hu and Mohan Kankanhalli and Weidong Geng" 74 | } 75 | ``` -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM answeror/mxnet:f2684a6-cuda8 2 | MAINTAINER answeror 3 | 4 | RUN apt-get install -y python-pip python-scipy 5 | # RUN mkdir ~/.pip && \ 6 | # echo "[global]" > ~/.pip/pip.conf && \ 7 | # echo "index-url = https://pypi.mirrors.ustc.edu.cn/simple" >> ~/.pip/pip.conf 8 | RUN pip install click==6.6 logbook==1.0.0 joblib==0.10.3 nose==1.3.7 9 | 10 | RUN cd /mxnet && \ 11 | git reset --hard && \ 12 | git checkout master && \ 13 | git pull 14 | 15 | RUN cd /mxnet && \ 16 | git checkout 7a485bb && \ 17 | git submodule update && \ 18 | git checkout 887491d src/operator/elementwise_binary_broadcast_op-inl.h && \ 19 | sed -i -e 's/CHECK(ksize_x <= dshape\[3\] && ksize_y <= dshape\[2\])/CHECK(ksize_x <= dshape[3] + 2 * param_.pad[1] \&\& ksize_y <= dshape[2] + 2 * param_.pad[0])/' src/operator/convolution-inl.h && \ 20 | cp make/config.mk . && \ 21 | echo "USE_CUDA=1" >>config.mk && \ 22 | echo "USE_CUDA_PATH=/usr/local/cuda" >>config.mk && \ 23 | echo "USE_CUDNN=1" >>config.mk && \ 24 | echo "USE_BLAS=openblas" >>config.mk && \ 25 | make clean && \ 26 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs 27 | 28 | ADD elementwise_binary_broadcast_op-inl.h /mxnet/src/operator/elementwise_binary_broadcast_op-inl.h 29 | RUN cd /mxnet && \ 30 | make clean && \ 31 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs 32 | 33 | # RUN pip install jupyter pandas matplotlib seaborn scikit-learn 34 | # RUN mkdir -p -m 700 /root/.jupyter/ && \ 35 | # echo "c.NotebookApp.ip = '*'" >> /root/.jupyter/jupyter_notebook_config.py 36 | # EXPOSE 8888 37 | # CMD ["sh", "-c", "jupyter notebook"] 38 | 39 | WORKDIR /code 40 | -------------------------------------------------------------------------------- /docker/mxnet/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn5-devel 2 | MAINTAINER answeror 3 | 4 | RUN echo "deb http://ftp.sjtu.edu.cn/ubuntu/ trusty main restricted universe multiverse" > /etc/apt/sources.list && \ 5 | echo "deb http://ftp.sjtu.edu.cn/ubuntu/ trusty-security main restricted universe multiverse" >> /etc/apt/sources.list && \ 6 | echo "deb http://ftp.sjtu.edu.cn/ubuntu/ trusty-updates main restricted universe multiverse" >> /etc/apt/sources.list && \ 7 | echo "deb http://ftp.sjtu.edu.cn/ubuntu/ trusty-proposed main restricted universe multiverse" >> /etc/apt/sources.list && \ 8 | echo "deb http://ftp.sjtu.edu.cn/ubuntu/ trusty-backports main restricted universe multiverse" >> /etc/apt/sources.list && \ 9 | echo "deb-src http://ftp.sjtu.edu.cn/ubuntu/ trusty main restricted universe multiverse" >> /etc/apt/sources.list && \ 10 | echo "deb-src http://ftp.sjtu.edu.cn/ubuntu/ trusty-security main restricted universe multiverse" >> /etc/apt/sources.list && \ 11 | echo "deb-src http://ftp.sjtu.edu.cn/ubuntu/ trusty-updates main restricted universe multiverse" >> /etc/apt/sources.list && \ 12 | echo "deb-src http://ftp.sjtu.edu.cn/ubuntu/ trusty-proposed main restricted universe multiverse" >> /etc/apt/sources.list && \ 13 | echo "deb-src http://ftp.sjtu.edu.cn/ubuntu/ trusty-backports main restricted universe multiverse" >> /etc/apt/sources.list && \ 14 | apt-get -qqy update 15 | 16 | # mxnet 17 | RUN apt-get update && apt-get install -y \ 18 | build-essential \ 19 | git \ 20 | libopenblas-dev \ 21 | libopencv-dev \ 22 | python-numpy \ 23 | wget \ 24 | unzip 25 | RUN git clone --recursive https://github.com/dmlc/mxnet/ && cd mxnet && \ 26 | git checkout f2684a6 && \ 27 | sed -i -e 's/CHECK(ksize_x <= dshape\[3\] && ksize_y <= dshape\[2\])/CHECK(ksize_x <= dshape[3] + 2 * param_.pad[1] \&\& ksize_y <= dshape[2] + 2 * param_.pad[0])/' src/operator/convolution-inl.h && \ 28 | cp make/config.mk . && \ 29 | echo "USE_CUDA=1" >>config.mk && \ 30 | echo "USE_CUDA_PATH=/usr/local/cuda" >>config.mk && \ 31 | echo "USE_CUDNN=1" >>config.mk && \ 32 | echo "USE_BLAS=openblas" >>config.mk && \ 33 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs 34 | ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:$LD_LIBRARY_PATH 35 | 36 | ENV PYTHONPATH /mxnet/python 37 | -------------------------------------------------------------------------------- /exp_csl: -------------------------------------------------------------------------------- 1 | for i in $(seq 0 9); do 2 | python -m sigr.train_high_density_emg exp --log log --snapshot model \ 3 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 4 | --root .cache/CSL-piece_multistream-universal-intra-session-LR28-16-24-$i \ 5 | --batch-size 1000 --decay-all --adabn --dataset csliter \ 6 | --preprocess '(csl-cut,abs,ninapro-lowpass,downsample-5)' \ 7 | --balance-gesture 1 \ 8 | --num-filter 64 \ 9 | --num-semg-row 24 --num-semg-col 7 \ 10 | --feature-name 'piece_multistream' \ 11 | --fusion-type 'fuse_5' \ 12 | --window 1 \ 13 | --num-pixel 2 \ 14 | --dropout 0.5 \ 15 | --no-zscore \ 16 | crossval --crossval-type universal-intra-session --fold $i 17 | done 18 | 19 | for i in $(seq 0 249); do 20 | python -m sigr.train_high_density_emg exp --log log --snapshot model \ 21 | --num-epoch 10 --lr-step 4 --lr-step 8 --snapshot-period 10 \ 22 | --root .cache/CSL-piece_multistream-1-1-one-fold-intra-session-$i \ 23 | --params .cache/CSL-piece_multistream-universal-intra-session-LR28-16-24-$(($i % 10))/model-0028.params \ 24 | --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset csliter \ 25 | --preprocess '(csl-cut,abs,ninapro-lowpass)' \ 26 | --balance-gesture 1 \ 27 | --num-filter 64 \ 28 | --num-semg-row 24 --num-semg-col 7 \ 29 | --feature-name 'piece_multistream' \ 30 | --fusion-type 'fuse_5' \ 31 | --window 1 \ 32 | --num-pixel 2 \ 33 | --dropout 0.5 \ 34 | --no-zscore \ 35 | crossval --crossval-type intra-session --fold $i 36 | done 37 | -------------------------------------------------------------------------------- /exp_dbc: -------------------------------------------------------------------------------- 1 | # #dba 2 | # python -m sigr.train_high_density_emg exp --log log --snapshot model \ 3 | # --root .cache/capgmyo-piece_multistream-1-1-universal-one-fold-intra-subject \ 4 | # --batch-size 1000 --decay-all --dataset capgmyoiter \ 5 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 6 | # --num-filter 64 \ 7 | # --feature-name 'piece_multistream_v2' \ 8 | # --fusion-type 'fuse_5' \ 9 | # --num-pixel 2 \ 10 | # --window 1 \ 11 | # --num-semg-row 16 --num-semg-col 8 \ 12 | # --no-zscore \ 13 | # --gpu 1 \ 14 | # crossval --crossval-type universal-one-fold-intra-subject --fold 0 15 | 16 | # for i in $(seq 0 17 | shuf); do 17 | # python -m sigr.train_high_density_emg exp --log log --snapshot model \ 18 | # --root .cache/capgmyo-piece_multistream-1-1-one-fold-intra-subject-fold-$i \ 19 | # --params .cache/capgmyo-piece_multistream-1-1-universal-one-fold-intra-subject/model-0028.params \ 20 | # --batch-size 1000 --decay-all --dataset capgmyoiter \ 21 | # --num-filter 64 \ 22 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 23 | # --feature-name 'piece_multistream_v2' \ 24 | # --fusion-type 'fuse_5' \ 25 | # --num-pixel 2 \ 26 | # --window 1 \ 27 | # --num-semg-row 16 --num-semg-col 8 \ 28 | # --no-zscore \ 29 | # --gpu 1 \ 30 | # crossval --crossval-type one-fold-intra-subject --fold $i 31 | # done 32 | 33 | ##dbb 34 | # python -m sigr.train_high_density_emg exp --log log --snapshot model \ 35 | # --root .cache/capgmyo-dbb-piece_multistream-1-1-universal-one-fold-intra-subject \ 36 | # --batch-size 1000 --decay-all --dataset capgmyoiter_dbb \ 37 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 38 | # --num-filter 64 \ 39 | # --feature-name 'piece_multistream_v2' \ 40 | # --fusion-type 'fuse_5' \ 41 | # --num-pixel 2 \ 42 | # --window 1 \ 43 | # --num-semg-row 16 --num-semg-col 8 \ 44 | # --no-zscore \ 45 | # --gpu 1 \ 46 | # crossval --crossval-type universal-one-fold-intra-subject --fold 0 47 | 48 | # for i in $(seq 0 9 | shuf); do 49 | # python -m sigr.train_high_density_emg exp --log log --snapshot model \ 50 | # --root .cache/capgmyo-dbb-piece_multistream-1-1-one-fold-intra-subject-fold-$i \ 51 | # --params .cache/capgmyo-dbb-piece_multistream-1-1-universal-one-fold-intra-subject/model-0028.params \ 52 | # --batch-size 1000 --decay-all --dataset capgmyoiter_dbb \ 53 | # --num-filter 64 \ 54 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 55 | # --feature-name 'piece_multistream_v2' \ 56 | # --fusion-type 'fuse_5' \ 57 | # --num-pixel 2 \ 58 | # --window 1 \ 59 | # --num-semg-row 16 --num-semg-col 8 \ 60 | # --no-zscore \ 61 | # --gpu 1 \ 62 | # crossval --crossval-type one-fold-intra-subject --fold $i 63 | # done 64 | 65 | 66 | 67 | 68 | python -m sigr.train_high_density_emg exp --log log --snapshot model \ 69 | --root .cache/capgmyo-dbc-piece_multistream-1-1-universal-one-fold-intra-subject \ 70 | --batch-size 1000 --decay-all --dataset capgmyoiter_dbc \ 71 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 72 | --num-filter 64 \ 73 | --feature-name 'piece_multistream_v2' \ 74 | --fusion-type 'fuse_5' \ 75 | --num-pixel 2 \ 76 | --window 1 \ 77 | --num-semg-row 16 --num-semg-col 8 \ 78 | --no-zscore \ 79 | --gpu 1 \ 80 | crossval --crossval-type universal-one-fold-intra-subject --fold 0 81 | 82 | for i in $(seq 0 9 | shuf); do 83 | python -m sigr.train_high_density_emg exp --log log --snapshot model \ 84 | --root .cache/capgmyo-dbc-piece_multistream-1-1-one-fold-intra-subject-fold-$i \ 85 | --params .cache/capgmyo-dbc-piece_multistream-1-1-universal-one-fold-intra-subject/model-0028.params \ 86 | --batch-size 1000 --decay-all --dataset capgmyoiter_dbc \ 87 | --num-filter 64 \ 88 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 89 | --feature-name 'piece_multistream_v2' \ 90 | --fusion-type 'fuse_5' \ 91 | --num-pixel 2 \ 92 | --window 1 \ 93 | --num-semg-row 16 --num-semg-col 8 \ 94 | --no-zscore \ 95 | --gpu 1 \ 96 | crossval --crossval-type one-fold-intra-subject --fold $i 97 | done -------------------------------------------------------------------------------- /exp_ninapro: -------------------------------------------------------------------------------- 1 | ver=1.0.0.0 2 | python -m sigr.train_sigimg_actimg_fast exp --log log --snapshot model \ 3 | --root .cache/ninapro-db1-ch_multistream-20-1-universal-one-fold-intra-subject \ 4 | --batch-size 1000 --decay-all --dataset ninapro-db1-sigimg-fast \ 5 | --num-filter 64 \ 6 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 7 | --balance-gesture 1 \ 8 | --feature-name 'ch_multistream' \ 9 | --window 20 \ 10 | --num-pixel 2 \ 11 | --fusion-type 'fuse_5' \ 12 | --num-semg-row 1 --num-semg-col 10 \ 13 | --preprocess 'ninapro-lowpass' \ 14 | --no-zscore \ 15 | --gpu 0 \ 16 | crossval --crossval-type universal-one-fold-intra-subject --fold 0 17 | 18 | 19 | ver=1.0.0.1 20 | for i in $(seq 0 26); do 21 | python -m sigr.train_sigimg_actimg_fast exp --log log --snapshot model \ 22 | --root .cache/ninapro-db1-ch_multistream-20-1-one-fold-intra-subject-fold-$i-v$ver \ 23 | --batch-size 1000 --decay-all --dataset ninapro-db1-sigimg-fast \ 24 | --params .cache/ninapro-db1-ch_multistream-20-1-universal-one-fold-intra-subject/model-0028.params \ 25 | --num-filter 64 \ 26 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 27 | --balance-gesture 1 \ 28 | --feature-name 'ch_multistream' \ 29 | --window 20 \ 30 | --num-pixel 2 \ 31 | --fusion-type 'fuse_5' \ 32 | --num-semg-row 1 --num-semg-col 10 \ 33 | --preprocess 'ninapro-lowpass' \ 34 | --no-zscore \ 35 | --gpu 0 \ 36 | crossval --crossval-type one-fold-intra-subject --fold $i 37 | done 38 | 39 | 40 | -------------------------------------------------------------------------------- /scripts/exp-b: -------------------------------------------------------------------------------- 1 | ver=1.0.0.0 2 | python -m sigr.train_feature exp --log log --snapshot model \ 3 | --root .cache/ninapro-db1-actimg-15-1-universal-one-fold-intra-subject \ 4 | --batch-size 1000 --decay-all --dataset ninapro-db1-var-raw-prepro-lowpass \ 5 | --num-filter 64 \ 6 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 7 | --balance-gesture 1 \ 8 | --num-semg-row 15 --num-semg-col 50 \ 9 | crossval --crossval-type universal-one-fold-intra-subject --fold 0 10 | 11 | 12 | ver=1.0.0.1 13 | for i in $(seq 0 26); do 14 | python -m sigr.train_feature exp --log log --snapshot model \ 15 | --root .cache/ninapro-db1-actimg-15-1-one-fold-intra-subject-fold-$i-v$ver \ 16 | --batch-size 1000 --decay-all --dataset ninapro-db1-var-raw-prepro-lowpass \ 17 | --params .cache/ninapro-db1-actimg-15-1-universal-one-fold-intra-subject/model-0028.params \ 18 | --num-filter 64 \ 19 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 20 | --balance-gesture 1 \ 21 | --num-semg-row 15 --num-semg-col 50 \ 22 | crossval --crossval-type one-fold-intra-subject --fold $i 23 | done -------------------------------------------------------------------------------- /scripts/exp-ws43: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ver=843 4 | # for i in $(seq 0 9); do 5 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-augscale-adabn-$i-v$ver --fold $i --batch-size 900 --num-pixel 2 --num-filter 16 --adabn --minibatch --random-scale 1 6 | # done 7 | # ver=868 8 | # for i in $@; do 9 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-w150-lstm-adabn-$i-v$ver --fold $i --batch-size 180 --num-pixel 2 --num-filter 16 --adabn --minibatch --window 150 --adabn-num-epoch 10 --num-epoch 10 --lr-step 4 --lr-step 8 --lstm --lstm-last --lstm-dropout 0.5 10 | # done 11 | # ver=924 12 | # for i in $@; do 13 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-w150-lstm-adabn-$i-v$ver --fold $i --batch-size 900 --num-pixel 2 --num-filter 16 --adabn --window 150 --lstm-window 5 --num-adabn-epoch 1 --lstm --minibatch --num-lstm-hidden 128 --lstm-last 1 --lstm-dropout 0.5 14 | # done 15 | # ver=934 16 | # for i in $@; do 17 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-calib-w150-faug-lstm-adabn-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --window 150 --lstm-window 15 --num-adabn-epoch 1 --lstm --num-lstm-hidden 16 --lstm-last 1 --lstm-dropout 0.5 --only-calib --params .cache/sigr-inter-w150-lstm-adabn-$i-v927/model-0060.params --adabn --faug 0.5 18 | # done 19 | # ver=937 20 | # for i in $@; do 21 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-calib-w150-faug-lstm-adabn-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --window 150 --lstm-window 15 --num-adabn-epoch 1 --lstm --num-lstm-hidden 16 --lstm-last 1 --lstm-dropout 0.5 --only-calib --params .cache/sigr-inter-w150-lstm-adabn-$i-v927/model-0060.params --adabn --lr 0.01 --faug 1 22 | # done 23 | # ver=955.70 24 | # for i in $@; do 25 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-$i-v$ver --fold $i --batch-size 900 --adabn --minibatch --pixel-reduce-smooth --pixel-reduce-loss-weight 900 26 | # done 27 | 28 | # ver=957.50 29 | # for i in $(seq 0 9); do 30 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 31 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 32 | # --root .cache/sigr-csl-universal-intra-session-$i-v$ver \ 33 | # --num-semg-row 24 --num-semg-col 7 \ 34 | # --batch-size 2500 --decay-all --adabn --minibatch --dataset csl \ 35 | # --preprocess '(csl-bandpass,csl-cut,downsample-5,median)' \ 36 | # --balance-gesture 1 \ 37 | # crossval --crossval-type universal-intra-session --fold $i 38 | # done 39 | # ver=957.51 40 | # for i in $(seq 0 249); do 41 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 42 | # --num-epoch 14 --lr-step 8 --lr-step 12 --snapshot-period 14 \ 43 | # --root .cache/sigr-csl-intra-session-$i-v$ver \ 44 | # --num-semg-row 24 --num-semg-col 7 \ 45 | # --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset csl \ 46 | # --preprocess '(csl-bandpass,csl-cut,median)' \ 47 | # --balance-gesture 1 \ 48 | # --params .cache/sigr-csl-universal-intra-session-$(($i % 10))-v957.50/model-0028.params \ 49 | # crossval --crossval-type intra-session --fold $i 50 | # done 51 | 52 | ver=958.16 53 | for i in 0; do 54 | scripts/sigr python -m sigr.app exp --log log --snapshot model \ 55 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 56 | --root .cache/sigr-ninapro-db1-universal-one-fold-intra-subject-$i-v$ver \ 57 | --num-semg-row 1 --num-semg-col 10 \ 58 | --batch-size 1000 --decay-all --dataset ninapro-db1 \ 59 | --num-filter 64 \ 60 | crossval --crossval-type universal-one-fold-intra-subject --fold $i 61 | done 62 | ver=958.16.1 63 | for i in $(seq 0 26 | shuf); do 64 | scripts/sigr python -m sigr.app exp --log log --snapshot model \ 65 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 66 | --root .cache/sigr-ninapro-db1-one-fold-intra-subject-$i-v$ver \ 67 | --num-semg-row 1 --num-semg-col 10 \ 68 | --batch-size 1000 --decay-all --dataset ninapro-db1 \ 69 | --num-filter 64 \ 70 | --params .cache/sigr-ninapro-db1-universal-one-fold-intra-subject-0-v958.16/model-0028.params \ 71 | crossval --crossval-type one-fold-intra-subject --fold $i 72 | done 73 | 74 | ver=958.17 75 | for i in $(seq 0 4 | shuf); do 76 | scripts/sigr python -m sigr.app exp --log log --snapshot model \ 77 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 78 | --root .cache/sigr-universal-inter-session-$i-v$ver \ 79 | --num-semg-row 24 --num-semg-col 7 \ 80 | --batch-size 1000 --decay-all --dataset csl \ 81 | --preprocess '(csl-bandpass,csl-cut,downsample-5,median)' \ 82 | --balance-gesture 1 \ 83 | --num-filter 64 \ 84 | crossval --crossval-type universal-inter-session --fold $i 85 | done 86 | ver=958.17.1 87 | for i in $(seq 0 24 | shuf); do 88 | scripts/sigr python -m sigr.app exp --log log --snapshot model \ 89 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 90 | --root .cache/sigr-csl-inter-session-$i-v$ver \ 91 | --num-semg-row 24 --num-semg-col 7 \ 92 | --batch-size 1000 --decay-all --dataset csl \ 93 | --preprocess '(csl-bandpass,csl-cut,median)' \ 94 | --balance-gesture 1 \ 95 | --params .cache/sigr-universal-inter-session-$(($i % 5))-v958.17/model-0028.params \ 96 | --num-filter 64 \ 97 | crossval --crossval-type inter-session --fold $i 98 | done 99 | 100 | ver=958.18 101 | for i in $(seq 0 9 | shuf); do 102 | scripts/sigr python -m sigr.app exp --log log --snapshot model \ 103 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 104 | --root .cache/sigr-csl-universal-intra-session-$i-v$ver \ 105 | --num-semg-row 24 --num-semg-col 7 \ 106 | --batch-size 2500 --decay-all --adabn --minibatch --dataset csl \ 107 | --preprocess '(csl-bandpass,csl-cut,downsample-5,median)' \ 108 | --balance-gesture 1 \ 109 | --num-filter 64 \ 110 | --num-pixel 0 \ 111 | crossval --crossval-type universal-intra-session --fold $i 112 | done 113 | ver=958.18.1 114 | for i in $(seq 0 249 | shuf); do 115 | scripts/sigr python -m sigr.app exp --log log --snapshot model \ 116 | --num-epoch 14 --lr-step 8 --lr-step 12 --snapshot-period 14 \ 117 | --root .cache/sigr-csl-intra-session-$i-v$ver \ 118 | --num-semg-row 24 --num-semg-col 7 \ 119 | --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset csl \ 120 | --preprocess '(csl-bandpass,csl-cut,median)' \ 121 | --balance-gesture 1 \ 122 | --params .cache/sigr-csl-universal-intra-session-$(($i % 10))-v958.18/model-0028.params \ 123 | --num-filter 64 \ 124 | --num-pixel 0 \ 125 | crossval --crossval-type intra-session --fold $i 126 | done 127 | 128 | ver=958.19 129 | for i in $(seq 0 9 | shuf); do 130 | scripts/sigr python -m sigr.app exp --log log --snapshot model \ 131 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 132 | --root .cache/sigr-csl-universal-intra-session-$i-v$ver \ 133 | --num-semg-row 24 --num-semg-col 7 \ 134 | --batch-size 2500 --decay-all --adabn --minibatch --dataset csl \ 135 | --preprocess '(csl-bandpass,csl-cut,downsample-5,median)' \ 136 | --balance-gesture 1 \ 137 | --num-filter 64 \ 138 | --num-pixel 0 --num-conv 4 \ 139 | crossval --crossval-type universal-intra-session --fold $i 140 | done 141 | ver=958.19.1 142 | for i in $(seq 0 249 | shuf); do 143 | scripts/sigr python -m sigr.app exp --log log --snapshot model \ 144 | --num-epoch 14 --lr-step 8 --lr-step 12 --snapshot-period 14 \ 145 | --root .cache/sigr-csl-intra-session-$i-v$ver \ 146 | --num-semg-row 24 --num-semg-col 7 \ 147 | --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset csl \ 148 | --preprocess '(csl-bandpass,csl-cut,median)' \ 149 | --balance-gesture 1 \ 150 | --params .cache/sigr-csl-universal-intra-session-$(($i % 10))-v958.19/model-0028.params \ 151 | --num-filter 64 \ 152 | --num-pixel 0 --num-conv 4 \ 153 | crossval --crossval-type intra-session --fold $i 154 | done 155 | -------------------------------------------------------------------------------- /scripts/exp-ws46: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ver=843 4 | # for i in $(seq 0 9); do 5 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-$i-v$ver --fold $i --batch-size 900 --num-pixel 2 --num-filter 16 --adabn --minibatch 6 | # done 7 | 8 | # ver=847 9 | # for i in $@; do 10 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-w150-lstm-adabn-$i-v$ver --fold $i --batch-size 900 --num-pixel 2 --num-filter 16 --adabn --minibatch --window 150 --adabn-num-epoch 1 --num-epoch 30 --lr-step 10 --lr-step 20 --lstm 11 | # done 12 | # ver=869 13 | # for i in $@; do 14 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-w150-lstm-adabn-$i-v$ver --fold $i --batch-size 180 --num-pixel 2 --num-filter 16 --adabn --minibatch --window 150 --adabn-num-epoch 10 --num-epoch 10 --lr-step 4 --lr-step 8 --lstm --lstm-last --lstm-dropout 0.5 --params .cache/sigr-inter-adabn-$i-v843/model-0060.params --ignore-params 'gesture_last_fc_.*' 15 | # done 16 | # ver=927 17 | # for i in $@; do 18 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-w150-lstm-adabn-$i-v$ver --fold $i --batch-size 900 --num-pixel 2 --num-filter 16 --adabn --window 150 --lstm-window 15 --num-adabn-epoch 1 --lstm --minibatch --num-lstm-hidden 16 --lstm-last 1 --lstm-dropout 0.5 19 | # done 20 | # ver=927 21 | # for i in $@; do 22 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-calib-w150-lstm-adabn-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --window 150 --lstm-window 15 --num-adabn-epoch 1 --lstm --num-lstm-hidden 16 --lstm-last 1 --lstm-dropout 0.5 --only-calib --params .cache/sigr-inter-w150-lstm-adabn-$i-v927/model-0060.params --adabn 23 | # done 24 | # ver=927 25 | # for i in $@; do 26 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-calib-w150-aug-lstm-adabn-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --window 150 --lstm-window 15 --num-adabn-epoch 1 --lstm --num-lstm-hidden 16 --lstm-last 1 --lstm-dropout 0.5 --only-calib --params .cache/sigr-inter-w150-lstm-adabn-$i-v927/model-0060.params --adabn --random-bad-channel -1 --random-bad-channel 0 --random-bad-channel 1 --random-scale 1 27 | # done 28 | # ver=932 29 | # for i in $@; do 30 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-calib-w150-lstm-adabn-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --window 150 --lstm-window 15 --num-adabn-epoch 1 --lstm --num-lstm-hidden 16 --lstm-last 1 --lstm-dropout 0.5 --only-calib --params .cache/sigr-inter-w150-lstm-adabn-$i-v927/model-0060.params --adabn --lr 0.001 31 | # done 32 | # ver=933 33 | # for i in $@; do 34 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-calib-w150-faug-lstm-adabn-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --window 150 --lstm-window 15 --num-adabn-epoch 1 --lstm --num-lstm-hidden 16 --lstm-last 1 --lstm-dropout 0.5 --only-calib --params .cache/sigr-inter-w150-lstm-adabn-$i-v927/model-0060.params --adabn --lr 0.001 --faug 0.5 35 | # done 36 | # ver=955.73 37 | # for i in $@; do 38 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-$i-v$ver --fold $i --batch-size 900 --adabn --minibatch --pixel-reduce-smooth --pixel-reduce-loss-weight 0 39 | # done 40 | 41 | # ver=957.40.2 42 | # for i in $(seq 0 9); do 43 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 44 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 45 | # --root .cache/sigr-dbb-inter-subject-$i-v$ver \ 46 | # --num-semg-row 16 --num-semg-col 8 \ 47 | # --batch-size 1800 --decay-all --adabn --minibatch --dataset dbb \ 48 | # --preprocess '(median)' \ 49 | # crossval --crossval-type inter-subject --fold $i 50 | # done 51 | # ver=957.40 52 | # for i in $(seq 0 9); do 53 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 54 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 55 | # --root .cache/sigr-dbb-inter-subject-$i-v$ver \ 56 | # --num-semg-row 16 --num-semg-col 8 \ 57 | # --batch-size 1800 --decay-all --adabn --minibatch --dataset dbb \ 58 | # crossval --crossval-type inter-subject --fold $i 59 | # done 60 | # ver=957.41 61 | # for i in $(seq 0 9); do 62 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 63 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 64 | # --root .cache/sigr-dbb-universal-intra-subject-$i-v$ver \ 65 | # --num-semg-row 16 --num-semg-col 8 \ 66 | # --batch-size 2000 --decay-all --adabn --minibatch --dataset dbb \ 67 | # crossval --crossval-type universal-intra-subject --fold $i 68 | # done 69 | # ver=957.42 70 | # for i in $@; do 71 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 72 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 73 | # --root .cache/sigr-dbb-intra-subject-$i-v$ver \ 74 | # --num-semg-row 16 --num-semg-col 8 \ 75 | # --batch-size 1000 --decay-all --dataset dbb \ 76 | # --params .cache/sigr-dbb-universal-intra-subject-$(($i % 10))-v957.41/model-0028.params \ 77 | # crossval --crossval-type intra-subject --fold $i 78 | # done 79 | # ver=957.42.1 80 | # for i in $(seq 0 99); do 81 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 82 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 83 | # --root .cache/sigr-dbb-intra-subject-$i-v$ver \ 84 | # --num-semg-row 16 --num-semg-col 8 \ 85 | # --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset dbb \ 86 | # --params .cache/sigr-dbb-universal-intra-subject-$(($i % 10))-v957.41/model-0028.params \ 87 | # crossval --crossval-type intra-subject --fold $i 88 | # done 89 | # ver=957.42.2 90 | # for i in $@; do 91 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 92 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 93 | # --root .cache/sigr-dbb-intra-subject-$i-v$ver \ 94 | # --num-semg-row 16 --num-semg-col 8 \ 95 | # --batch-size 200 --decay-all --adabn --num-adabn-epoch 10 --dataset dbb \ 96 | # --params .cache/sigr-dbb-universal-intra-subject-$(($i % 10))-v957.41/model-0028.params \ 97 | # crossval --crossval-type intra-subject --fold $i 98 | # done 99 | # ver=957.42.3 100 | # for i in $@; do 101 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 102 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 103 | # --root .cache/sigr-dbb-intra-subject-$i-v$ver \ 104 | # --num-semg-row 16 --num-semg-col 8 \ 105 | # --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset dbb \ 106 | # --num-filter 64 \ 107 | # crossval --crossval-type intra-subject --fold $i 108 | # done 109 | # ver=957.42.4 110 | # for i in $@; do 111 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 112 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 113 | # --root .cache/sigr-dbb-intra-subject-$i-v$ver \ 114 | # --num-semg-row 16 --num-semg-col 8 \ 115 | # --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset dbb \ 116 | # crossval --crossval-type intra-subject --fold $i 117 | # done 118 | # ver=957.42.5 119 | # for i in $@; do 120 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 121 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 122 | # --root .cache/sigr-dbb-intra-subject-$i-v$ver \ 123 | # --num-semg-row 16 --num-semg-col 8 \ 124 | # --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset dbb \ 125 | # --preprocess '(median)' \ 126 | # crossval --crossval-type intra-subject --fold $i 127 | # done 128 | # ver=957.43 129 | # for i in $@; do 130 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 131 | # --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 132 | # --root .cache/sigr-dba-intra-subject-$i-v$ver \ 133 | # --num-semg-row 16 --num-semg-col 8 \ 134 | # --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset dba \ 135 | # crossval --crossval-type intra-subject --fold $i 136 | # done 137 | 138 | # ver=957.51 139 | # for i in $@; do 140 | # scripts/sigr python -m sigr.app exp --log log --snapshot model \ 141 | # --num-epoch 14 --lr-step 8 --lr-step 12 --snapshot-period 14 \ 142 | # --root .cache/sigr-csl-intra-session-$i-v$ver \ 143 | # --num-semg-row 24 --num-semg-col 7 \ 144 | # --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset csl \ 145 | # --preprocess '(csl-bandpass,csl-cut,median)' \ 146 | # --balance-gesture 1 \ 147 | # --params .cache/sigr-csl-universal-intra-session-$(($i % 10))-v957.50/model-0028.params \ 148 | # crossval --crossval-type intra-session --fold $i 149 | # done 150 | -------------------------------------------------------------------------------- /scripts/exp2: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ver=509 4 | # for i in $(seq 0 9); do 5 | # scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 1 --subject-confusion-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch 6 | # done 7 | ver=510 8 | for i in $(seq 0 9); do 9 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 1 --subject-confusion-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 10 | done 11 | -------------------------------------------------------------------------------- /scripts/exp3: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ver=509 4 | for i in $(seq 0 9); do 5 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-tzeng-$i-v$ver --fold $i --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 1 --subject-confusion-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch 6 | done 7 | -------------------------------------------------------------------------------- /scripts/exp_revgrad: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ver=405 4 | # for i in $(seq 8 8); do 5 | # scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --revgrad --params .cache/sigr-inter-adabn-$i-v403/model-0060.params --gamma 1e8 --lr 0.01 --batch-size 2000 --num-filter 16 --lambda-scale 0.1 6 | # done 7 | # ver=512 8 | # for i in $(seq 0 0); do 9 | # scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all 10 | # done 11 | # ver=608 12 | # for i in $(seq 3 9); do 13 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 0.1 --lambda-scale 10 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all --gamma 1e8 --revgrad-num-batch 2 14 | # done 15 | # ver=610 16 | # for i in $(seq 3 3); do 17 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 0.01 --lambda-scale 10 --num-subject-block 0 --minibatch --confuse-all --gamma 1e8 --revgrad-num-batch 3 18 | # done 19 | # ver=701 20 | # for i in $(seq 0 9); do 21 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 1 --lambda-scale 0.1 --num-subject-block 2 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all --gamma 1e8 --revgrad-num-batch 2 22 | # done 23 | # ver=702 24 | # for i in $(seq 0 9); do 25 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 1 --lambda-scale 0.1 --num-subject-block 2 --minibatch --confuse-all --gamma 1e8 --revgrad-num-batch 3 26 | # done 27 | #ver=617 28 | #for i in $(seq 0 9); do 29 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all --gamma 1e8 --revgrad-num-batch 10 30 | #done 31 | #ver=618 32 | #for i in $(seq 0 9); do 33 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all --gamma 1e8 --revgrad-num-batch 1 34 | #done 35 | #ver=619 36 | #for i in $(seq 0 9); do 37 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.1 --num-epoch 15 --lr-step 5 --confuse-all --gamma 1e8 --revgrad-num-batch 100 38 | #done 39 | # ver=705 40 | # for i in 7 8 4 5; do 41 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 1 --lambda-scale 0.5 --num-subject-block 2 --minibatch --confuse-all --gamma 1e8 --revgrad-num-batch 3 --num-epoch 120 --lr-step 40 42 | # done 43 | # ver=710 44 | # for i in 3 2 1 4 0; do 45 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 1 --lambda-scale 0.1 --num-subject-block 2 --minibatch --confuse-all --gamma 1e8 --revgrad-num-batch 2 46 | # done 47 | # ver=713 48 | # for i in 3 2 1 4 0; do 49 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 1 --lambda-scale 0.5 --num-subject-block 2 --minibatch --confuse-all --gamma 1e8 --revgrad-num-batch 2 --params .cache/sigr-inter-adabn-$i-v506/model-0060.params 50 | # done 51 | # ver=723 52 | # for i in 3 2 1 4 0; do 53 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --confuse-all --gamma 1e8 --revgrad-num-batch 4 --params .cache/sigr-inter-adabn-$i-v506/model-0060.params 54 | # done 55 | # ver=738 56 | # for i in $(seq 0 9); do 57 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 900 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 0.1 --lambda-scale 1 --num-subject-block 2 --minibatch --confuse-all --gamma 1e8 --revgrad-num-batch 2 --drop-branch 58 | # done 59 | ver=815 60 | for i in 5 6 7 8 9; do 61 | scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-revgrad-$i-v$ver --fold $i --adabn --batch-size 900 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 1 --lambda-scale 0.01 --num-subject-block 0 --minibatch --gamma 1e8 --revgrad-num-batch 2 62 | done 63 | -------------------------------------------------------------------------------- /scripts/exp_revgrad2: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ver=706 4 | for i in 0 1 2 3 4; do 5 | scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-revgrad-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --revgrad --subject-loss-weight 1 --lambda-scale 0.5 --num-subject-block 2 --confuse-all --gamma 1e8 --revgrad-num-batch 3 --num-epoch 120 --lr-step 40 6 | done 7 | -------------------------------------------------------------------------------- /scripts/exp_tzeng: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ver=432 4 | # for i in $(seq 0 9); do 5 | # scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --lr 0.1 --batch-size 2000 --subject-loss-weight 1 --subject-confusion-loss-weight 0.1 --num-filter 16 --tzeng --num-epoch 60 --lr-step 20 --lambda-scale 1 --tzeng-num-batch 10 6 | # done 7 | # ver=511 8 | # for i in $(seq 7 9); do 9 | # scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 1 --subject-confusion-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-conv 10 | # done 11 | # ver=512 12 | # for i in 2 3 4; do 13 | # scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all 14 | # done 15 | # ver=600 16 | # for i in $(seq 0 9); do 17 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-conv 18 | # done 19 | # ver=606 20 | # for i in $(seq 0 9); do 21 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-tzeng-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --params .cache/sigr-inter-$i-v411/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 22 | # done 23 | # ver=607 24 | # for i in $(seq 0 9); do 25 | # scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-tzeng-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 26 | # done 27 | ver=750 28 | for i in 6 5 7 8 9; do 29 | scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 900 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 1 --subject-confusion-loss-weight 0.01 --lambda-scale 1 --num-subject-block 0 --minibatch --tzeng-num-batch 10 --lr-step 40 --lr-step 60 --lr-step 80 30 | done 31 | -------------------------------------------------------------------------------- /scripts/exp_tzeng2: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ver=512 4 | # for i in 5 6 9; do 5 | # scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all 6 | # done 7 | ver=604 8 | for i in $(seq 0 9); do 9 | scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-tzeng-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --params .cache/sigr-inter-$i-v411/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-conv 10 | done 11 | ver=605 12 | for i in $(seq 0 9); do 13 | scripts/sigr python -m sigr.app inter --log log --snapshot model --root .cache/sigr-inter-tzeng-$i-v$ver --fold $i --batch-size 1000 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --params .cache/sigr-inter-$i-v411/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all 14 | done 15 | -------------------------------------------------------------------------------- /scripts/exp_tzeng3: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ver=521 4 | for i in $(seq 9 9); do 5 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all 6 | done 7 | ver=522 8 | for i in $(seq 0 9); do 9 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.01 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.01 --num-epoch 15 --lr-step 5 --confuse-all 10 | done 11 | ver=525 12 | for i in $(seq 0 9); do 13 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-tzeng-$i-v$ver --fold $i --batch-size 2000 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.01 --subject-confusion-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --confuse-all 14 | done 15 | ver=526 16 | for i in $(seq 0 9); do 17 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.01 --subject-confusion-loss-weight 0.1 --lambda-scale 1 --num-subject-block 0 --minibatch --confuse-all 18 | done 19 | -------------------------------------------------------------------------------- /scripts/exp_tzeng4: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ver=524 4 | for i in $(seq 0 9); do 5 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-tzeng-$i-v$ver --fold $i --batch-size 2000 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --minibatch --confuse-all 6 | done 7 | ver=523 8 | for i in $(seq 0 9); do 9 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.01 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --minibatch --params .cache/sigr-inter-adabn-$i-v506/model-0060.params --lr 0.1 --num-epoch 30 --lr-step 10 --confuse-all 10 | done 11 | ver=527 12 | for i in $(seq 0 9); do 13 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.1 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --minibatch --confuse-all 14 | done 15 | -------------------------------------------------------------------------------- /scripts/exp_tzeng5: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ver=528 4 | for i in $(seq 0 9); do 5 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-tzeng-$i-v$ver --fold $i --batch-size 2000 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.01 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --confuse-all 6 | done 7 | ver=529 8 | for i in $(seq 0 9); do 9 | scripts/sigr python -m sigr.app inter --gpu 0 --gpu 1 --log log --snapshot model --root .cache/sigr-inter-adabn-tzeng-$i-v$ver --fold $i --adabn --batch-size 1800 --num-pixel 2 --num-filter 16 --tzeng --subject-loss-weight 0.01 --subject-confusion-loss-weight 1 --lambda-scale 1 --num-subject-block 0 --minibatch --confuse-all 10 | done 11 | -------------------------------------------------------------------------------- /scripts/mount-cache: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | sudo mount -t cifs -o user=answeror,iocharset=utf8,noperm //fat.zju-capg.org/home/deepsemg-cache .cache 4 | -------------------------------------------------------------------------------- /scripts/ninapro_lowpass.m: -------------------------------------------------------------------------------- 1 | function ninapro_lowpass(in,out) 2 | for subject=0:26 3 | for gesture=1:52 4 | for trial=0:9 5 | path = sprintf('%03d/%03d/%03d_%03d_%03d.mat',... 6 | subject,gesture,subject,gesture,trial); 7 | fprintf([path '\n']); 8 | deal_one(in,out,path); 9 | end 10 | end 11 | end 12 | end 13 | 14 | function deal_one(in,out,path) 15 | in = [in '/' path]; 16 | out = [out '/' path]; 17 | dir = out(1:end-16); 18 | if ~exist(dir,'dir') 19 | mkdir(dir); 20 | end 21 | f = load(in); 22 | data = f.data; 23 | label = f.label; 24 | subject = f.subject; 25 | parfor ch=1:10 26 | data(:,ch) = lowpass(data(:,ch)); 27 | end 28 | save(out,'data','label','subject'); 29 | end 30 | 31 | function y = lowpass(x) 32 | fc = 1; 33 | fs = 100; 34 | [b,a] = butter(1,fc/(fs/2)); 35 | y = filtfilt(b,a,x); 36 | end -------------------------------------------------------------------------------- /scripts/rundocker: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | nvidia-docker run -ti -v $(pwd):/code $@ 4 | -------------------------------------------------------------------------------- /scripts/runsrep: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | scripts/rundocker answeror/sigr:2016-09-21 $@ 4 | -------------------------------------------------------------------------------- /scripts/sigr: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | scripts/rundocker answeror/sigr:2016-07-06 $@ 4 | -------------------------------------------------------------------------------- /scripts/test_csl_multistream.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | sys.path.insert(0, os.getcwd()) 5 | import numpy as np 6 | import mxnet as mx 7 | import scipy.io as sio 8 | from sigr.evaluation_dbcmultistream import CrossValEvaluation as CV, Exp 9 | from sigr.data import Preprocess, Dataset 10 | from sigr import Context 11 | 12 | one_fold_intra_subject_eval = CV(crossval_type='one-fold-intra-subject', batch_size=1000) 13 | intra_session_eval = CV(crossval_type='intra-session', batch_size=1000) 14 | 15 | print('CSL Multistream') 16 | print('===========') 17 | 18 | 19 | #semg_row = [] 20 | #semg_col = [] 21 | #num_ch = [] 22 | #for i in range(10): 23 | # num_ch.append(1) 24 | # semg_row.append(1) 25 | # semg_col.append(20) 26 | #window=20 27 | #num_raw_semg_row=1 28 | #num_raw_semg_col=10 29 | #feature_name = 'ch_multistream' 30 | #fusion_type = 'fuse_2' 31 | # 32 | #with Context(parallel=True, level='DEBUG'): 33 | # acc = one_fold_intra_subject_eval.vote_accuracy_curves( 34 | # [Exp(dataset=Dataset.from_name('ninapro-db1-sigimg-fast'), 35 | # dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 36 | # Mod=dict(num_gesture=52, 37 | # context=[mx.gpu(0)], 38 | # multi_stream = True, 39 | # num_stream=len(semg_col), 40 | # symbol_kargs=dict(dropout=0, num_stream=len(semg_col), fusion_type=fusion_type, num_semg_row=semg_row, num_semg_col=semg_col, num_channel=num_ch, num_filter=64), 41 | # params='.cache/ninapro-db1-ch_multistream-20-1-one-fold-intra-subject-fold-%d-v1.0.0.1/model-0028.params'))], 42 | # folds=np.arange(27), 43 | # windows=np.arange(1, 5), 44 | # window=window, 45 | # num_semg_row = num_raw_semg_row, 46 | # num_semg_col = num_raw_semg_col, 47 | # feature_name = feature_name, 48 | # balance=True) 49 | # acc = acc.mean(axis=(0, 1)) 50 | # print('Single frame accuracy: %f' % acc[0]) 51 | ## print('5 frames (50 ms) majority voting accuracy: %f' % acc[4]) 52 | ## print('10 frames (100 ms) majority voting accuracy: %f' % acc[9]) 53 | ## print('15 frames (150 ms) majority voting accuracy: %f' % acc[14]) 54 | ## print('20 frames (200 ms) majority voting accuracy: %f' % acc[19]) 55 | ## print('25 frames (250 ms) majority voting accuracy: %f' % acc[24]) 56 | 57 | 58 | 59 | 60 | 61 | #semg_row = [] 62 | #semg_col = [] 63 | #num_ch = [] 64 | #for i in range(2): 65 | # num_ch.append(1) 66 | # semg_row.append(1) 67 | # semg_col.append(10) 68 | #window=1 69 | #num_raw_semg_row=1 70 | #num_raw_semg_col=10 71 | #feature_name = 'singleframe_multistream' 72 | #fusion_type = 'multistream_multistruct_fuse_1' 73 | # 74 | #with Context(parallel=True, level='DEBUG'): 75 | # acc = one_fold_intra_subject_eval.vote_accuracy_curves( 76 | # [Exp(dataset=Dataset.from_name('ninapro-db1-sigimg-fast'), 77 | # dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 78 | # Mod=dict(num_gesture=52, 79 | # context=[mx.gpu(0)], 80 | # multi_stream = True, 81 | # num_stream=len(semg_col), 82 | # symbol_kargs=dict(dropout=0, num_stream=len(semg_col), fusion_type=fusion_type, num_semg_row=semg_row, num_semg_col=semg_col, num_channel=num_ch, num_filter=64), 83 | # params='.cache/TEST-ninapro-db1-singleframe_multistream-1-1-one-fold-intra-subject-fold-%d/model-0028.params'))], 84 | # folds=np.arange(27), 85 | # windows=np.arange(1, 501), 86 | # window=window, 87 | # num_semg_row = num_raw_semg_row, 88 | # num_semg_col = num_raw_semg_col, 89 | # feature_name = feature_name, 90 | # balance=True) 91 | # acc = acc.mean(axis=(0, 1)) 92 | # print('Single frame accuracy: %f' % acc[0]) 93 | # print('5 frames (50 ms) majority voting accuracy: %f' % acc[4]) 94 | # print('10 frames (100 ms) majority voting accuracy: %f' % acc[9]) 95 | # print('15 frames (150 ms) majority voting accuracy: %f' % acc[14]) 96 | # print('20 frames (200 ms) majority voting accuracy: %f' % acc[19]) 97 | # print('25 frames (250 ms) majority voting accuracy: %f' % acc[24]) 98 | 99 | 100 | semg_row = [] 101 | semg_col = [] 102 | num_ch = [] 103 | for i in range(3): 104 | num_ch.append(1) 105 | semg_row.append(8) 106 | semg_col.append(7) 107 | window=1 108 | num_raw_semg_row=24 109 | num_raw_semg_col=7 110 | feature_name = 'piece_multistream' 111 | fusion_type = 'fuse_5' 112 | 113 | with Context(parallel=True, level='DEBUG'): 114 | acc = intra_session_eval.vote_accuracy_curves( 115 | [Exp(dataset=Dataset.from_name('csliter'), 116 | dataset_args=dict(preprocess=Preprocess.parse('(csl-cut,abs,ninapro-lowpass)')), 117 | Mod=dict(num_gesture=27, 118 | adabn=True, 119 | num_adabn_epoch=10, 120 | context=[mx.gpu(1)], 121 | multi_stream = True, 122 | num_stream=len(semg_col), 123 | symbol_kargs=dict(dropout=0, zscore=False, num_pixel=2, num_stream=len(semg_col), fusion_type=fusion_type, num_semg_row=semg_row, num_semg_col=semg_col, num_channel=num_ch, num_filter=64), 124 | params='.cache/CSL-piece_multistream-1-1-one-fold-intra-session-%d/model-0010.params'))], 125 | folds=np.arange(250), 126 | windows=np.arange(1, 2049), 127 | window=window, 128 | num_semg_row = num_raw_semg_row, 129 | num_semg_col = num_raw_semg_col, 130 | feature_name = feature_name, 131 | balance=True) 132 | acc = acc.mean(axis=(0, 1)) 133 | print('Single frame accuracy: %f' % acc[0]) 134 | print('307 frames majority voting accuracy: %f' % acc[306]) 135 | print('350 frames majority voting accuracy: %f' % acc[349]) 136 | print('614 frames majority voting accuracy: %f' % acc[613]) 137 | 138 | save_root = "/home/weiwentao/public-2/wwt/voting_acc/" 139 | if os.path.isdir(save_root) is False: 140 | os.makedirs(save_root) 141 | out_acc = os.path.join(save_root, 'csl_multistream.v.0.1.0.mat') 142 | sio.savemat(out_acc, {'acc': acc}) 143 | 144 | 145 | with Context(parallel=True, level='DEBUG'): 146 | acc = intra_session_eval.accuracies( 147 | [Exp(dataset=Dataset.from_name('csliter'), vote=-1, 148 | dataset_args=dict(preprocess=Preprocess.parse('(csl-cut,abs,ninapro-lowpass)')), 149 | Mod=dict(num_gesture=27, 150 | adabn=True, 151 | num_adabn_epoch=10, 152 | context=[mx.gpu(1)], 153 | multi_stream = True, 154 | num_stream=len(semg_col), 155 | symbol_kargs=dict(dropout=0, zscore=False, num_pixel=2, num_stream=len(semg_col), fusion_type=fusion_type, num_semg_row=semg_row, num_semg_col=semg_col, num_channel=num_ch, num_filter=64), 156 | params='.cache/CSL-piece_multistream-1-1-one-fold-intra-session-%d/model-0010.params'))], 157 | folds=np.arange(250), 158 | window=window, 159 | num_semg_row = num_raw_semg_row, 160 | num_semg_col = num_raw_semg_col, 161 | feature_name = feature_name) 162 | print('Per-trial majority voting accuracy: %f' % acc.mean()) 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | #with Context(parallel=True, level='DEBUG'): 174 | # acc = one_fold_intra_subject_eval.accuracies( 175 | # [Exp(dataset=Dataset.from_name('ninapro-db1-sigimg-fast'), vote=-1, 176 | # dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 177 | # Mod=dict(num_gesture=52, 178 | # context=[mx.gpu(0)], 179 | # symbol_kargs=dict(dropout=0, num_semg_row=semg_row, num_semg_col=semg_col, num_filter=64), 180 | # params='.cache/ninapro-db1-sigimg-1-1-one-fold-intra-subject-fold-%d-v1.0.0.6/model-0028.params'))], 181 | # folds=np.arange(27)) 182 | # print('Per-trial majority voting accuracy: %f' % acc.mean()) 183 | -------------------------------------------------------------------------------- /scripts/test_db1_input.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | sys.path.insert(0, os.getcwd()) 5 | import numpy as np 6 | import mxnet as mx 7 | from sigr.evaluation_db1input import CrossValEvaluation as CV, Exp 8 | from sigr.data import Preprocess, Dataset 9 | from sigr import Context 10 | 11 | one_fold_intra_subject_eval = CV(crossval_type='one-fold-intra-subject', batch_size=1000) 12 | intra_session_eval = CV(crossval_type='intra-session', batch_size=1000) 13 | 14 | print('NinaPro DB1 SigImage') 15 | print('===========') 16 | 17 | semg_row = 1 18 | semg_col = 50 19 | window=1 20 | num_raw_semg_row=1 21 | num_raw_semg_col=10 22 | feature_name = 'sigimg' 23 | 24 | #with Context(parallel=True, level='DEBUG'): 25 | # acc = one_fold_intra_subject_eval.vote_accuracy_curves( 26 | # [Exp(dataset=Dataset.from_name('ninapro-db1-sigimg-fast'), 27 | # dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 28 | # Mod=dict(num_gesture=52, 29 | # context=[mx.gpu(0)], 30 | # symbol_kargs=dict(dropout=0, num_semg_row=semg_row, num_semg_col=semg_col, num_filter=64), 31 | # params='.cache/ninapro-db1-sigimg-1-1-one-fold-intra-subject-fold-%d-v1.0.0.6/model-0028.params'))], 32 | # folds=np.arange(27), 33 | # windows=np.arange(1, 501), 34 | # window=window, 35 | # num_semg_row = num_raw_semg_row, 36 | # num_semg_col = num_raw_semg_col, 37 | # feature_name = feature_name, 38 | # balance=True) 39 | # acc = acc.mean(axis=(0, 1)) 40 | # print('Single frame accuracy: %f' % acc[0]) 41 | # print('5 frames (50 ms) majority voting accuracy: %f' % acc[4]) 42 | # print('10 frames (100 ms) majority voting accuracy: %f' % acc[9]) 43 | # print('15 frames (150 ms) majority voting accuracy: %f' % acc[14]) 44 | # print('20 frames (200 ms) majority voting accuracy: %f' % acc[19]) 45 | # print('25 frames (250 ms) majority voting accuracy: %f' % acc[24]) 46 | 47 | top_k = 3 48 | 49 | with Context(parallel=True, level='DEBUG'): 50 | acc = one_fold_intra_subject_eval.topk_accuracy_curves( 51 | [Exp(dataset=Dataset.from_name('ninapro-db1-sigimg-fast'), 52 | dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 53 | Mod=dict(num_gesture=52, 54 | context=[mx.gpu(0)], 55 | symbol_kargs=dict(dropout=0, num_semg_row=semg_row, num_semg_col=semg_col, num_filter=64), 56 | params='test_result/ninapro-db1-sigimg-1-1-one-fold-intra-subject-fold-%d-v1.0.0.6/model-0028.params'))], 57 | folds=np.arange(27), 58 | windows=np.arange(1, 501), 59 | window=window, 60 | num_semg_row = num_raw_semg_row, 61 | num_semg_col = num_raw_semg_col, 62 | feature_name = feature_name, 63 | topk = top_k, 64 | balance=True) 65 | acc = acc.mean(axis=(0, 1)) 66 | print('Single frame accuracy: %f' % acc[0]) 67 | print('5 frames (50 ms) majority voting accuracy: %f' % acc[4]) 68 | print('10 frames (100 ms) majority voting accuracy: %f' % acc[9]) 69 | print('15 frames (150 ms) majority voting accuracy: %f' % acc[14]) 70 | print('20 frames (200 ms) majority voting accuracy: %f' % acc[19]) 71 | print('25 frames (250 ms) majority voting accuracy: %f' % acc[24]) -------------------------------------------------------------------------------- /scripts/test_ninapro_multistream.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | sys.path.insert(0, os.getcwd()) 5 | import numpy as np 6 | import mxnet as mx 7 | from sigr.evaluation_db1multistream_outputpreds import CrossValEvaluation as CV, Exp 8 | from sigr.data import Preprocess, Dataset 9 | from sigr import Context 10 | 11 | one_fold_intra_subject_eval = CV(crossval_type='one-fold-intra-subject', batch_size=1000) 12 | intra_session_eval = CV(crossval_type='intra-session', batch_size=1000) 13 | 14 | print('NinaPro DB1 Multistream') 15 | print('===========') 16 | 17 | 18 | semg_row = [] 19 | semg_col = [] 20 | num_ch = [] 21 | for i in range(10): 22 | num_ch.append(1) 23 | semg_row.append(1) 24 | semg_col.append(20) 25 | window=20 26 | num_raw_semg_row=1 27 | num_raw_semg_col=10 28 | feature_name = 'ch_multistream' 29 | fusion_type = 'fuse_5' 30 | 31 | with Context(parallel=True, level='DEBUG'): 32 | acc = one_fold_intra_subject_eval.vote_accuracy_curves( 33 | [Exp(dataset=Dataset.from_name('ninapro-db1-sigimg-fast'), 34 | dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 35 | Mod=dict(num_gesture=52, 36 | context=[mx.gpu(0)], 37 | multi_stream = True, 38 | num_stream=len(semg_col), 39 | symbol_kargs=dict(dropout=0, zscore=False, num_pixel=2, num_stream=len(semg_col), fusion_type=fusion_type, num_semg_row=semg_row, num_semg_col=semg_col, num_channel=num_ch, num_filter=64), 40 | params='.cache/ninapro-db1-ch_multistream-20-1-one-fold-intra-subject-fold-%d-v1.0.0.1/model-0028.params'))], 41 | folds=np.arange(27), 42 | windows=np.arange(1, 5), 43 | window=window, 44 | num_semg_row = num_raw_semg_row, 45 | num_semg_col = num_raw_semg_col, 46 | feature_name = feature_name, 47 | balance=True) 48 | acc = acc.mean(axis=(0, 1)) 49 | 50 | 51 | kk = one_fold_intra_subject_eval.output_softmax_preds( 52 | [Exp(dataset=Dataset.from_name('ninapro-db1-sigimg-fast'), 53 | dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 54 | Mod=dict(num_gesture=52, 55 | context=[mx.gpu(0)], 56 | multi_stream = True, 57 | num_stream=len(semg_col), 58 | symbol_kargs=dict(dropout=0, zscore=False, num_pixel=2, num_stream=len(semg_col), fusion_type=fusion_type, num_semg_row=semg_row, num_semg_col=semg_col, num_channel=num_ch, num_filter=64), 59 | params='.cache/ninapro-db1-ch_multistream-20-1-one-fold-intra-subject-fold-%d-v1.0.0.1/model-0028.params'))], 60 | folds=np.arange(27), 61 | windows=np.arange(1, 5), 62 | window=window, 63 | num_semg_row = num_raw_semg_row, 64 | num_semg_col = num_raw_semg_col, 65 | feature_name = feature_name, 66 | balance=True) 67 | 68 | print('Single frame accuracy: %f' % acc[0]) 69 | # print('5 frames (50 ms) majority voting accuracy: %f' % acc[4]) 70 | # print('10 frames (100 ms) majority voting accuracy: %f' % acc[9]) 71 | # print('15 frames (150 ms) majority voting accuracy: %f' % acc[14]) 72 | # print('20 frames (200 ms) majority voting accuracy: %f' % acc[19]) 73 | # print('25 frames (250 ms) majority voting accuracy: %f' % acc[24]) 74 | 75 | 76 | #with Context(parallel=True, level='DEBUG'): 77 | # acc = one_fold_intra_subject_eval.accuracies( 78 | # [Exp(dataset=Dataset.from_name('ninapro-db1-sigimg-fast'), vote=-1, 79 | # dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 80 | # Mod=dict(num_gesture=52, 81 | # context=[mx.gpu(0)], 82 | # symbol_kargs=dict(dropout=0, num_semg_row=semg_row, num_semg_col=semg_col, num_filter=64), 83 | # params='.cache/ninapro-db1-sigimg-1-1-one-fold-intra-subject-fold-%d-v1.0.0.6/model-0028.params'))], 84 | # folds=np.arange(27)) 85 | # print('Per-trial majority voting accuracy: %f' % acc.mean()) 86 | -------------------------------------------------------------------------------- /scripts/test_semimyo.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | sys.path.insert(0, os.getcwd()) 5 | import numpy as np 6 | import mxnet as mx 7 | from sigr.evaluation_semimyo import CrossValEvaluation as CV, Exp 8 | from sigr.data import Preprocess, Dataset 9 | from sigr import Context 10 | 11 | 12 | one_fold_intra_subject_eval = CV(crossval_type='one-fold-intra-subject', batch_size=1000) 13 | 14 | with Context(parallel=True, level='DEBUG'): 15 | acc = one_fold_intra_subject_eval.vote_accuracy_curves( 16 | [Exp(dataset=Dataset.from_name('ninapro-db1-raw/semg-glove'), 17 | dataset_args=dict(preprocess=Preprocess.parse('{ninapro-lowpass-parallel,identity,identity}')), 18 | Mod=dict(for_training=False, 19 | context=[mx.gpu(1)], 20 | symbol_kargs=dict( 21 | dropout=0, 22 | num_gesture=52, 23 | num_glove=22, 24 | num_semg_row=1, 25 | num_semg_col=10, 26 | glove_loss_weight=0.01, 27 | num_glove_layer=2, 28 | num_glove_hidden=128 29 | ), 30 | params='.cache/semimyo-ninapro-db1-raw-semg-glove-one-fold-intra-subject-%d-v20161127.2/model-0028.params'))], 31 | folds=np.arange(27), 32 | windows=[1], 33 | balance=True) 34 | acc = acc.mean(axis=(0, 1)) 35 | print('Single frame accuracy: %f' % acc[0]) 36 | 37 | # with Context(parallel=True, level='DEBUG'): 38 | # acc = one_fold_intra_subject_eval.vote_accuracy_curves( 39 | # [Exp(dataset=Dataset.from_name('ninapro-db1-raw/semg-glove'), 40 | # dataset_args=dict(preprocess=Preprocess.parse('{ninapro-lowpass-parallel,identity,identity}')), 41 | # Mod=dict(for_training=False, 42 | # context=[mx.gpu(0)], 43 | # symbol_kargs=dict( 44 | # dropout=0, 45 | # num_gesture=52, 46 | # num_glove=22, 47 | # num_semg_row=1, 48 | # num_semg_col=10, 49 | # glove_loss_weight=0, 50 | # num_glove_layer=2, 51 | # num_glove_hidden=128 52 | # ), 53 | # params='.cache/semimyo-ninapro-db1-raw-semg-glove-one-fold-intra-subject-%d-v20161127.1/model-0028.params'))], 54 | # folds=np.arange(27), 55 | # windows=[1], 56 | # balance=True) 57 | # acc = acc.mean(axis=(0, 1)) 58 | # print('Single frame accuracy: %f' % acc[0]) 59 | 60 | # with Context(parallel=True, level='DEBUG'): 61 | # acc = one_fold_intra_subject_eval.vote_accuracy_curves( 62 | # [Exp(dataset=Dataset.from_name('ninapro-db1-raw/semg-glove'), 63 | # dataset_args=dict(preprocess=Preprocess.parse('{ninapro-lowpass-parallel,identity,identity}')), 64 | # Mod=dict(for_training=False, 65 | # context=[mx.gpu(1)], 66 | # symbol_kargs=dict( 67 | # dropout=0, 68 | # num_gesture=52, 69 | # num_glove=22, 70 | # num_semg_row=1, 71 | # num_semg_col=10, 72 | # glove_loss_weight=0.1, 73 | # num_glove_layer=2, 74 | # num_glove_hidden=128 75 | # ), 76 | # params='.cache/semimyo-ninapro-db1-raw-semg-glove-one-fold-intra-subject-%d-v20161127.3/model-0028.params'))], 77 | # folds=np.arange(27), 78 | # windows=[1], 79 | # balance=True) 80 | # acc = acc.mean(axis=(0, 1)) 81 | # print('Single frame accuracy: %f' % acc[0]) 82 | 83 | # with Context(parallel=True, level='DEBUG'): 84 | # acc = one_fold_intra_subject_eval.vote_accuracy_curves( 85 | # [Exp(dataset=Dataset.from_name('ninapro-db1-raw/semg-glove'), 86 | # dataset_args=dict(preprocess=Preprocess.parse('{ninapro-lowpass-parallel,identity,identity}')), 87 | # Mod=dict(for_training=False, 88 | # context=[mx.gpu(1)], 89 | # symbol_kargs=dict( 90 | # dropout=0, 91 | # num_gesture=52, 92 | # num_glove=22, 93 | # num_semg_row=1, 94 | # num_semg_col=10, 95 | # glove_loss_weight=0.01, 96 | # num_glove_layer=2, 97 | # num_glove_hidden=256 98 | # ), 99 | # params='.cache/semimyo-ninapro-db1-raw-semg-glove-one-fold-intra-subject-%d-v20161127.4/model-0028.params'))], 100 | # folds=np.arange(27), 101 | # windows=[1], 102 | # balance=True) 103 | # acc = acc.mean(axis=(0, 1)) 104 | # print('Single frame accuracy: %f' % acc[0]) 105 | 106 | # with Context(parallel=True, level='DEBUG'): 107 | # acc = one_fold_intra_subject_eval.vote_accuracy_curves( 108 | # [Exp(dataset=Dataset.from_name('ninapro-db1-raw/semg-glove'), 109 | # dataset_args=dict(preprocess=Preprocess.parse('{ninapro-lowpass-parallel,identity,identity}')), 110 | # Mod=dict(for_training=False, 111 | # context=[mx.gpu(1)], 112 | # symbol_kargs=dict( 113 | # dropout=0, 114 | # num_gesture=52, 115 | # num_glove=22, 116 | # num_semg_row=1, 117 | # num_semg_col=10, 118 | # glove_loss_weight=0.01, 119 | # num_glove_layer=4, 120 | # num_glove_hidden=128 121 | # ), 122 | # params='.cache/semimyo-ninapro-db1-raw-semg-glove-one-fold-intra-subject-%d-v20161127.5/model-0028.params'))], 123 | # folds=np.arange(27), 124 | # windows=[1], 125 | # balance=True) 126 | # acc = acc.mean(axis=(0, 1)) 127 | # print('Single frame accuracy: %f' % acc[0]) 128 | -------------------------------------------------------------------------------- /scripts/testsrep.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | sys.path.insert(0, os.getcwd()) 5 | import numpy as np 6 | import mxnet as mx 7 | from sigr.evaluation import CrossValEvaluation as CV, Exp 8 | from sigr.data import Preprocess, Dataset 9 | from sigr import Context 10 | 11 | 12 | one_fold_intra_subject_eval = CV(crossval_type='one-fold-intra-subject', batch_size=1000) 13 | intra_session_eval = CV(crossval_type='intra-session', batch_size=1000) 14 | 15 | print('NinaPro DB1') 16 | print('===========') 17 | 18 | with Context(parallel=True, level='DEBUG'): 19 | acc = one_fold_intra_subject_eval.vote_accuracy_curves( 20 | [Exp(dataset=Dataset.from_name('ninapro-db1'), 21 | dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 22 | Mod=dict(num_gesture=52, 23 | context=[mx.gpu(0)], 24 | symbol_kargs=dict(dropout=0, num_semg_row=1, num_semg_col=10, num_filter=64), 25 | params='.cache/srep-ninapro-db1-one-fold-intra-subject-%d/model-0028.params'))], 26 | folds=np.arange(27), 27 | windows=np.arange(1, 501), 28 | balance=True) 29 | acc = acc.mean(axis=(0, 1)) 30 | print('Single frame accuracy: %f' % acc[0]) 31 | print('15 frames (150 ms) majority voting accuracy: %f' % acc[14]) 32 | 33 | with Context(parallel=True, level='DEBUG'): 34 | acc = one_fold_intra_subject_eval.accuracies( 35 | [Exp(dataset=Dataset.from_name('ninapro-db1'), vote=-1, 36 | dataset_args=dict(preprocess=Preprocess.parse('ninapro-lowpass')), 37 | Mod=dict(num_gesture=52, 38 | context=[mx.gpu(0)], 39 | symbol_kargs=dict(dropout=0, num_semg_row=1, num_semg_col=10, num_filter=64), 40 | params='.cache/srep-ninapro-db1-one-fold-intra-subject-%d/model-0028.params'))], 41 | folds=np.arange(27)) 42 | print('Per-trial majority voting accuracy: %f' % acc.mean()) 43 | 44 | print('') 45 | print('CapgMyo DB-a') 46 | print('============') 47 | 48 | with Context(parallel=True, level='DEBUG'): 49 | acc = one_fold_intra_subject_eval.vote_accuracy_curves( 50 | [Exp(dataset=Dataset.from_name('dba'), 51 | Mod=dict(num_gesture=8, 52 | context=[mx.gpu(0)], 53 | symbol_kargs=dict(dropout=0, num_semg_row=16, num_semg_col=8, num_filter=64), 54 | params='.cache/srep-dba-one-fold-intra-subject-%d/model-0028.params'))], 55 | folds=np.arange(18), 56 | windows=np.arange(1, 1001)) 57 | acc = acc.mean(axis=(0, 1)) 58 | print('Single frame accuracy: %f' % acc[0]) 59 | print('150 frames (150 ms) majority voting accuracy: %f' % acc[149]) 60 | 61 | with Context(parallel=True, level='DEBUG'): 62 | acc = one_fold_intra_subject_eval.accuracies( 63 | [Exp(dataset=Dataset.from_name('dba'), vote=-1, 64 | Mod=dict(num_gesture=8, 65 | context=[mx.gpu(0)], 66 | symbol_kargs=dict(dropout=0, num_semg_row=16, num_semg_col=8, num_filter=64), 67 | params='.cache/srep-dba-one-fold-intra-subject-%d/model-0028.params'))], 68 | folds=np.arange(18)) 69 | print('Per-trial majority voting accuracy: %f' % acc.mean()) 70 | 71 | # print('') 72 | # print('# CSL-HDEMG') 73 | # print('===========') 74 | 75 | # with Context(parallel=True, level='DEBUG'): 76 | # acc = intra_session_eval.vote_accuracy_curves( 77 | # [Exp(dataset=Dataset.from_name('csl'), 78 | # dataset_args=dict(preprocess=Preprocess.parse('(csl-bandpass,csl-cut,median)')), 79 | # Mod=dict(num_gesture=27, 80 | # adabn=True, 81 | # num_adabn_epoch=10, 82 | # context=[mx.gpu(0)], 83 | # symbol_kargs=dict(dropout=0, num_semg_row=24, num_semg_col=7, num_filter=64), 84 | # params='.cache/srep-csl-intra-session-%d/model-0028.params'))], 85 | # folds=np.arange(250), 86 | # windows=np.arange(1, 2049), 87 | # balance=True) 88 | # acc = acc.mean(axis=(0, 1)) 89 | # print('Single frame accuracy: %f' % acc[0]) 90 | # print('307 frames (150 ms) majority voting accuracy: %f' % acc[306]) 91 | 92 | # with Context(parallel=True, level='DEBUG'): 93 | # acc = intra_session_eval.accuracies( 94 | # [Exp(dataset=Dataset.from_name('csl'), vote=-1, 95 | # dataset_args=dict(preprocess=Preprocess.parse('(csl-bandpass,csl-cut,median)')), 96 | # Mod=dict(num_gesture=27, 97 | # adabn=True, 98 | # num_adabn_epoch=10, 99 | # context=[mx.gpu(0)], 100 | # symbol_kargs=dict(dropout=0, num_semg_row=24, num_semg_col=7, num_filter=64), 101 | # params='.cache/srep-csl-intra-session-%d/model-0028.params'))], 102 | # folds=np.arange(250)) 103 | # print('Per-trial majority voting accuracy: %f' % acc.mean()) 104 | -------------------------------------------------------------------------------- /scripts/testsrep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | scripts/runsrep python scripts/testsrep.py 4 | -------------------------------------------------------------------------------- /scripts/trainsrep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Recognition of 8 gestures in CapgMyo DB-a 4 | scripts/runsrep python -m sigr.app exp --log log --snapshot model \ 5 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 6 | --root .cache/srep-dba-universal-one-fold-intra-subject \ 7 | --num-semg-row 16 --num-semg-col 8 \ 8 | --batch-size 1000 --decay-all --dataset dba \ 9 | --num-filter 64 \ 10 | crossval --crossval-type universal-one-fold-intra-subject --fold 0 11 | for i in $(seq 0 17 | shuf); do 12 | scripts/runsrep python -m sigr.app exp --log log --snapshot model \ 13 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 14 | --root .cache/srep-dba-one-fold-intra-subject-$i \ 15 | --num-semg-row 16 --num-semg-col 8 \ 16 | --batch-size 1000 --decay-all --dataset dba \ 17 | --num-filter 64 \ 18 | --params .cache/srep-dba-universal-one-fold-intra-subject/model-0028.params \ 19 | crossval --crossval-type one-fold-intra-subject --fold $i 20 | done 21 | 22 | # Recognition of 52 gestures in NinaPro DB1 23 | scripts/runsrep python -m sigr.app exp --log log --snapshot model \ 24 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 25 | --root .cache/srep-ninapro-db1-universal-one-fold-intra-subject \ 26 | --num-semg-row 1 --num-semg-col 10 \ 27 | --batch-size 1000 --decay-all --dataset ninapro-db1 \ 28 | --num-filter 64 \ 29 | --balance-gesture 1 \ 30 | --preprocess 'ninapro-lowpass' \ 31 | crossval --crossval-type universal-one-fold-intra-subject --fold 0 32 | for i in $(seq 0 26 | shuf); do 33 | scripts/runsrep python -m sigr.app exp --log log --snapshot model \ 34 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 35 | --root .cache/srep-ninapro-db1-one-fold-intra-subject-$i \ 36 | --num-semg-row 1 --num-semg-col 10 \ 37 | --batch-size 1000 --decay-all --dataset ninapro-db1 \ 38 | --num-filter 64 \ 39 | --params .cache/srep-ninapro-db1-universal-one-fold-intra-subject/model-0028.params \ 40 | --balance-gesture 1 \ 41 | --preprocess 'ninapro-lowpass' \ 42 | crossval --crossval-type one-fold-intra-subject --fold $i 43 | done 44 | 45 | # Recognition of 27 gestures in CSL-HDEMG 46 | for i in $(seq 0 24 | shuf); do 47 | scripts/runsrep python -m sigr.app exp --log log --snapshot model \ 48 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 49 | --root .cache/srep-csl-universal-intra-session-$i \ 50 | --num-semg-row 24 --num-semg-col 7 \ 51 | --batch-size 2500 --decay-all --adabn --minibatch --dataset csl \ 52 | --preprocess '(csl-bandpass,csl-cut,downsample-5,median)' \ 53 | --balance-gesture 1 \ 54 | --num-filter 64 \ 55 | crossval --crossval-type universal-intra-session --fold $i 56 | done 57 | for i in $(seq 0 249 | shuf); do 58 | scripts/runsrep python -m sigr.app exp --log log --snapshot model \ 59 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 60 | --root .cache/srep-csl-intra-session-$i \ 61 | --num-semg-row 24 --num-semg-col 7 \ 62 | --batch-size 1000 --decay-all --adabn --num-adabn-epoch 10 --dataset csl \ 63 | --preprocess '(csl-bandpass,csl-cut,median)' \ 64 | --balance-gesture 1 \ 65 | --params .cache/srep-csl-universal-intra-session-$(($i % 10))/model-0028.params \ 66 | --num-filter 64 \ 67 | crossval --crossval-type intra-session --fold $i 68 | done 69 | -------------------------------------------------------------------------------- /sigr/__init__.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import random 4 | 5 | mx.random.seed(42) 6 | np.random.seed(43) 7 | random.seed(44) 8 | 9 | import os 10 | 11 | os.environ['JOBLIB_TEMP_FOLDER'] = '/tmp' 12 | 13 | ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 14 | CACHE = os.path.join(ROOT, '.cache') 15 | #CACHE = '/home/weiwentao/exp/.cache/' 16 | 17 | 18 | from contextlib import contextmanager 19 | 20 | 21 | @contextmanager 22 | def Context(log=None, parallel=False, level=None): 23 | from .utils import logging_context 24 | with logging_context(log, level=level): 25 | if not parallel: 26 | yield 27 | else: 28 | import joblib as jb 29 | from multiprocessing import cpu_count 30 | with jb.Parallel(n_jobs=cpu_count()) as par: 31 | Context.parallel = par 32 | yield 33 | 34 | 35 | def _patch(func): 36 | func() 37 | return lambda: None 38 | 39 | 40 | @_patch 41 | def _patch_click(): 42 | import click 43 | orig = click.option 44 | 45 | def option(*args, **kargs): 46 | if 'help' in kargs and 'default' in kargs: 47 | kargs['help'] += ' (default {})'.format(kargs['default']) 48 | return orig(*args, **kargs) 49 | 50 | click.option = option 51 | 52 | 53 | from .data import s21 as data_s21 54 | 55 | 56 | __all__ = ['ROOT', 'CACHE', 'Context', 'data_s21'] 57 | -------------------------------------------------------------------------------- /sigr/activity_img/actimg_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | import numpy as np 4 | from activity_image import get_signal_img 5 | from itertools import product 6 | from collections import namedtuple 7 | 8 | #import cv2 9 | #from ..utils import butter_lowpass_filter as lowpass 10 | 11 | 12 | subjects = list(range(27)) 13 | gestures = list(range(53)) 14 | trials = list(range(10)) 15 | #input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 16 | #output_path = '/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw' 17 | 18 | 19 | input_path = 'Y:/duyu/misc/ninapro-db1' 20 | 21 | 22 | filtering_type = 'lowpass' 23 | framerate = 100 24 | 25 | window_length_ms = 150 26 | window_stride_ms = 10 27 | 28 | window = window_length_ms*framerate/1000 29 | stride = window_stride_ms*framerate/1000 30 | 31 | output_path = ('Y:/semg/ninapro-feature/TEST-ninapro-db1-var-raw-prepro-%s-win-%d-stride-%d' % (filtering_type, window, stride)) 32 | 33 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False) 34 | 35 | def get_combos(*args): 36 | for arg in args: 37 | if isinstance(arg, tuple): 38 | arg = [arg] 39 | for a in arg: 40 | yield Combo(*a) 41 | 42 | 43 | 44 | #the following functions can be loaded from ..utils 45 | 46 | 47 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 48 | from scipy.signal import butter, lfilter, filtfilt 49 | 50 | nyq = 0.5 * fs 51 | cut = cut / nyq 52 | 53 | b, a = butter(order, cut, btype='low') 54 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 55 | return y 56 | 57 | 58 | def get_segments(data, window, stride): 59 | return windowed_view( 60 | data.flat, 61 | window * data.shape[1], 62 | (window-stride)* data.shape[1] 63 | ) 64 | 65 | def windowed_view(arr, window, overlap): 66 | from numpy.lib.stride_tricks import as_strided 67 | arr = np.asarray(arr) 68 | window_step = window - overlap 69 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 70 | window) 71 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 72 | arr.strides[-1:]) 73 | return as_strided(arr, shape=new_shape, strides=new_strides) 74 | 75 | 76 | def dft(data): 77 | f = np.fft.fft2(data) 78 | fshift = np.fft.fftshift(f) 79 | magnitude_spectrum = 20*np.log(np.abs(fshift)) 80 | return magnitude_spectrum 81 | 82 | #the following functions can be loaded from .. 83 | 84 | def dft_dy(data): 85 | data = data.T 86 | n = data.shape[-1] 87 | window = np.hanning(n) 88 | windowed = data * window 89 | spectrum = np.fft.fft(windowed) 90 | return np.abs(spectrum.T) 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | 96 | print ("NinaPro activity image generation, use window = %d frames, stride = %d frames" % (window, stride)) 97 | 98 | combos = get_combos(product(subjects, gestures, trials)) 99 | 100 | combos = list(combos) 101 | 102 | for combo in combos: 103 | in_path = os.path.join( 104 | input_path, 'data', 105 | '{c.subject:03d}', 106 | '{c.gesture:03d}', 107 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 108 | 109 | out_dir = os.path.join( 110 | output_path, 111 | '{c.subject:03d}', 112 | '{c.gesture:03d}').format(c=combo) 113 | 114 | if os.path.isdir(out_dir) is False: 115 | os.makedirs(out_dir) 116 | 117 | 118 | data = sio.loadmat(in_path)['data'].astype(np.float32) 119 | 120 | print ("Subject %d Gesture %d Trial %d data loaded!" % (combo.subject, combo.gesture, combo.trial)) 121 | 122 | if filtering_type is 'lowpass': 123 | # data = np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 124 | data = np.transpose([butter_lowpass_filter(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 125 | print ("Subject %d Gesture %d Trial %d bandpass filtering finished!" % (combo.subject, combo.gesture, combo.trial)) 126 | else: 127 | pass 128 | 129 | 130 | 131 | chnum = data.shape[1]; 132 | data = get_segments(data, window, stride) 133 | data = data.reshape(-1, window, chnum) 134 | 135 | data = [np.transpose(get_signal_img(seg.T)) for seg in data] 136 | data = np.array(data) 137 | 138 | out_path = os.path.join( 139 | out_dir, 140 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}_sigimg.mat').format(c=combo) 141 | sio.savemat(out_path, {'data': data, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 142 | 143 | print ("Subject %d Gesture %d Trial %d sig image saved!" % (combo.subject, combo.gesture, combo.trial)) 144 | 145 | ## for test only 146 | # data = data[0:25,] 147 | # data = dft(data) 148 | # data = cv2.resize(data,None,fx=20,fy=20) 149 | # cv2.imshow('image',data) 150 | # cv2.waitKey(0) 151 | # cv2.destroyAllWindows() 152 | 153 | 154 | data_fft = [] 155 | for seg in data: 156 | spectrum = dft_dy(seg) 157 | data_fft.append(spectrum) 158 | data = np.array(data_fft) 159 | 160 | 161 | print ("Subject %d Gesture %d Trial %d data windowing finished!" % (combo.subject, combo.gesture, combo.trial)) 162 | 163 | out_path = os.path.join( 164 | out_dir, 165 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}_actimg.mat').format(c=combo) 166 | sio.savemat(out_path, {'data': data, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 167 | 168 | print ("Subject %d Gesture %d Trial %d activity image saved!" % (combo.subject, combo.gesture, combo.trial)) 169 | 170 | 171 | -------------------------------------------------------------------------------- /sigr/activity_img/actimg_extractor_1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | import numpy as np 4 | from activity_image import get_signal_img 5 | from itertools import product 6 | from collections import namedtuple 7 | 8 | #import cv2 9 | #from ..utils import butter_lowpass_filter as lowpass 10 | 11 | 12 | subjects = list(range(27)) 13 | gestures = list(range(53)) 14 | trials = list(range(10)) 15 | #input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 16 | #output_path = '/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw' 17 | 18 | 19 | input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 20 | 21 | 22 | filtering_type = 'lowpass' 23 | framerate = 100 24 | 25 | window_length_ms = 200 26 | window_stride_ms = 10 27 | 28 | window = window_length_ms*framerate/1000 29 | stride = window_stride_ms*framerate/1000 30 | 31 | output_path = ('/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw-prepro-%s-win-%d-stride-%d' % (filtering_type, window, stride)) 32 | 33 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False) 34 | 35 | def get_combos(*args): 36 | for arg in args: 37 | if isinstance(arg, tuple): 38 | arg = [arg] 39 | for a in arg: 40 | yield Combo(*a) 41 | 42 | 43 | 44 | #the following functions can be loaded from ..utils 45 | 46 | 47 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 48 | from scipy.signal import butter, lfilter, filtfilt 49 | 50 | nyq = 0.5 * fs 51 | cut = cut / nyq 52 | 53 | b, a = butter(order, cut, btype='low') 54 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 55 | return y 56 | 57 | 58 | def get_segments(data, window, stride): 59 | return windowed_view( 60 | data.flat, 61 | window * data.shape[1], 62 | (window-stride)* data.shape[1] 63 | ) 64 | 65 | def windowed_view(arr, window, overlap): 66 | from numpy.lib.stride_tricks import as_strided 67 | arr = np.asarray(arr) 68 | window_step = window - overlap 69 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 70 | window) 71 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 72 | arr.strides[-1:]) 73 | return as_strided(arr, shape=new_shape, strides=new_strides) 74 | 75 | 76 | def dft(data): 77 | f = np.fft.fft2(data) 78 | fshift = np.fft.fftshift(f) 79 | magnitude_spectrum = 20*np.log(np.abs(fshift)) 80 | return magnitude_spectrum 81 | 82 | #the following functions can be loaded from .. 83 | 84 | def dft_dy(data): 85 | data = data.T 86 | n = data.shape[-1] 87 | window = np.hanning(n) 88 | windowed = data * window 89 | spectrum = np.fft.fft(windowed) 90 | return np.abs(spectrum) 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | 96 | print ("NinaPro activity image generation, use window = %d frames, stride = %d frames" % (window, stride)) 97 | 98 | combos = get_combos(product(subjects, gestures, trials)) 99 | 100 | combos = list(combos) 101 | 102 | for combo in combos: 103 | in_path = os.path.join( 104 | input_path, 'data', 105 | '{c.subject:03d}', 106 | '{c.gesture:03d}', 107 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 108 | 109 | out_dir = os.path.join( 110 | output_path, 111 | '{c.subject:03d}', 112 | '{c.gesture:03d}').format(c=combo) 113 | 114 | if os.path.isdir(out_dir) is False: 115 | os.makedirs(out_dir) 116 | 117 | 118 | data = sio.loadmat(in_path)['data'].astype(np.float32) 119 | 120 | print ("Subject %d Gesture %d Trial %d data loaded!" % (combo.subject, combo.gesture, combo.trial)) 121 | 122 | if filtering_type is 'lowpass': 123 | # data = np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 124 | data = np.transpose([butter_lowpass_filter(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 125 | print ("Subject %d Gesture %d Trial %d bandpass filtering finished!" % (combo.subject, combo.gesture, combo.trial)) 126 | else: 127 | pass 128 | 129 | 130 | 131 | chnum = data.shape[1]; 132 | data = get_segments(data, window, stride) 133 | data = data.reshape(-1, window, chnum) 134 | 135 | data = [np.transpose(get_signal_img(seg.T)) for seg in data] 136 | data = np.array(data) 137 | 138 | out_path = os.path.join( 139 | out_dir, 140 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}_sigimg.mat').format(c=combo) 141 | sio.savemat(out_path, {'data': data, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 142 | 143 | print ("Subject %d Gesture %d Trial %d sig image saved!" % (combo.subject, combo.gesture, combo.trial)) 144 | 145 | ## for test only 146 | # data = data[0:25,] 147 | # data = dft(data) 148 | # data = cv2.resize(data,None,fx=20,fy=20) 149 | # cv2.imshow('image',data) 150 | # cv2.waitKey(0) 151 | # cv2.destroyAllWindows() 152 | 153 | 154 | 155 | 156 | 157 | data = [dft(seg) for seg in data] 158 | data = np.array(data) 159 | 160 | 161 | print ("Subject %d Gesture %d Trial %d data windowing finished!" % (combo.subject, combo.gesture, combo.trial)) 162 | 163 | out_path = os.path.join( 164 | out_dir, 165 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}_actimg.mat').format(c=combo) 166 | sio.savemat(out_path, {'data': data, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 167 | 168 | print ("Subject %d Gesture %d Trial %d activity image saved!" % (combo.subject, combo.gesture, combo.trial)) 169 | 170 | 171 | -------------------------------------------------------------------------------- /sigr/activity_img/activity_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #import cv2 3 | 4 | def genIndex(chanums): 5 | 6 | index = [] 7 | i = 1 8 | j = i+1 9 | 10 | if (chanums % 2) == 0: 11 | Ns = chanums+1 12 | else: 13 | Ns = chanums 14 | 15 | 16 | index.append(1) 17 | t = chr(i+ord('A')) 18 | while(i!=j): 19 | l = "" 20 | l = l+chr(i+ord('A')) 21 | l = l+chr(j+ord('A')) 22 | r = "" 23 | r = r+chr(j+ord('A')) 24 | r = r+chr(i+ord('A')) 25 | if(j>Ns): 26 | j = 1 27 | elif(t.find(l)==-1 and t.find(r)==-1): 28 | index.append(j) 29 | t = t+chr(j+ord('A')) 30 | i = j 31 | j = i+1 32 | else: 33 | j = j+1 34 | 35 | 36 | 37 | new_index = [] 38 | if (chanums % 2) == 0: 39 | for i in range(len(index)): 40 | if index[i] != chanums+1: 41 | new_index.append(index[i]) 42 | 43 | index = np.array(new_index) 44 | index = index-1 45 | return index 46 | 47 | 48 | def get_signal_img(data): 49 | 50 | ch_num = data.shape[0] 51 | index = genIndex(ch_num) 52 | signal_img = data[index] 53 | signal_img = signal_img[:-1] 54 | # print signal_img.shape 55 | return signal_img 56 | 57 | def get_activity_img(data): 58 | 59 | signal_img = get_signal_img(data) 60 | 61 | f = np.fft.fft2(signal_img) 62 | fshift = np.fft.fftshift(f) 63 | magnitude_spectrum = 20*np.log(np.abs(fshift)) 64 | # magnitude_spectrum = cv2.resize(magnitude_spectrum,None,fx=1,fy=8) 65 | # cv2.imshow('image',magnitude_spectrum) 66 | # cv2.waitKey(0) 67 | # cv2.destroyAllWindows() 68 | return magnitude_spectrum 69 | 70 | 71 | -------------------------------------------------------------------------------- /sigr/base_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | 4 | class Meta(type): 5 | 6 | impls = [] 7 | 8 | def __init__(cls, name, bases, fields): 9 | type.__init__(cls, name, bases, fields) 10 | Meta.impls.append(cls) 11 | 12 | 13 | class BaseModule(object): 14 | 15 | __metaclass__ = Meta 16 | 17 | @classmethod 18 | def parse(cls, text, **kargs): 19 | if cls is BaseModule: 20 | for impl in Meta.impls: 21 | if impl is not BaseModule: 22 | inst = impl.parse(text, **kargs) 23 | if inst is not None: 24 | return inst 25 | 26 | 27 | __all__ = ['BaseModule'] 28 | -------------------------------------------------------------------------------- /sigr/constant.py: -------------------------------------------------------------------------------- 1 | NUM_LSTM_HIDDEN = 128 2 | NUM_LSTM_LAYER = 1 3 | LSTM_DROPOUT = 0. 4 | NUM_SEMG_ROW = 16 5 | NUM_SEMG_COL = 8 6 | NUM_SEMG_POINT = NUM_SEMG_ROW * NUM_SEMG_COL 7 | NUM_FILTER = 16 8 | NUM_HIDDEN = 512 9 | NUM_BOTTLENECK = 128 10 | DROPOUT = 0.5 11 | GAMMA = 10 12 | NUM_FEATURE_BLOCK = 2 13 | NUM_GESTURE_BLOCK = 0 14 | NUM_SUBJECT_BLOCK = 0 15 | NUM_PIXEL = 2 16 | LAMBDA_SCALE = 1 17 | NUM_TZENG_BATCH = 2 18 | NUM_ADABN_EPOCH = 1 19 | RANDOM_SHIFT_FILL = 'zero' 20 | NUM_CONV_LAYER = 2 21 | NUM_CONV_FILTER = 64 22 | NUM_LC_LAYER = 2 23 | NUM_LC_HIDDEN = 64 24 | LC_KERNEL = 1 25 | LC_STRIDE = 1 26 | LC_PAD = 0 27 | NUM_FC_LAYER = 2 28 | NUM_FC_HIDDEN = 512 29 | NUM_MINI_BATCH = 1 30 | LR = 0.1 31 | WD = 0.0001 32 | LR_FACTOR = 0.1 33 | # GLOVE_LOSS_WEIGHT = 1 34 | # NUM_GLOVE_LAYER = 128 35 | # NUM_GLOVE_HIDDEN = 128 36 | NUM_EPOCH = 28 37 | LR_STEP = [16, 24] 38 | BATCH_SIZE = 1000 39 | DECAY_ALL = True 40 | SNAPSHOT_PERIOD = 28 41 | FEATURE_EXTRACTION_WIN_LEN = 20 42 | FEATURE_EXTRACTION_WIN_STRIDE = 1 43 | ACTIVITY_IMAGE_PREPROCESS = 'lowpass' 44 | FEATURE_MAP_PREPROCESS = 'lowpass' 45 | FEATURE_LIST = ['dwpt'] 46 | 47 | -------------------------------------------------------------------------------- /sigr/coral.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg as splg 3 | 4 | 5 | def get_coral_params(ds, dt, lam=1e-3): 6 | ms = ds.mean(axis=0) 7 | ds = ds - ms 8 | mt = dt.mean(axis=0) 9 | dt = dt - mt 10 | cs = np.cov(ds.T) + lam * np.eye(ds.shape[1]) 11 | ct = np.cov(dt.T) + lam * np.eye(dt.shape[1]) 12 | sqrt = splg.sqrtm 13 | w = sqrt(ct).dot(np.linalg.inv(sqrt(cs))) 14 | b = mt - w.dot(ms.reshape(-1, 1)).ravel() 15 | return w, b 16 | -------------------------------------------------------------------------------- /sigr/data/capgmyo/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | from itertools import product 4 | import numpy as np 5 | import scipy.io as sio 6 | from logbook import Logger 7 | from ... import utils, CACHE 8 | from .. import Dataset as Base, Combo, Trial, SingleSessionMixin 9 | 10 | 11 | TRIALS = list(range(1, 11)) 12 | NUM_TRIAL = len(TRIALS) 13 | NUM_SEMG_ROW = 16 14 | NUM_SEMG_COL = 8 15 | FRAMERATE = 1000 16 | PREPROCESS_KARGS = dict( 17 | framerate=FRAMERATE, 18 | num_semg_row=NUM_SEMG_ROW, 19 | num_semg_col=NUM_SEMG_COL 20 | ) 21 | 22 | logger = Logger(__name__) 23 | 24 | 25 | class GetTrial(object): 26 | 27 | def __init__(self, gestures, trials, preprocess=None): 28 | self.preprocess = preprocess 29 | self.memo = {} 30 | self.gesture_and_trials = list(product(gestures, trials)) 31 | 32 | def get_path(self, root, combo): 33 | return os.path.join( 34 | root, 35 | '{c.subject:03d}-{c.gesture:03d}-{c.trial:03d}.mat'.format(c=combo)) 36 | 37 | def __call__(self, root, combo): 38 | path = self.get_path(root, combo) 39 | if path not in self.memo: 40 | logger.debug('Load subject {}', combo.subject) 41 | paths = [self.get_path(root, Combo(combo.subject, gesture, trial)) 42 | for gesture, trial in self.gesture_and_trials] 43 | self.memo.update({path: data for path, data in 44 | zip(paths, _get_data(paths, self.preprocess))}) 45 | data = self.memo[path] 46 | data = data.copy() 47 | gesture = np.repeat(combo.gesture, len(data)) 48 | subject = np.repeat(combo.subject, len(data)) 49 | return Trial(data=data, gesture=gesture, subject=subject) 50 | 51 | 52 | @utils.cached 53 | def _get_data(paths, preprocess): 54 | # return list(Context.parallel( 55 | # jb.delayed(_get_data_aux)(path, preprocess) for path in paths)) 56 | return [_get_data_aux(path, preprocess) for path in paths] 57 | 58 | 59 | def _get_data_aux(path, preprocess): 60 | data = sio.loadmat(path)['data'].astype(np.float32) 61 | if preprocess: 62 | data = preprocess(data, **PREPROCESS_KARGS) 63 | return data 64 | 65 | 66 | class Dataset(SingleSessionMixin, Base): 67 | 68 | framerate = FRAMERATE 69 | num_semg_row = NUM_SEMG_ROW 70 | num_semg_col = NUM_SEMG_COL 71 | trials = TRIALS 72 | 73 | def __init__(self, root): 74 | self.root = root 75 | 76 | def get_trial_func(self, *args, **kargs): 77 | return GetTrial(*args, **kargs) 78 | 79 | @classmethod 80 | def parse(cls, text): 81 | if cls is not Dataset and text == cls.name: 82 | return cls(root=os.path.join(CACHE, cls.name.split('/')[0])) 83 | 84 | 85 | from . import dba, dbb, dbc 86 | assert dba and dbb and dbc 87 | -------------------------------------------------------------------------------- /sigr/data/capgmyo/dba.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'dba' 8 | subjects = list(range(1, 19)) 9 | gestures = list(range(1, 9)) 10 | -------------------------------------------------------------------------------- /sigr/data/capgmyo/dbb.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from functools import partial 3 | from itertools import product 4 | from logbook import Logger 5 | from . import Dataset as Base 6 | from .. import get_data 7 | from ... import constant 8 | 9 | 10 | logger = Logger(__name__) 11 | 12 | 13 | class Dataset(Base): 14 | 15 | name = 'dbb' 16 | subjects = list(range(2, 21, 2)) 17 | gestures = list(range(1, 9)) 18 | num_session = 2 19 | sessions = [1, 2] 20 | 21 | def get_universal_inter_session_data(self, fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs): 22 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 23 | load = partial(get_data, 24 | framerate=self.framerate, 25 | root=self.root, 26 | last_batch_handle='pad', 27 | get_trial=get_trial, 28 | batch_size=batch_size, 29 | num_semg_row=self.num_semg_row, 30 | num_semg_col=self.num_semg_col) 31 | session = fold + 1 32 | subjects = list(range(1, 11)) 33 | num_subject = 10 34 | train = load(combos=self.get_combos(product([self.encode_subject_and_session(s, i) for s, i in 35 | product(subjects, [i for i in self.sessions if i != session])], 36 | self.gestures, self.trials)), 37 | adabn=adabn, 38 | mini_batch_size=batch_size // (num_subject * (self.num_session - 1) if minibatch else 1), 39 | balance_gesture=balance_gesture, 40 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 41 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 42 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 43 | shuffle=True) 44 | logger.debug('Training set loaded') 45 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(s, session) for s in subjects], 46 | self.gestures, self.trials)), 47 | adabn=adabn, 48 | mini_batch_size=batch_size // (num_subject if minibatch else 1), 49 | shuffle=False) 50 | logger.debug('Test set loaded') 51 | return train, val 52 | 53 | def get_inter_session_data(self, fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs): 54 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 55 | load = partial(get_data, 56 | framerate=self.framerate, 57 | root=self.root, 58 | last_batch_handle='pad', 59 | get_trial=get_trial, 60 | batch_size=batch_size, 61 | num_semg_row=self.num_semg_row, 62 | num_semg_col=self.num_semg_col) 63 | subject = fold // self.num_session + 1 64 | session = fold % self.num_session + 1 65 | train = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, i) for i in self.sessions if i != session], 66 | self.gestures, self.trials)), 67 | adabn=adabn, 68 | mini_batch_size=batch_size // (self.num_session - 1 if minibatch else 1), 69 | balance_gesture=balance_gesture, 70 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 71 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 72 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 73 | shuffle=True) 74 | logger.debug('Training set loaded') 75 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, session)], 76 | self.gestures, self.trials)), 77 | shuffle=False) 78 | logger.debug('Test set loaded') 79 | return train, val 80 | 81 | def get_inter_session_val(self, fold, batch_size, preprocess=None, **kargs): 82 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 83 | load = partial(get_data, 84 | framerate=self.framerate, 85 | root=self.root, 86 | last_batch_handle='pad', 87 | get_trial=get_trial, 88 | batch_size=batch_size, 89 | num_semg_row=self.num_semg_row, 90 | num_semg_col=self.num_semg_col) 91 | subject = fold // self.num_session + 1 92 | session = fold % self.num_session + 1 93 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, session)], 94 | self.gestures, self.trials)), 95 | shuffle=False) 96 | logger.debug('Test set loaded') 97 | return val 98 | 99 | def encode_subject_and_session(self, subject, session): 100 | return (subject - 1) * self.num_session + session 101 | -------------------------------------------------------------------------------- /sigr/data/capgmyo/dbc.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'dbc' 8 | subjects = list(range(1, 11)) 9 | gestures = list(range(1, 13)) 10 | -------------------------------------------------------------------------------- /sigr/data/capgmyoiter/capgmyoiter_dbb.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from functools import partial 3 | from itertools import product 4 | from logbook import Logger 5 | from . import Dataset as Base 6 | from .. import get_data 7 | from ... import constant 8 | 9 | 10 | logger = Logger(__name__) 11 | 12 | 13 | class Dataset(Base): 14 | 15 | name = 'capgmyoiter-dbb' 16 | subjects = list(range(2, 21, 2)) 17 | gestures = list(range(1, 9)) 18 | num_session = 2 19 | sessions = [1, 2] 20 | 21 | root = ('.cache/dbb') 22 | 23 | @classmethod 24 | def parse(cls, text): 25 | if text == 'capgmyoiter-dbb': 26 | return cls(root = '.cache/dbb') 27 | 28 | -------------------------------------------------------------------------------- /sigr/data/capgmyoiter_dbb/capgmyoiter_dbb.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from functools import partial 3 | from itertools import product 4 | from logbook import Logger 5 | from . import Dataset as Base 6 | from .. import get_data 7 | from ... import constant 8 | 9 | 10 | logger = Logger(__name__) 11 | 12 | 13 | class Dataset(Base): 14 | 15 | name = 'capgmyoiter-dbb' 16 | subjects = list(range(2, 21, 2)) 17 | gestures = list(range(1, 9)) 18 | num_session = 2 19 | sessions = [1, 2] 20 | 21 | root = ('.cache/dbb') 22 | 23 | @classmethod 24 | def parse(cls, text): 25 | if text == 'capgmyoiter-dbb': 26 | return cls(root = '.cache/dbb') 27 | 28 | -------------------------------------------------------------------------------- /sigr/data/capgmyoiter_dbc/capgmyoiter_dbb.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from functools import partial 3 | from itertools import product 4 | from logbook import Logger 5 | from . import Dataset as Base 6 | from .. import get_data 7 | from ... import constant 8 | 9 | 10 | logger = Logger(__name__) 11 | 12 | 13 | class Dataset(Base): 14 | 15 | name = 'capgmyoiter-dbb' 16 | subjects = list(range(2, 21, 2)) 17 | gestures = list(range(1, 9)) 18 | num_session = 2 19 | sessions = [1, 2] 20 | 21 | root = ('.cache/dbb') 22 | 23 | @classmethod 24 | def parse(cls, text): 25 | if text == 'capgmyoiter-dbb': 26 | return cls(root = '.cache/dbb') 27 | 28 | -------------------------------------------------------------------------------- /sigr/data/ninapro/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | from itertools import product 4 | import numpy as np 5 | import scipy.io as sio 6 | from logbook import Logger 7 | from ... import utils, CACHE 8 | from .. import Dataset as Base, Combo, Trial, SingleSessionMixin 9 | 10 | 11 | NUM_SEMG_ROW = 1 12 | NUM_SEMG_COL = 10 13 | FRAMERATE = 100 14 | PREPROCESS_KARGS = dict( 15 | framerate=FRAMERATE, 16 | num_semg_row=NUM_SEMG_ROW, 17 | num_semg_col=NUM_SEMG_COL 18 | ) 19 | 20 | logger = Logger(__name__) 21 | 22 | 23 | class Dataset(SingleSessionMixin, Base): 24 | 25 | framerate = FRAMERATE 26 | num_semg_row = NUM_SEMG_ROW 27 | num_semg_col = NUM_SEMG_COL 28 | subjects = list(range(27)) 29 | gestures = list(range(53)) 30 | trials = list(range(10)) 31 | 32 | def __init__(self, root): 33 | self.root = root 34 | 35 | def get_one_fold_intra_subject_trials(self): 36 | return [0, 2, 3, 5, 7, 8, 9], [1, 4, 6] 37 | 38 | def get_trial_func(self, *args, **kargs): 39 | return GetTrial(*args, **kargs) 40 | 41 | @classmethod 42 | def parse(cls, text): 43 | if cls is not Dataset and text == cls.name: 44 | return cls(root=getattr(cls, 'root', os.path.join(CACHE, cls.name.split('/')[0], 'data'))) 45 | 46 | 47 | class GetTrial(object): 48 | 49 | def __init__(self, gestures, trials, preprocess=None): 50 | self.preprocess = preprocess 51 | self.memo = {} 52 | self.gesture_and_trials = list(product(gestures, trials)) 53 | 54 | def get_path(self, root, combo): 55 | return os.path.join( 56 | root, 57 | '{c.subject:03d}', 58 | '{c.gesture:03d}', 59 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 60 | 61 | def __call__(self, root, combo): 62 | path = self.get_path(root, combo) 63 | if path not in self.memo: 64 | logger.debug('Load subject {}', combo.subject) 65 | paths = [self.get_path(root, Combo(combo.subject, gesture, trial)) 66 | for gesture, trial in self.gesture_and_trials] 67 | self.memo.update({path: data for path, data in 68 | zip(paths, _get_data(paths, self.preprocess))}) 69 | data = self.memo[path] 70 | data = data.copy() 71 | gesture = np.repeat(combo.gesture, len(data)) 72 | subject = np.repeat(combo.subject, len(data)) 73 | return Trial(data=data, gesture=gesture, subject=subject) 74 | 75 | 76 | @utils.cached 77 | def _get_data(paths, preprocess): 78 | # return list(Context.parallel( 79 | # jb.delayed(_get_data_aux)(path, preprocess) for path in paths)) 80 | return [_get_data_aux(path, preprocess) for path in paths] 81 | 82 | 83 | def _get_data_aux(path, preprocess): 84 | data = sio.loadmat(path)['data'].astype(np.float32) 85 | if preprocess: 86 | data = preprocess(data, **PREPROCESS_KARGS) 87 | return data 88 | 89 | 90 | from . import db1, db1_g53, db1_g5, db1_g8, db1_g12, caputo, db1_matlab_lowpass, db1_raw_semg_glove, db1_signal_image, db1_signal_image_fast, db1_softmax_as_input, db1_rawdata_semgfeature 91 | assert db1 and db1_g53 and db1_g5 and db1_g8 and db1_g12 and caputo and db1_matlab_lowpass and db1_raw_semg_glove and db1_signal_image and db1_signal_image_fast and db1_softmax_as_input and db1_rawdata_semgfeature 92 | -------------------------------------------------------------------------------- /sigr/data/ninapro/caputo.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/caputo' 8 | gestures = list(range(1, 53)) 9 | 10 | def get_one_fold_intra_subject_trials(self): 11 | return [i - 1 for i in [1, 3, 4, 5, 9]], [i - 1 for i in [2, 6, 7, 8, 10]] 12 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1' 8 | gestures = list(range(1, 53)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_g12.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/g12' 8 | gestures = list(range(1, 13)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_g5.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/g5' 8 | gestures = list(range(25, 30)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_g53.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/g53' 8 | gestures = list(range(0, 53)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_g8.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/g8' 8 | gestures = list(range(13, 21)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_matlab_lowpass.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Database(Base): 6 | 7 | name = 'ninapro-db1-matlab-lowpass' 8 | gestures = list(range(1, 53)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/test_db1_raw_semg_glove.py: -------------------------------------------------------------------------------- 1 | from nose.tools import assert_equal 2 | from numpy.testing import assert_array_equal 3 | import numpy as np 4 | from .. import Dataset, Combo 5 | 6 | 7 | def test_get_trial(): 8 | dataset = Dataset.from_name('ninapro-db1-raw/semg-glove') 9 | get_trial = dataset.get_trial_func() 10 | 11 | def do(subject, gesture, trial): 12 | trial = get_trial(Combo(subject=subject, gesture=gesture, trial=trial)) 13 | assert_array_equal(trial.subject, subject) 14 | assert_equal(trial.gesture[0], -1) 15 | assert_equal(trial.gesture[-1], gesture) 16 | assert np.in1d(trial.gesture, [-1, gesture]).all() 17 | 18 | for subject, gesture, trial in [(1, 1, 1), 19 | (27, 1, 1), 20 | (1, 1, 10), 21 | (1, 12, 1), 22 | (1, 13, 1), 23 | (1, 29, 1), 24 | (1, 30, 1), 25 | (1, 52, 1)]: 26 | yield do, subject, gesture, trial 27 | 28 | 29 | def test_get_trial_norest(): 30 | dataset = Dataset.from_name('ninapro-db1-raw/semg-glove') 31 | get_trial = dataset.get_trial_func(norest=True) 32 | 33 | def do(subject, gesture, trial): 34 | trial = get_trial(Combo(subject=subject, gesture=gesture, trial=trial)) 35 | assert_array_equal(trial.subject, subject) 36 | assert_array_equal(trial.gesture, gesture) 37 | 38 | for subject, gesture, trial in [(1, 1, 1), 39 | (27, 1, 1), 40 | (1, 1, 10), 41 | (1, 12, 1), 42 | (1, 13, 1), 43 | (1, 29, 1), 44 | (1, 30, 1), 45 | (1, 52, 1)]: 46 | yield do, subject, gesture, trial 47 | -------------------------------------------------------------------------------- /sigr/data/s21.py: -------------------------------------------------------------------------------- 1 | from itertools import product, starmap 2 | from . import get_data, Combo 3 | from .. import ROOT 4 | import os 5 | import numpy as np 6 | 7 | 8 | ROOT = os.path.join(ROOT, '.cache/mat.s21.bandstop-45-55.s1000m.scale-01') 9 | 10 | 11 | def get_coral(folds, batch_size): 12 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 13 | return get_data( 14 | root=ROOT, 15 | # combos=get_combos(product([subjects[fold] for fold in folds], [100, 101], [0])), 16 | combos=get_combos(product([subjects[fold] for fold in folds], range(1, 9), [0])), 17 | mean=0.5, 18 | scale=2, 19 | batch_size=2000, 20 | last_batch_handle='pad', 21 | shuffle=False, 22 | adabn=True 23 | ) 24 | 25 | 26 | def get_combos(prods): 27 | return list(starmap(Combo, prods)) 28 | 29 | 30 | def get_stats(): 31 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 32 | load = lambda subject: get_data( 33 | root=ROOT, 34 | combos=get_combos(product([subject], range(1, 9), range(10))), 35 | mean=0.5, 36 | scale=2, 37 | batch_size=1000, 38 | last_batch_handle='roll_over' 39 | ) 40 | stats = [] 41 | for subject in subjects: 42 | batch = next(load(subject)[0]) 43 | data = batch.data[0].asnumpy() 44 | stats.append({ 45 | 'std': data.std() 46 | }) 47 | import pandas as pd 48 | return pd.DataFrame(stats, index=range(10)) 49 | 50 | 51 | def get_general_data(root, batch_size, with_subject): 52 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 53 | load = lambda **kargs: get_data( 54 | root=root, 55 | mean=0.5, 56 | scale=2, 57 | with_subject=with_subject, 58 | batch_size=batch_size, 59 | last_batch_handle='roll_over', 60 | **kargs 61 | ) 62 | val, num_val = load(combos=get_combos(product(subjects, range(1, 9), range(1, 10, 2)))) 63 | train, num_train = load(combos=get_combos(product(subjects, range(1, 9), range(0, 10, 2)))) 64 | return train, val, num_train, num_val 65 | 66 | 67 | def get_inter_subject_data( 68 | root, 69 | fold, 70 | batch_size, 71 | maxforce, 72 | target_binary, 73 | calib, 74 | with_subject, 75 | with_target_gesture, 76 | random_scale, 77 | random_bad_channel, 78 | shuffle, 79 | adabn, 80 | window, 81 | only_calib, 82 | soft_label, 83 | minibatch, 84 | fft, 85 | fft_append, 86 | dual_stream, 87 | lstm, 88 | dense_window, 89 | lstm_window 90 | ): 91 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 92 | 93 | num_subject = 10 if maxforce or calib else 9 94 | if minibatch: 95 | assert batch_size % num_subject == 0, '%d %% %d' % (batch_size, num_subject) 96 | mini_batch_size = batch_size // num_subject 97 | else: 98 | mini_batch_size = batch_size 99 | 100 | load = lambda **kargs: get_data( 101 | root=root, 102 | mean=0.5, 103 | scale=2, 104 | with_subject=with_subject, 105 | target_binary=target_binary, 106 | batch_size=batch_size, 107 | with_target_gesture=with_target_gesture, 108 | fft=fft, 109 | fft_append=fft_append, 110 | dual_stream=dual_stream, 111 | **kargs 112 | ) 113 | val_subject = subjects[fold] 114 | del subjects[fold] 115 | val = load( 116 | combos=get_combos(product([val_subject], range(1, 9), range(1, 10) if calib else range(10))), 117 | last_batch_handle='pad', 118 | shuffle=False, 119 | window=(window // (lstm_window or window)) if lstm else window, 120 | num_ignore_per_segment=window - 1 if lstm else 0, 121 | dense_window=dense_window 122 | ) 123 | 124 | if maxforce and calib: 125 | target_combos = get_combos(product([val_subject], list(range(1, 9)) * 10 + [100, 101], [0] * (9 if target_binary else 1))) 126 | elif maxforce: 127 | target_combos = get_combos(product([val_subject], [100, 101], [0] * 41 * (9 if target_binary else 1))) 128 | elif only_calib: 129 | target_combos = get_combos(product([val_subject], list(range(1, 9)), [0])) 130 | elif calib: 131 | target_combos = get_combos(product([val_subject], list(range(1, 9)) * 10, [0] * (9 if target_binary else 1))) 132 | else: 133 | target_combos = None 134 | 135 | if only_calib: 136 | combos = [] 137 | else: 138 | combos = get_combos(product(subjects, range(1, 9), range(10))) 139 | if maxforce: 140 | combos += get_combos(product(subjects, [100, 101], [0])) 141 | 142 | if soft_label: 143 | import pandas as pd 144 | soft_label = pd.DataFrame.from_csv(os.path.join(os.path.dirname(__file__), 's21_soft_label.scv')) 145 | 146 | train = load( 147 | combos=combos, 148 | target_combos=target_combos, 149 | random_scale=random_scale, 150 | random_bad_channel=random_bad_channel, 151 | last_batch_handle='pad', 152 | shuffle=shuffle, 153 | mini_batch_size=mini_batch_size, 154 | soft_label=False if soft_label is False else soft_label[soft_label['fold'] == fold][[str(i) for i in range(8)]].as_matrix(), 155 | adabn=adabn, 156 | window=window, 157 | dense_window=dense_window 158 | ) 159 | return train, val 160 | 161 | 162 | def get_inter_subject_val(fold, batch_size, calib, **kargs): 163 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 164 | return get_data( 165 | combos=get_combos(product([subjects[fold]], range(1, 9), range(1, 10) if calib else range(10))), 166 | root=ROOT, 167 | mean=0.5, 168 | scale=2, 169 | batch_size=batch_size, 170 | last_batch_handle='pad', 171 | shuffle=False, 172 | random_state=np.random.RandomState(42), 173 | **kargs 174 | ) 175 | 176 | 177 | def get_inter_subject_train(fold, batch_size): 178 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 179 | return get_data( 180 | combos=get_combos(product([subjects[i] for i in range(10) if i != fold], range(1, 9), range(10))), 181 | root=ROOT, 182 | mean=0.5, 183 | scale=2, 184 | batch_size=batch_size, 185 | last_batch_handle='pad', 186 | shuffle=False 187 | ) 188 | -------------------------------------------------------------------------------- /sigr/evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import numpy as np 4 | from functools import partial 5 | from .parse_log import parse_log 6 | from . import utils 7 | from . import module 8 | from logbook import Logger 9 | from copy import deepcopy 10 | import mxnet as mx 11 | 12 | 13 | Exp = utils.Bunch 14 | 15 | logger = Logger(__name__) 16 | 17 | 18 | @utils.cached(ignore=['context']) 19 | def _crossval_predict_aux(self, Mod, get_crossval_val, fold, context, dataset_args=None): 20 | Mod = deepcopy(Mod) 21 | Mod.update(context=context) 22 | mod = module.RuntimeModule(**Mod) 23 | Val = partial( 24 | get_crossval_val, 25 | fold=fold, 26 | batch_size=self.batch_size, 27 | window=mod.num_channel, 28 | **(dataset_args or {}) 29 | ) 30 | return mod.predict(utils.LazyProxy(Val)) 31 | 32 | 33 | @utils.cached(ignore=['context']) 34 | def _crossval_predict_proba_aux(self, Mod, get_crossval_val, fold, context, dataset_args=None): 35 | Mod = deepcopy(Mod) 36 | Mod.update(context=context) 37 | mod = module.RuntimeModule(**Mod) 38 | Val = partial( 39 | get_crossval_val, 40 | fold=fold, 41 | batch_size=self.batch_size, 42 | window=mod.num_channel, 43 | **(dataset_args or {}) 44 | ) 45 | return mod.predict_proba(utils.LazyProxy(Val)) 46 | 47 | 48 | def _crossval_predict(self, **kargs): 49 | proba = kargs.pop('proba', False) 50 | fold = int(kargs.pop('fold')) 51 | Mod = kargs.pop('Mod') 52 | Mod = deepcopy(Mod) 53 | Mod.update(params=self.format_params(Mod['params'], fold)) 54 | context = Mod.pop('context', [mx.gpu(0)]) 55 | # import pickle 56 | # d = kargs.copy() 57 | # d.update(Mod=Mod, fold=fold) 58 | # print(pickle.dumps(d)) 59 | 60 | # Ensure load from disk. 61 | # Otherwise following cached methods like vote will have two caches, 62 | # one for the first computation, 63 | # and the other for the cached one. 64 | func = _crossval_predict_aux if not proba else _crossval_predict_proba_aux 65 | return func.call_and_shelve(self, Mod=Mod, fold=fold, context=context, **kargs).get() 66 | 67 | 68 | class Evaluation(object): 69 | 70 | def __init__(self, batch_size=None): 71 | self.batch_size = batch_size 72 | 73 | 74 | class CrossValEvaluation(Evaluation): 75 | 76 | def __init__(self, **kargs): 77 | self.crossval_type = kargs.pop('crossval_type') 78 | super(CrossValEvaluation, self).__init__(**kargs) 79 | 80 | def get_crossval_val_func(self, dataset): 81 | return getattr(dataset, 'get_%s_val' % self.crossval_type.replace('-', '_')) 82 | 83 | def format_params(self, params, fold): 84 | try: 85 | return params % fold 86 | except: 87 | return params 88 | 89 | def transform(self, Mod, dataset, fold, dataset_args=None): 90 | get_crossval_val = self.get_crossval_val_func(dataset) 91 | pred, true, _ = _crossval_predict( 92 | self, 93 | proba=True, 94 | Mod=Mod, 95 | get_crossval_val=get_crossval_val, 96 | fold=fold, 97 | dataset_args=dataset_args) 98 | return pred, true 99 | 100 | def accuracy_mod(self, Mod, dataset, fold, 101 | vote=False, 102 | dataset_args=None, 103 | balance=False): 104 | get_crossval_val = self.get_crossval_val_func(dataset) 105 | pred, true, segment = _crossval_predict( 106 | self, 107 | Mod=Mod, 108 | get_crossval_val=get_crossval_val, 109 | fold=fold, 110 | dataset_args=dataset_args) 111 | if vote: 112 | from .vote import vote as do 113 | return do(true, pred, segment, vote, balance) 114 | return (true == pred).sum() / true.size 115 | 116 | def accuracy_exp(self, exp, fold): 117 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'): 118 | return self.accuracy_mod(Mod=exp.Mod, 119 | dataset=exp.dataset, 120 | fold=fold, 121 | vote=exp.get('vote', False), 122 | dataset_args=exp.get('dataset_args')) 123 | else: 124 | try: 125 | return parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1] 126 | except: 127 | return np.nan 128 | 129 | def accuracy(self, **kargs): 130 | if 'exp' in kargs: 131 | return self.accuracy_exp(**kargs) 132 | elif 'Mod' in kargs: 133 | return self.accuracy_mod(**kargs) 134 | else: 135 | assert False 136 | 137 | def accuracies(self, exps, folds): 138 | acc = [] 139 | for exp in exps: 140 | for fold in folds: 141 | acc.append(self.accuracy(exp=exp, fold=fold)) 142 | return np.array(acc).reshape(len(exps), len(folds)) 143 | 144 | def compare(self, exps, fold): 145 | acc = [] 146 | for exp in exps: 147 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'): 148 | acc.append(self.accuracy(Mod=exp.Mod, 149 | dataset=exp.dataset, 150 | fold=fold, 151 | vote=exp.get('vote', False), 152 | dataset_args=exp.get('dataset_args'))) 153 | else: 154 | try: 155 | acc.append(parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1]) 156 | except: 157 | acc.append(np.nan) 158 | return acc 159 | 160 | def vote_accuracy_curves(self, exps, folds, windows, balance=False): 161 | acc = [] 162 | for exp in exps: 163 | for fold in folds: 164 | acc.append(self.vote_accuracy_curve( 165 | Mod=exp.Mod, 166 | dataset=exp.dataset, 167 | fold=int(fold), 168 | windows=windows, 169 | dataset_args=exp.get('dataset_args'), 170 | balance=balance)) 171 | return np.array(acc).reshape(len(exps), len(folds), len(windows)) 172 | 173 | def vote_accuracy_curve(self, Mod, dataset, fold, windows, 174 | dataset_args=None, 175 | balance=False): 176 | get_crossval_val = self.get_crossval_val_func(dataset) 177 | pred, true, segment = _crossval_predict( 178 | self, 179 | Mod=Mod, 180 | get_crossval_val=get_crossval_val, 181 | fold=fold, 182 | dataset_args=dataset_args) 183 | from .vote import get_vote_accuracy_curve as do 184 | return do(true, pred, segment, windows, balance)[1] 185 | 186 | 187 | def get_crossval_accuracies(crossval_type, exps, folds, batch_size=1000): 188 | acc = [] 189 | evaluation = CrossValEvaluation( 190 | crossval_type=crossval_type, 191 | batch_size=batch_size 192 | ) 193 | for fold in folds: 194 | acc.append(evaluation.compare(exps, fold)) 195 | return acc 196 | -------------------------------------------------------------------------------- /sigr/evaluation_db1multistream.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import numpy as np 4 | from functools import partial 5 | from .parse_log import parse_log 6 | from . import utils 7 | from . import module_multistream 8 | from logbook import Logger 9 | from copy import deepcopy 10 | import mxnet as mx 11 | 12 | 13 | Exp = utils.Bunch 14 | 15 | logger = Logger(__name__) 16 | 17 | #feature_name = args.feature_name, 18 | #window=args.window, 19 | #num_semg_row = args.num_semg_row, 20 | #num_semg_col = args.num_semg_col 21 | 22 | 23 | @utils.cached(ignore=['context']) 24 | def _crossval_predict_aux(self, Mod, get_crossval_val, fold, context, 25 | feature_name, 26 | window, 27 | num_semg_row, 28 | num_semg_col, 29 | dataset_args=None): 30 | Mod = deepcopy(Mod) 31 | Mod.update(context=context) 32 | mod = module_multistream.RuntimeModule(**Mod) 33 | Val = partial( 34 | get_crossval_val, 35 | fold=fold, 36 | batch_size=self.batch_size, 37 | window=window, 38 | feature_name=feature_name, 39 | num_semg_row=num_semg_row, 40 | num_semg_col=num_semg_col, 41 | **(dataset_args or {}) 42 | ) 43 | # print Val.name 44 | return mod.predict(utils.LazyProxy(Val)) 45 | 46 | 47 | @utils.cached(ignore=['context']) 48 | def _crossval_predict_proba_aux(self, Mod, get_crossval_val, fold, context, 49 | feature_name, 50 | window, 51 | num_semg_row, 52 | num_semg_col, 53 | dataset_args=None): 54 | Mod = deepcopy(Mod) 55 | Mod.update(context=context) 56 | mod = module_multistream.RuntimeModule(**Mod) 57 | Val = partial( 58 | get_crossval_val, 59 | fold=fold, 60 | batch_size=self.batch_size, 61 | window=window, 62 | feature_name=feature_name, 63 | num_semg_row=num_semg_row, 64 | num_semg_col=num_semg_col, 65 | **(dataset_args or {}) 66 | ) 67 | return mod.predict_proba(utils.LazyProxy(Val)) 68 | 69 | 70 | def _crossval_predict(self, **kargs): 71 | proba = kargs.pop('proba', False) 72 | fold = int(kargs.pop('fold')) 73 | Mod = kargs.pop('Mod') 74 | Mod = deepcopy(Mod) 75 | Mod.update(params=self.format_params(Mod['params'], fold)) 76 | context = Mod.pop('context', [mx.gpu(0)]) 77 | window = kargs.pop('window') 78 | feature_name = kargs.pop('feature_name') 79 | num_semg_row=kargs.pop('num_semg_row') 80 | num_semg_col=kargs.pop('num_semg_col') 81 | # import pickle 82 | # d = kargs.copy() 83 | # d.update(Mod=Mod, fold=fold) 84 | # print(pickle.dumps(d)) 85 | 86 | # Ensure load from disk. 87 | # Otherwise following cached methods like vote will have two caches, 88 | # one for the first computation, 89 | # and the other for the cached one. 90 | func = _crossval_predict_aux if not proba else _crossval_predict_proba_aux 91 | return func.call_and_shelve(self, Mod=Mod, fold=fold, context=context, 92 | window=window, 93 | feature_name=feature_name, 94 | num_semg_row=num_semg_row, 95 | num_semg_col=num_semg_col, 96 | **kargs).get() 97 | 98 | 99 | class Evaluation(object): 100 | 101 | def __init__(self, batch_size=None): 102 | self.batch_size = batch_size 103 | 104 | 105 | class CrossValEvaluation(Evaluation): 106 | 107 | def __init__(self, **kargs): 108 | self.crossval_type = kargs.pop('crossval_type') 109 | super(CrossValEvaluation, self).__init__(**kargs) 110 | 111 | def get_crossval_val_func(self, dataset): 112 | return getattr(dataset, 'get_%s_val' % self.crossval_type.replace('-', '_')) 113 | 114 | def format_params(self, params, fold): 115 | try: 116 | return params % fold 117 | except: 118 | return params 119 | 120 | def transform(self, Mod, dataset, fold, dataset_args=None): 121 | get_crossval_val = self.get_crossval_val_func(dataset) 122 | pred, true, _ = _crossval_predict( 123 | self, 124 | proba=True, 125 | Mod=Mod, 126 | get_crossval_val=get_crossval_val, 127 | fold=fold, 128 | dataset_args=dataset_args) 129 | return pred, true 130 | 131 | def accuracy_mod(self, Mod, dataset, fold, 132 | vote=False, 133 | dataset_args=None, 134 | balance=False): 135 | get_crossval_val = self.get_crossval_val_func(dataset) 136 | pred, true, segment = _crossval_predict( 137 | self, 138 | Mod=Mod, 139 | get_crossval_val=get_crossval_val, 140 | fold=fold, 141 | dataset_args=dataset_args) 142 | if vote: 143 | from .vote import vote as do 144 | return do(true, pred, segment, vote, balance) 145 | return (true == pred).sum() / true.size 146 | 147 | def accuracy_exp(self, exp, fold): 148 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'): 149 | return self.accuracy_mod(Mod=exp.Mod, 150 | dataset=exp.dataset, 151 | fold=fold, 152 | vote=exp.get('vote', False), 153 | dataset_args=exp.get('dataset_args')) 154 | else: 155 | try: 156 | return parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1] 157 | except: 158 | return np.nan 159 | 160 | def accuracy(self, **kargs): 161 | if 'exp' in kargs: 162 | return self.accuracy_exp(**kargs) 163 | elif 'Mod' in kargs: 164 | return self.accuracy_mod(**kargs) 165 | else: 166 | assert False 167 | 168 | def accuracies(self, exps, folds): 169 | acc = [] 170 | for exp in exps: 171 | for fold in folds: 172 | acc.append(self.accuracy(exp=exp, fold=fold)) 173 | return np.array(acc).reshape(len(exps), len(folds)) 174 | 175 | def compare(self, exps, fold): 176 | acc = [] 177 | for exp in exps: 178 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'): 179 | acc.append(self.accuracy(Mod=exp.Mod, 180 | dataset=exp.dataset, 181 | fold=fold, 182 | vote=exp.get('vote', False), 183 | dataset_args=exp.get('dataset_args'))) 184 | else: 185 | try: 186 | acc.append(parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1]) 187 | except: 188 | acc.append(np.nan) 189 | return acc 190 | 191 | def vote_accuracy_curves(self, exps, folds, windows, feature_name, window, num_semg_row, num_semg_col, balance=False): 192 | acc = [] 193 | for exp in exps: 194 | for fold in folds: 195 | acc.append(self.vote_accuracy_curve( 196 | Mod=exp.Mod, 197 | dataset=exp.dataset, 198 | fold=int(fold), 199 | windows=windows, 200 | feature_name=feature_name, 201 | window=window, 202 | num_semg_row=num_semg_row, 203 | num_semg_col=num_semg_col, 204 | dataset_args=exp.get('dataset_args'), 205 | balance=balance)) 206 | return np.array(acc).reshape(len(exps), len(folds), len(windows)) 207 | 208 | def vote_accuracy_curve(self, Mod, dataset, fold, windows, feature_name, window, num_semg_row, num_semg_col, 209 | dataset_args=None, 210 | balance=False): 211 | get_crossval_val = self.get_crossval_val_func(dataset) 212 | pred, true, segment = _crossval_predict( 213 | self, 214 | Mod=Mod, 215 | get_crossval_val=get_crossval_val, 216 | fold=fold, 217 | feature_name=feature_name, 218 | window = window, 219 | num_semg_row = num_semg_row, 220 | num_semg_col = num_semg_col, 221 | dataset_args=dataset_args) 222 | from .vote import get_vote_accuracy_curve as do 223 | return do(true, pred, segment, windows, balance)[1] 224 | 225 | 226 | def get_crossval_accuracies(crossval_type, exps, folds, batch_size=1000): 227 | acc = [] 228 | evaluation = CrossValEvaluation( 229 | crossval_type=crossval_type, 230 | batch_size=batch_size 231 | ) 232 | for fold in folds: 233 | acc.append(evaluation.compare(exps, fold)) 234 | return acc 235 | 236 | -------------------------------------------------------------------------------- /sigr/evaluation_semimyo.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import numpy as np 4 | from functools import partial 5 | from .parse_log import parse_log 6 | from . import utils 7 | from . import module_semimyo as module 8 | from logbook import Logger 9 | from copy import deepcopy 10 | import mxnet as mx 11 | 12 | 13 | Exp = utils.Bunch 14 | 15 | logger = Logger(__name__) 16 | 17 | 18 | @utils.cached(ignore=['context']) 19 | def _crossval_predict_aux(self, Mod, get_crossval_val, fold, context, dataset_args=None): 20 | Mod = deepcopy(Mod) 21 | Mod.update(context=context) 22 | mod = module.RuntimeModule(**Mod) 23 | Val = partial( 24 | get_crossval_val, 25 | fold=fold, 26 | batch_size=self.batch_size, 27 | **(dataset_args or {}) 28 | ) 29 | return mod.predict(utils.LazyProxy(Val)) 30 | 31 | 32 | @utils.cached(ignore=['context']) 33 | def _crossval_predict_proba_aux(self, Mod, get_crossval_val, fold, context, dataset_args=None): 34 | Mod = deepcopy(Mod) 35 | Mod.update(context=context) 36 | mod = module.RuntimeModule(**Mod) 37 | Val = partial( 38 | get_crossval_val, 39 | fold=fold, 40 | batch_size=self.batch_size, 41 | **(dataset_args or {}) 42 | ) 43 | return mod.predict_proba(utils.LazyProxy(Val)) 44 | 45 | 46 | def _crossval_predict(self, **kargs): 47 | proba = kargs.pop('proba', False) 48 | fold = int(kargs.pop('fold')) 49 | Mod = kargs.pop('Mod') 50 | Mod = deepcopy(Mod) 51 | Mod.update(params=self.format_params(Mod['params'], fold)) 52 | context = Mod.pop('context', [mx.gpu(0)]) 53 | # import pickle 54 | # d = kargs.copy() 55 | # d.update(Mod=Mod, fold=fold) 56 | # print(pickle.dumps(d)) 57 | 58 | # Ensure load from disk. 59 | # Otherwise following cached methods like vote will have two caches, 60 | # one for the first computation, 61 | # and the other for the cached one. 62 | func = _crossval_predict_aux if not proba else _crossval_predict_proba_aux 63 | return func.call_and_shelve(self, Mod=Mod, fold=fold, context=context, **kargs).get() 64 | 65 | 66 | class Evaluation(object): 67 | 68 | def __init__(self, batch_size=None): 69 | self.batch_size = batch_size 70 | 71 | 72 | class CrossValEvaluation(Evaluation): 73 | 74 | def __init__(self, **kargs): 75 | self.crossval_type = kargs.pop('crossval_type') 76 | super(CrossValEvaluation, self).__init__(**kargs) 77 | 78 | def get_crossval_val_func(self, dataset): 79 | return getattr(dataset, 'get_%s_val' % self.crossval_type.replace('-', '_')) 80 | 81 | def format_params(self, params, fold): 82 | try: 83 | return params % fold 84 | except: 85 | return params 86 | 87 | def transform(self, Mod, dataset, fold, dataset_args=None): 88 | get_crossval_val = self.get_crossval_val_func(dataset) 89 | pred, true, _ = _crossval_predict( 90 | self, 91 | proba=True, 92 | Mod=Mod, 93 | get_crossval_val=get_crossval_val, 94 | fold=fold, 95 | dataset_args=dataset_args) 96 | return pred, true 97 | 98 | def accuracy_mod(self, Mod, dataset, fold, 99 | vote=False, 100 | dataset_args=None, 101 | balance=False): 102 | get_crossval_val = self.get_crossval_val_func(dataset) 103 | pred, true, segment = _crossval_predict( 104 | self, 105 | Mod=Mod, 106 | get_crossval_val=get_crossval_val, 107 | fold=fold, 108 | dataset_args=dataset_args) 109 | if vote: 110 | from .vote import vote as do 111 | return do(true, pred, segment, vote, balance) 112 | return (true == pred).sum() / true.size 113 | 114 | def accuracy_exp(self, exp, fold): 115 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'): 116 | return self.accuracy_mod(Mod=exp.Mod, 117 | dataset=exp.dataset, 118 | fold=fold, 119 | vote=exp.get('vote', False), 120 | dataset_args=exp.get('dataset_args')) 121 | else: 122 | try: 123 | return parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1] 124 | except: 125 | return np.nan 126 | 127 | def accuracy(self, **kargs): 128 | if 'exp' in kargs: 129 | return self.accuracy_exp(**kargs) 130 | elif 'Mod' in kargs: 131 | return self.accuracy_mod(**kargs) 132 | else: 133 | assert False 134 | 135 | def accuracies(self, exps, folds): 136 | acc = [] 137 | for exp in exps: 138 | for fold in folds: 139 | acc.append(self.accuracy(exp=exp, fold=fold)) 140 | return np.array(acc).reshape(len(exps), len(folds)) 141 | 142 | def compare(self, exps, fold): 143 | acc = [] 144 | for exp in exps: 145 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'): 146 | acc.append(self.accuracy(Mod=exp.Mod, 147 | dataset=exp.dataset, 148 | fold=fold, 149 | vote=exp.get('vote', False), 150 | dataset_args=exp.get('dataset_args'))) 151 | else: 152 | try: 153 | acc.append(parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1]) 154 | except: 155 | acc.append(np.nan) 156 | return acc 157 | 158 | def vote_accuracy_curves(self, exps, folds, windows, balance=False): 159 | acc = [] 160 | for exp in exps: 161 | for fold in folds: 162 | acc.append(self.vote_accuracy_curve( 163 | Mod=exp.Mod, 164 | dataset=exp.dataset, 165 | fold=int(fold), 166 | windows=windows, 167 | dataset_args=exp.get('dataset_args'), 168 | balance=balance)) 169 | return np.array(acc).reshape(len(exps), len(folds), len(windows)) 170 | 171 | def vote_accuracy_curve(self, Mod, dataset, fold, windows, 172 | dataset_args=None, 173 | balance=False): 174 | get_crossval_val = self.get_crossval_val_func(dataset) 175 | pred, true, segment = _crossval_predict( 176 | self, 177 | Mod=Mod, 178 | get_crossval_val=get_crossval_val, 179 | fold=fold, 180 | dataset_args=dataset_args) 181 | from .vote import get_vote_accuracy_curve as do 182 | return do(true, pred, segment, windows, balance)[1] 183 | 184 | 185 | def get_crossval_accuracies(crossval_type, exps, folds, batch_size=1000): 186 | acc = [] 187 | evaluation = CrossValEvaluation( 188 | crossval_type=crossval_type, 189 | batch_size=batch_size 190 | ) 191 | for fold in folds: 192 | acc.append(evaluation.compare(exps, fold)) 193 | return acc 194 | -------------------------------------------------------------------------------- /sigr/feature_map/activity_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #import cv2 3 | 4 | def genIndex(chanums): 5 | 6 | index = [] 7 | i = 1 8 | j = i+1 9 | 10 | if (chanums % 2) == 0: 11 | Ns = chanums+1 12 | else: 13 | Ns = chanums 14 | 15 | 16 | index.append(1) 17 | t = chr(i+ord('A')) 18 | while(i!=j): 19 | l = "" 20 | l = l+chr(i+ord('A')) 21 | l = l+chr(j+ord('A')) 22 | r = "" 23 | r = r+chr(j+ord('A')) 24 | r = r+chr(i+ord('A')) 25 | if(j>Ns): 26 | j = 1 27 | elif(t.find(l)==-1 and t.find(r)==-1): 28 | index.append(j) 29 | t = t+chr(j+ord('A')) 30 | i = j 31 | j = i+1 32 | else: 33 | j = j+1 34 | 35 | 36 | 37 | new_index = [] 38 | if (chanums % 2) == 0: 39 | for i in range(len(index)): 40 | if index[i] != chanums+1: 41 | new_index.append(index[i]) 42 | 43 | index = np.array(new_index) 44 | index = index-1 45 | return index 46 | 47 | 48 | def get_signal_img(data): 49 | 50 | ch_num = data.shape[0] 51 | index = genIndex(ch_num) 52 | signal_img = data[index] 53 | signal_img = signal_img[:-1] 54 | # print signal_img.shape 55 | return signal_img 56 | 57 | def get_activity_img(data): 58 | 59 | signal_img = get_signal_img(data) 60 | 61 | f = np.fft.fft2(signal_img) 62 | fshift = np.fft.fftshift(f) 63 | magnitude_spectrum = 20*np.log(np.abs(fshift)) 64 | # magnitude_spectrum = cv2.resize(magnitude_spectrum,None,fx=1,fy=8) 65 | # cv2.imshow('image',magnitude_spectrum) 66 | # cv2.waitKey(0) 67 | # cv2.destroyAllWindows() 68 | return magnitude_spectrum 69 | 70 | 71 | -------------------------------------------------------------------------------- /sigr/feature_map/feature-map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import emg_features 3 | 4 | def feature_map(x, feature_list): 5 | 6 | res = [] 7 | for i in range(x.shape[0]): 8 | single_channel = [] 9 | for j in range(len(feature_list)): 10 | func = 'emg_features.emg_'+feature_list[j] 11 | single_channel.append(eval(str(func))(x[i,:])) 12 | # print single_channel[0].shape 13 | single_channel = np.hstack(single_channel) 14 | res.append(single_channel) 15 | res =np.vstack(res) 16 | return res -------------------------------------------------------------------------------- /sigr/feature_map/ninapro_feature_map_extractor-43.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | import numpy as np 4 | import emg_features 5 | #from activity_image import get_signal_img 6 | from itertools import product 7 | from collections import namedtuple 8 | from joblib import Parallel, delayed 9 | #import cv2 10 | #from ..utils import butter_lowpass_filter as lowpass 11 | 12 | 13 | subjects = list(range(0,13)) 14 | gestures = list(range(1,53)) 15 | trials = list(range(10)) 16 | #input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 17 | #output_path = '/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw' 18 | 19 | 20 | input_path = '/home/weiwentao/data/ninapro-db1' 21 | 22 | 23 | filtering_type = 'lowpass' 24 | framerate = 100 25 | 26 | window_length_ms = 200 27 | window_stride_ms = 10 28 | 29 | window = window_length_ms*framerate/1000 30 | stride = window_stride_ms*framerate/1000 31 | 32 | output_path = ('/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw-prepro-%s-win-%d-stride-%d' % (filtering_type, window, stride)) 33 | 34 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False) 35 | 36 | 37 | #feature_list = ['dwt','dwpt','dwpt_mean','dwpt_sd','mav','mavslpphinyomark','wl','sscbestninapro1','rms','hemg20','mdwtdb1ninapro'] 38 | 39 | 40 | feature_list = ['mavslpphinyomark','arc','cc', 'arr', 'fft_power_db1'] 41 | 42 | def get_combos(*args): 43 | for arg in args: 44 | if isinstance(arg, tuple): 45 | arg = [arg] 46 | for a in arg: 47 | yield Combo(*a) 48 | 49 | 50 | 51 | #the following functions can be loaded from ..utils 52 | 53 | 54 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 55 | from scipy.signal import butter, lfilter, filtfilt 56 | 57 | nyq = 0.5 * fs 58 | cut = cut / nyq 59 | 60 | b, a = butter(order, cut, btype='low') 61 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 62 | return y 63 | 64 | 65 | def get_segments(data, window, stride): 66 | return windowed_view( 67 | data.flat, 68 | window * data.shape[1], 69 | (window-stride)* data.shape[1] 70 | ) 71 | 72 | def windowed_view(arr, window, overlap): 73 | from numpy.lib.stride_tricks import as_strided 74 | arr = np.asarray(arr) 75 | window_step = window - overlap 76 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 77 | window) 78 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 79 | arr.strides[-1:]) 80 | return as_strided(arr, shape=new_shape, strides=new_strides) 81 | 82 | 83 | #def dft(data): 84 | # f = np.fft.fft2(data) 85 | # fshift = np.fft.fftshift(f) 86 | # magnitude_spectrum = 20*np.log(np.abs(fshift)) 87 | # return magnitude_spectrum 88 | # 89 | ##the following functions can be loaded from .. 90 | # 91 | #def dft_dy(data): 92 | # data = data.T 93 | # n = data.shape[-1] 94 | # window = np.hanning(n) 95 | # windowed = data * window 96 | # spectrum = np.fft.fft(windowed) 97 | # return np.abs(spectrum) 98 | 99 | 100 | def feature_map(x): 101 | 102 | res = [] 103 | for i in range(x.shape[0]): 104 | single_channel = [] 105 | for j in range(len(feature_list)): 106 | func = 'emg_features.emg_'+feature_list[j] 107 | single_channel.append(eval(str(func))(x[i,:])) 108 | # print single_channel[0].shape 109 | single_channel = np.hstack(single_channel) 110 | res.append(single_channel) 111 | res =np.vstack(res) 112 | return res 113 | 114 | def extract_emg_feature(x, feature_name): 115 | res = [] 116 | for i in range(x.shape[0]): 117 | func = 'emg_features.emg_'+feature_name 118 | res.append(eval(str(func))(x[i,:])) 119 | res =np.vstack(res) 120 | return res 121 | 122 | def emg_feature_extraction_parallel(out_dir, combo, data, feature_name): 123 | feature = [np.transpose(extract_emg_feature(seg.T, feature_name)) for seg in data] 124 | feature = np.array(feature) 125 | out_path = os.path.join( 126 | out_dir, 127 | '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}_{1}.mat').format(combo, feature_name) 128 | sio.savemat(out_path, {'data': feature, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 129 | print ("Subject %d Gesture %d Trial %d %s saved!" % (combo.subject, combo.gesture, combo.trial, feature_name)) 130 | 131 | 132 | if __name__ == '__main__': 133 | 134 | print ("NinaPro feature map generation, use window = %d frames, stride = %d frames" % (window, stride)) 135 | 136 | combos = get_combos(product(subjects, gestures, trials)) 137 | 138 | combos = list(combos) 139 | 140 | 141 | # # for test only 142 | # pre_data = [] 143 | # feature_dim = 0 144 | # for i in range(len(feature_list)): 145 | # input_dir = os.path.join(output_path, 146 | # '000', 147 | # '000', 148 | # '000_000_000_{0}.mat' 149 | # ).format(feature_list[i]) 150 | # mat = sio.loadmat(input_dir) 151 | # data = mat['data'].astype(np.float32) 152 | # feature_dim = feature_dim + data.shape[1] 153 | # pre_data.append(data) 154 | # pre_data = np.concatenate(pre_data, axis=1) 155 | # feature_dim = pre_data.shape[1] 156 | 157 | 158 | 159 | 160 | for combo in combos: 161 | in_path = os.path.join( 162 | input_path, 'data', 163 | '{c.subject:03d}', 164 | '{c.gesture:03d}', 165 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 166 | 167 | out_dir = os.path.join( 168 | output_path, 169 | '{c.subject:03d}', 170 | '{c.gesture:03d}').format(c=combo) 171 | 172 | if os.path.isdir(out_dir) is False: 173 | os.makedirs(out_dir) 174 | 175 | 176 | data = sio.loadmat(in_path)['data'].astype(np.float32) 177 | 178 | print ("Subject %d Gesture %d Trial %d data loaded!" % (combo.subject, combo.gesture, combo.trial)) 179 | 180 | if filtering_type is 'lowpass': 181 | # data = np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 182 | data = np.transpose([butter_lowpass_filter(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 183 | print ("Subject %d Gesture %d Trial %d bandpass filtering finished!" % (combo.subject, combo.gesture, combo.trial)) 184 | else: 185 | pass 186 | 187 | 188 | 189 | chnum = data.shape[1]; 190 | data = get_segments(data, window, stride) 191 | data = data.reshape(-1, window, chnum) 192 | 193 | Parallel(n_jobs=4)(delayed(emg_feature_extraction_parallel)(out_dir, combo, data, feature_list[i]) for i in range(len(feature_list))) 194 | 195 | # for i in range(len(feature_list)): 196 | # feature = [np.transpose(extract_emg_feature(seg.T, feature_list[i])) for seg in data] 197 | # feature = np.array(feature) 198 | # out_path = os.path.join( 199 | # out_dir, 200 | # '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}_{1}.mat').format(combo, feature_list[i]) 201 | # sio.savemat(out_path, {'data': feature, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 202 | # print ("Subject %d Gesture %d Trial %d %s saved!" % (combo.subject, combo.gesture, combo.trial, feature_list[i])) 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /sigr/feature_map/ninapro_feature_map_extractor-45.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | import numpy as np 4 | import emg_features 5 | #from activity_image import get_signal_img 6 | from itertools import product 7 | from collections import namedtuple 8 | 9 | #import cv2 10 | #from ..utils import butter_lowpass_filter as lowpass 11 | 12 | 13 | subjects = list(range(13,27)) 14 | gestures = list(range(1,53)) 15 | trials = list(range(10)) 16 | #input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 17 | #output_path = '/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw' 18 | 19 | 20 | input_path = '/home/weiwentao/data/ninapro-db1' 21 | 22 | 23 | filtering_type = 'lowpass' 24 | framerate = 100 25 | 26 | window_length_ms = 200 27 | window_stride_ms = 10 28 | 29 | window = window_length_ms*framerate/1000 30 | stride = window_stride_ms*framerate/1000 31 | 32 | output_path = ('/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw-prepro-%s-win-%d-stride-%d' % (filtering_type, window, stride)) 33 | 34 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False) 35 | 36 | 37 | #feature_list = ['dwt','dwpt','dwpt_mean','dwpt_sd','mav','mavslpphinyomark','wl','sscbestninapro1','rms','hemg20','mdwtdb1ninapro'] 38 | 39 | 40 | feature_list = ['mavslpphinyomark','arc','cc', 'arr', 'fft_power_db1'] 41 | 42 | def get_combos(*args): 43 | for arg in args: 44 | if isinstance(arg, tuple): 45 | arg = [arg] 46 | for a in arg: 47 | yield Combo(*a) 48 | 49 | 50 | 51 | #the following functions can be loaded from ..utils 52 | 53 | 54 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 55 | from scipy.signal import butter, lfilter, filtfilt 56 | 57 | nyq = 0.5 * fs 58 | cut = cut / nyq 59 | 60 | b, a = butter(order, cut, btype='low') 61 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 62 | return y 63 | 64 | 65 | def get_segments(data, window, stride): 66 | return windowed_view( 67 | data.flat, 68 | window * data.shape[1], 69 | (window-stride)* data.shape[1] 70 | ) 71 | 72 | def windowed_view(arr, window, overlap): 73 | from numpy.lib.stride_tricks import as_strided 74 | arr = np.asarray(arr) 75 | window_step = window - overlap 76 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 77 | window) 78 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 79 | arr.strides[-1:]) 80 | return as_strided(arr, shape=new_shape, strides=new_strides) 81 | 82 | 83 | #def dft(data): 84 | # f = np.fft.fft2(data) 85 | # fshift = np.fft.fftshift(f) 86 | # magnitude_spectrum = 20*np.log(np.abs(fshift)) 87 | # return magnitude_spectrum 88 | # 89 | ##the following functions can be loaded from .. 90 | # 91 | #def dft_dy(data): 92 | # data = data.T 93 | # n = data.shape[-1] 94 | # window = np.hanning(n) 95 | # windowed = data * window 96 | # spectrum = np.fft.fft(windowed) 97 | # return np.abs(spectrum) 98 | 99 | 100 | def feature_map(x): 101 | 102 | res = [] 103 | for i in range(x.shape[0]): 104 | single_channel = [] 105 | for j in range(len(feature_list)): 106 | func = 'emg_features.emg_'+feature_list[j] 107 | single_channel.append(eval(str(func))(x[i,:])) 108 | # print single_channel[0].shape 109 | single_channel = np.hstack(single_channel) 110 | res.append(single_channel) 111 | res =np.vstack(res) 112 | return res 113 | 114 | def extract_emg_feature(x, feature_name): 115 | res = [] 116 | for i in range(x.shape[0]): 117 | func = 'emg_features.emg_'+feature_name 118 | res.append(eval(str(func))(x[i,:])) 119 | res =np.vstack(res) 120 | return res 121 | 122 | 123 | 124 | 125 | if __name__ == '__main__': 126 | 127 | print ("NinaPro feature map generation, use window = %d frames, stride = %d frames" % (window, stride)) 128 | 129 | combos = get_combos(product(subjects, gestures, trials)) 130 | 131 | combos = list(combos) 132 | 133 | 134 | # # for test only 135 | # pre_data = [] 136 | # feature_dim = 0 137 | # for i in range(len(feature_list)): 138 | # input_dir = os.path.join(output_path, 139 | # '000', 140 | # '000', 141 | # '000_000_000_{0}.mat' 142 | # ).format(feature_list[i]) 143 | # mat = sio.loadmat(input_dir) 144 | # data = mat['data'].astype(np.float32) 145 | # feature_dim = feature_dim + data.shape[1] 146 | # pre_data.append(data) 147 | # pre_data = np.concatenate(pre_data, axis=1) 148 | # feature_dim = pre_data.shape[1] 149 | 150 | 151 | 152 | 153 | for combo in combos: 154 | in_path = os.path.join( 155 | input_path, 'data', 156 | '{c.subject:03d}', 157 | '{c.gesture:03d}', 158 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 159 | 160 | out_dir = os.path.join( 161 | output_path, 162 | '{c.subject:03d}', 163 | '{c.gesture:03d}').format(c=combo) 164 | 165 | if os.path.isdir(out_dir) is False: 166 | os.makedirs(out_dir) 167 | 168 | 169 | data = sio.loadmat(in_path)['data'].astype(np.float32) 170 | 171 | print ("Subject %d Gesture %d Trial %d data loaded!" % (combo.subject, combo.gesture, combo.trial)) 172 | 173 | if filtering_type is 'lowpass': 174 | # data = np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 175 | data = np.transpose([butter_lowpass_filter(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 176 | print ("Subject %d Gesture %d Trial %d bandpass filtering finished!" % (combo.subject, combo.gesture, combo.trial)) 177 | else: 178 | pass 179 | 180 | 181 | 182 | chnum = data.shape[1]; 183 | data = get_segments(data, window, stride) 184 | data = data.reshape(-1, window, chnum) 185 | 186 | for i in range(len(feature_list)): 187 | feature = [np.transpose(extract_emg_feature(seg.T, feature_list[i])) for seg in data] 188 | feature = np.array(feature) 189 | out_path = os.path.join( 190 | out_dir, 191 | '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}_{1}.mat').format(combo, feature_list[i]) 192 | sio.savemat(out_path, {'data': feature, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 193 | print ("Subject %d Gesture %d Trial %d %s saved!" % (combo.subject, combo.gesture, combo.trial, feature_list[i])) 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /sigr/feature_map/ninapro_feature_map_extractor-49.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | import numpy as np 4 | import emg_features 5 | #from activity_image import get_signal_img 6 | from itertools import product 7 | from collections import namedtuple 8 | from joblib import Parallel, delayed 9 | #import cv2 10 | #from ..utils import butter_lowpass_filter as lowpass 11 | 12 | 13 | subjects = list(range(0,13)) 14 | gestures = list(range(1,53)) 15 | trials = list(range(10)) 16 | #input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 17 | #output_path = '/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw' 18 | 19 | 20 | input_path = '/home/weiwentao/data/ninapro-db1' 21 | 22 | 23 | filtering_type = 'lowpass' 24 | framerate = 100 25 | 26 | window_length_ms = 200 27 | window_stride_ms = 10 28 | 29 | window = window_length_ms*framerate/1000 30 | stride = window_stride_ms*framerate/1000 31 | 32 | output_path = ('/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw-prepro-%s-win-%d-stride-%d' % (filtering_type, window, stride)) 33 | 34 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False) 35 | 36 | 37 | #feature_list = ['dwt','dwpt','dwpt_mean','dwpt_sd','mav','mavslpphinyomark','wl','sscbestninapro1','rms','hemg20','mdwtdb1ninapro'] 38 | 39 | 40 | feature_list = ['mavslpphinyomark','arc','cc', 'arr', 'fft_power_db1'] 41 | 42 | def get_combos(*args): 43 | for arg in args: 44 | if isinstance(arg, tuple): 45 | arg = [arg] 46 | for a in arg: 47 | yield Combo(*a) 48 | 49 | 50 | 51 | #the following functions can be loaded from ..utils 52 | 53 | 54 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 55 | from scipy.signal import butter, lfilter, filtfilt 56 | 57 | nyq = 0.5 * fs 58 | cut = cut / nyq 59 | 60 | b, a = butter(order, cut, btype='low') 61 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 62 | return y 63 | 64 | 65 | def get_segments(data, window, stride): 66 | return windowed_view( 67 | data.flat, 68 | window * data.shape[1], 69 | (window-stride)* data.shape[1] 70 | ) 71 | 72 | def windowed_view(arr, window, overlap): 73 | from numpy.lib.stride_tricks import as_strided 74 | arr = np.asarray(arr) 75 | window_step = window - overlap 76 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 77 | window) 78 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 79 | arr.strides[-1:]) 80 | return as_strided(arr, shape=new_shape, strides=new_strides) 81 | 82 | 83 | #def dft(data): 84 | # f = np.fft.fft2(data) 85 | # fshift = np.fft.fftshift(f) 86 | # magnitude_spectrum = 20*np.log(np.abs(fshift)) 87 | # return magnitude_spectrum 88 | # 89 | ##the following functions can be loaded from .. 90 | # 91 | #def dft_dy(data): 92 | # data = data.T 93 | # n = data.shape[-1] 94 | # window = np.hanning(n) 95 | # windowed = data * window 96 | # spectrum = np.fft.fft(windowed) 97 | # return np.abs(spectrum) 98 | 99 | 100 | def feature_map(x): 101 | 102 | res = [] 103 | for i in range(x.shape[0]): 104 | single_channel = [] 105 | for j in range(len(feature_list)): 106 | func = 'emg_features.emg_'+feature_list[j] 107 | single_channel.append(eval(str(func))(x[i,:])) 108 | # print single_channel[0].shape 109 | single_channel = np.hstack(single_channel) 110 | res.append(single_channel) 111 | res =np.vstack(res) 112 | return res 113 | 114 | def extract_emg_feature(x, feature_name): 115 | res = [] 116 | for i in range(x.shape[0]): 117 | func = 'emg_features.emg_'+feature_name 118 | res.append(eval(str(func))(x[i,:])) 119 | res =np.vstack(res) 120 | return res 121 | 122 | def emg_feature_extraction_parallel(out_dir, combo, data, feature_name): 123 | feature = [np.transpose(extract_emg_feature(seg.T, feature_name)) for seg in data] 124 | feature = np.array(feature) 125 | out_path = os.path.join( 126 | out_dir, 127 | '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}_{1}.mat').format(combo, feature_name) 128 | sio.savemat(out_path, {'data': feature, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 129 | print ("Subject %d Gesture %d Trial %d %s saved!" % (combo.subject, combo.gesture, combo.trial, feature_name)) 130 | 131 | 132 | if __name__ == '__main__': 133 | 134 | print ("NinaPro feature map generation, use window = %d frames, stride = %d frames" % (window, stride)) 135 | 136 | combos = get_combos(product(subjects, gestures, trials)) 137 | 138 | combos = list(combos) 139 | 140 | 141 | # # for test only 142 | # pre_data = [] 143 | # feature_dim = 0 144 | # for i in range(len(feature_list)): 145 | # input_dir = os.path.join(output_path, 146 | # '000', 147 | # '000', 148 | # '000_000_000_{0}.mat' 149 | # ).format(feature_list[i]) 150 | # mat = sio.loadmat(input_dir) 151 | # data = mat['data'].astype(np.float32) 152 | # feature_dim = feature_dim + data.shape[1] 153 | # pre_data.append(data) 154 | # pre_data = np.concatenate(pre_data, axis=1) 155 | # feature_dim = pre_data.shape[1] 156 | 157 | 158 | 159 | 160 | for combo in combos: 161 | in_path = os.path.join( 162 | input_path, 'data', 163 | '{c.subject:03d}', 164 | '{c.gesture:03d}', 165 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 166 | 167 | out_dir = os.path.join( 168 | output_path, 169 | '{c.subject:03d}', 170 | '{c.gesture:03d}').format(c=combo) 171 | 172 | if os.path.isdir(out_dir) is False: 173 | os.makedirs(out_dir) 174 | 175 | 176 | data = sio.loadmat(in_path)['data'].astype(np.float32) 177 | 178 | print ("Subject %d Gesture %d Trial %d data loaded!" % (combo.subject, combo.gesture, combo.trial)) 179 | 180 | if filtering_type is 'lowpass': 181 | # data = np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 182 | data = np.transpose([butter_lowpass_filter(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 183 | print ("Subject %d Gesture %d Trial %d bandpass filtering finished!" % (combo.subject, combo.gesture, combo.trial)) 184 | else: 185 | pass 186 | 187 | 188 | 189 | chnum = data.shape[1]; 190 | data = get_segments(data, window, stride) 191 | data = data.reshape(-1, window, chnum) 192 | 193 | Parallel(n_jobs=4)(delayed(emg_feature_extraction_parallel)(out_dir, combo, data, feature_list[i]) for i in range(len(feature_list))) 194 | 195 | # for i in range(len(feature_list)): 196 | # feature = [np.transpose(extract_emg_feature(seg.T, feature_list[i])) for seg in data] 197 | # feature = np.array(feature) 198 | # out_path = os.path.join( 199 | # out_dir, 200 | # '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}_{1}.mat').format(combo, feature_list[i]) 201 | # sio.savemat(out_path, {'data': feature, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 202 | # print ("Subject %d Gesture %d Trial %d %s saved!" % (combo.subject, combo.gesture, combo.trial, feature_list[i])) 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /sigr/feature_map/ninapro_feature_map_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | import numpy as np 4 | import emg_features 5 | #from activity_image import get_signal_img 6 | from itertools import product 7 | from collections import namedtuple 8 | from joblib import Parallel, delayed 9 | #import cv2 10 | #from ..utils import butter_lowpass_filter as lowpass 11 | 12 | 13 | subjects = list(range(27)) 14 | gestures = list(range(1,53)) 15 | trials = list(range(10)) 16 | #input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 17 | #output_path = '/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw' 18 | 19 | 20 | input_path = '/home/weiwentao/data/ninapro-db1' 21 | 22 | 23 | filtering_type = 'lowpass' 24 | framerate = 100 25 | 26 | window_length_ms = 200 27 | window_stride_ms = 10 28 | 29 | window = window_length_ms*framerate/1000 30 | stride = window_stride_ms*framerate/1000 31 | 32 | output_path = ('/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw-prepro-%s-win-%d-stride-%d' % (filtering_type, window, stride)) 33 | 34 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False) 35 | 36 | 37 | #feature_list = ['dwt','dwpt','dwpt_mean','dwpt_sd','mav','mavslpphinyomark','wl','sscbestninapro1','rms','hemg20','mdwtdb1ninapro'] 38 | 39 | 40 | feature_list = ['mavslpphinyomark','arc','cc', 'arr', 'fft_power_db1'] 41 | 42 | def get_combos(*args): 43 | for arg in args: 44 | if isinstance(arg, tuple): 45 | arg = [arg] 46 | for a in arg: 47 | yield Combo(*a) 48 | 49 | 50 | 51 | #the following functions can be loaded from ..utils 52 | 53 | 54 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 55 | from scipy.signal import butter, lfilter, filtfilt 56 | 57 | nyq = 0.5 * fs 58 | cut = cut / nyq 59 | 60 | b, a = butter(order, cut, btype='low') 61 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 62 | return y 63 | 64 | 65 | def get_segments(data, window, stride): 66 | return windowed_view( 67 | data.flat, 68 | window * data.shape[1], 69 | (window-stride)* data.shape[1] 70 | ) 71 | 72 | def windowed_view(arr, window, overlap): 73 | from numpy.lib.stride_tricks import as_strided 74 | arr = np.asarray(arr) 75 | window_step = window - overlap 76 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 77 | window) 78 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 79 | arr.strides[-1:]) 80 | return as_strided(arr, shape=new_shape, strides=new_strides) 81 | 82 | 83 | #def dft(data): 84 | # f = np.fft.fft2(data) 85 | # fshift = np.fft.fftshift(f) 86 | # magnitude_spectrum = 20*np.log(np.abs(fshift)) 87 | # return magnitude_spectrum 88 | # 89 | ##the following functions can be loaded from .. 90 | # 91 | #def dft_dy(data): 92 | # data = data.T 93 | # n = data.shape[-1] 94 | # window = np.hanning(n) 95 | # windowed = data * window 96 | # spectrum = np.fft.fft(windowed) 97 | # return np.abs(spectrum) 98 | 99 | 100 | def feature_map(x): 101 | 102 | res = [] 103 | for i in range(x.shape[0]): 104 | single_channel = [] 105 | for j in range(len(feature_list)): 106 | func = 'emg_features.emg_'+feature_list[j] 107 | single_channel.append(eval(str(func))(x[i,:])) 108 | # print single_channel[0].shape 109 | single_channel = np.hstack(single_channel) 110 | res.append(single_channel) 111 | res =np.vstack(res) 112 | return res 113 | 114 | def extract_emg_feature(x, feature_name): 115 | res = [] 116 | for i in range(x.shape[0]): 117 | func = 'emg_features.emg_'+feature_name 118 | res.append(eval(str(func))(x[i,:])) 119 | res =np.vstack(res) 120 | return res 121 | 122 | 123 | def emg_feature_extraction_parallel(out_dir, combo, data, feature_name): 124 | feature = [np.transpose(extract_emg_feature(seg.T, feature_name)) for seg in data] 125 | feature = np.array(feature) 126 | out_path = os.path.join( 127 | out_dir, 128 | '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}_{1}.mat').format(combo, feature_name) 129 | sio.savemat(out_path, {'data': feature, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 130 | print ("Subject %d Gesture %d Trial %d %s saved!" % (combo.subject, combo.gesture, combo.trial, feature_name)) 131 | 132 | 133 | 134 | 135 | 136 | if __name__ == '__main__': 137 | 138 | print ("NinaPro feature map generation, use window = %d frames, stride = %d frames" % (window, stride)) 139 | 140 | combos = get_combos(product(subjects, gestures, trials)) 141 | 142 | combos = list(combos) 143 | 144 | 145 | # # for test only 146 | # pre_data = [] 147 | # feature_dim = 0 148 | # for i in range(len(feature_list)): 149 | # input_dir = os.path.join(output_path, 150 | # '000', 151 | # '000', 152 | # '000_000_000_{0}.mat' 153 | # ).format(feature_list[i]) 154 | # mat = sio.loadmat(input_dir) 155 | # data = mat['data'].astype(np.float32) 156 | # feature_dim = feature_dim + data.shape[1] 157 | # pre_data.append(data) 158 | # pre_data = np.concatenate(pre_data, axis=1) 159 | # feature_dim = pre_data.shape[1] 160 | 161 | 162 | 163 | 164 | for combo in combos: 165 | in_path = os.path.join( 166 | input_path, 'data', 167 | '{c.subject:03d}', 168 | '{c.gesture:03d}', 169 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 170 | 171 | out_dir = os.path.join( 172 | output_path, 173 | '{c.subject:03d}', 174 | '{c.gesture:03d}').format(c=combo) 175 | 176 | if os.path.isdir(out_dir) is False: 177 | os.makedirs(out_dir) 178 | 179 | 180 | data = sio.loadmat(in_path)['data'].astype(np.float32) 181 | 182 | print ("Subject %d Gesture %d Trial %d data loaded!" % (combo.subject, combo.gesture, combo.trial)) 183 | 184 | if filtering_type is 'lowpass': 185 | # data = np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 186 | data = np.transpose([butter_lowpass_filter(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 187 | print ("Subject %d Gesture %d Trial %d bandpass filtering finished!" % (combo.subject, combo.gesture, combo.trial)) 188 | else: 189 | pass 190 | 191 | 192 | 193 | chnum = data.shape[1]; 194 | data = get_segments(data, window, stride) 195 | data = data.reshape(-1, window, chnum) 196 | 197 | Parallel(n_jobs=4)(delayed(emg_feature_extraction_parallel)(out_dir, combo, data, feature_list[i]) for i in range(len(feature_list))) 198 | 199 | # for i in range(len(feature_list)): 200 | # feature = [np.transpose(extract_emg_feature(seg.T, feature_list[i])) for seg in data] 201 | # feature = np.array(feature) 202 | # out_path = os.path.join( 203 | # out_dir, 204 | # '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}_{1}.mat').format(combo, feature_list[i]) 205 | # sio.savemat(out_path, {'data': feature, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 206 | # print ("Subject %d Gesture %d Trial %d %s saved!" % (combo.subject, combo.gesture, combo.trial, feature_list[i])) 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | -------------------------------------------------------------------------------- /sigr/fft.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | 4 | 5 | def fft(data, fs): 6 | n = data.shape[-1] 7 | window = np.hanning(n) 8 | windowed = data * window 9 | spectrum = np.fft.fft(windowed) 10 | freq = np.fft.fftfreq(n, 1 / fs) 11 | half_n = np.ceil(n / 2) 12 | spectrum_half = (2 / n) * spectrum[..., :half_n] 13 | freq_half = freq[:half_n] 14 | return freq_half, np.abs(spectrum_half) 15 | -------------------------------------------------------------------------------- /sigr/genIndex.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | 4 | 5 | def genIndex(chanums): 6 | 7 | index = [] 8 | i = 1 9 | j = i+1 10 | 11 | if (chanums % 2) == 0: 12 | Ns = chanums+1 13 | else: 14 | Ns = chanums 15 | 16 | 17 | index.append(1) 18 | t = chr(i+ord('A')) 19 | while(i!=j): 20 | l = "" 21 | l = l+chr(i+ord('A')) 22 | l = l+chr(j+ord('A')) 23 | r = "" 24 | r = r+chr(j+ord('A')) 25 | r = r+chr(i+ord('A')) 26 | if(j>Ns): 27 | j = 1 28 | elif(t.find(l)==-1 and t.find(r)==-1): 29 | index.append(j) 30 | t = t+chr(j+ord('A')) 31 | i = j 32 | j = i+1 33 | else: 34 | j = j+1 35 | 36 | 37 | 38 | new_index = [] 39 | if (chanums % 2) == 0: 40 | for i in range(len(index)): 41 | if index[i] != chanums+1: 42 | new_index.append(index[i]) 43 | index = new_index 44 | 45 | index = np.array(index) 46 | index = index-1 47 | return index -------------------------------------------------------------------------------- /sigr/parse_log.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import re 3 | import numpy as np 4 | 5 | 6 | def div(up, down): 7 | try: 8 | return up / down 9 | except: 10 | return np.nan 11 | 12 | 13 | def parse_log(path): 14 | with open(path, 'r') as f: 15 | lines = f.readlines() 16 | 17 | res = [re.compile('.*Epoch\[(\d+)\] Train-accuracy(?:\[g\])?=([.\d]+)'), 18 | re.compile('.*Epoch\[(\d+)\] Validation-accuracy(?:\[g\])?=([.\d]+)'), 19 | re.compile('.*Epoch\[(\d+)\] Time.*=([.\d]+)')] 20 | 21 | data = {} 22 | for l in lines: 23 | i = 0 24 | for r in res: 25 | m = r.match(l) 26 | if m is not None: 27 | break 28 | i += 1 29 | if m is None: 30 | continue 31 | 32 | assert len(m.groups()) == 2 33 | epoch = int(m.groups()[0]) 34 | val = float(m.groups()[1]) 35 | 36 | if epoch not in data: 37 | data[epoch] = [0] * len(res) * 2 38 | 39 | data[epoch][i*2] += val 40 | data[epoch][i*2+1] += 1 41 | 42 | df = [] 43 | for k, v in data.items(): 44 | try: 45 | df.append({ 46 | 'epoch': k + 1, 47 | 'train': div(v[0], v[1]), 48 | 'val': div(v[2], v[3]), 49 | 'time': div(v[4], v[5]) 50 | }) 51 | except: 52 | pass 53 | 54 | import pandas as pd 55 | return pd.DataFrame(df) 56 | -------------------------------------------------------------------------------- /sigr/sklearn_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from nose.tools import assert_equal 3 | import mxnet as mx 4 | import numpy as np 5 | from logbook import Logger 6 | import joblib as jb 7 | from .base_module import BaseModule 8 | 9 | 10 | logger = Logger('sigr') 11 | 12 | 13 | class SklearnModule(BaseModule): 14 | 15 | def _get_data_label(self, data_iter): 16 | data = [] 17 | label = [] 18 | for batch in data_iter: 19 | data.append(batch.data[0].asnumpy().reshape( 20 | batch.data[0].shape[0], -1)) 21 | label.append(batch.label[0].asnumpy()) 22 | if batch.pad: 23 | data[-1] = data[-1][:-batch.pad] 24 | label[-1] = label[-1][:-batch.pad] 25 | data = np.vstack(data) 26 | label = np.hstack(label) 27 | assert_equal(len(data), len(label)) 28 | return data, label 29 | 30 | def fit(self, train_data, eval_data, eval_metric='acc', **kargs): 31 | snapshot = kargs.pop('snapshot') 32 | self.clf.fit(*self._get_data_label(train_data)) 33 | jb.dump(self.clf, snapshot + '-0001.params') 34 | 35 | if not isinstance(eval_metric, mx.metric.EvalMetric): 36 | eval_metric = mx.metric.create(eval_metric) 37 | data, label = self._get_data_label(eval_data) 38 | pred = self.clf.predict(data).astype(np.int64) 39 | prob = np.zeros((len(pred), pred.max() + 1)) 40 | prob[np.arange(len(prob)), pred] = 1 41 | eval_metric.update([mx.nd.array(label)], [mx.nd.array(prob)]) 42 | for name, val in eval_metric.get_name_value(): 43 | logger.info('Epoch[0] Validation-{}={}', name, val) 44 | 45 | 46 | class KNNModule(SklearnModule): 47 | 48 | def __init__(self): 49 | from sklearn.neighbors import KNeighborsClassifier as KNN 50 | self.clf = KNN() 51 | 52 | @classmethod 53 | def parse(cls, text, **kargs): 54 | if text == 'knn': 55 | return cls() 56 | 57 | 58 | class SVMModule(SklearnModule): 59 | 60 | def __init__(self): 61 | from sklearn.svm import LinearSVC 62 | self.clf = LinearSVC() 63 | 64 | @classmethod 65 | def parse(cls, text, **kargs): 66 | if text == 'svm': 67 | return cls() 68 | 69 | 70 | class RandomForestsModule(SklearnModule): 71 | 72 | def __init__(self): 73 | from sklearn.ensemble import RandomForestClassifier as RandomForests 74 | self.clf = RandomForests() 75 | 76 | @classmethod 77 | def parse(cls, text, **kargs): 78 | if text == 'random-forests': 79 | return cls() 80 | 81 | 82 | class LDAModule(SklearnModule): 83 | 84 | def __init__(self): 85 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA 86 | self.clf = LDA() 87 | 88 | @classmethod 89 | def parse(cls, text, **kargs): 90 | if text == 'lda': 91 | return cls() 92 | -------------------------------------------------------------------------------- /sigr/stacked_optical_flow/optical_flow_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | import numpy as np 4 | from itertools import product 5 | from collections import namedtuple 6 | import cv2 7 | #from ..utils import butter_lowpass_filter as lowpass 8 | 9 | dataset = 'csl' 10 | 11 | subjects = list(range(27)) 12 | gestures = list(range(53)) 13 | trials = list(range(10)) 14 | #input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 15 | #output_path = '/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw' 16 | 17 | 18 | input_path = 'Y:/duyu/misc' 19 | 20 | 21 | 22 | if dataset is 'csl': 23 | subjects = list(range(1, 6)) 24 | sessions = list(range(1, 6)) 25 | gestures = list(range(27)) 26 | trials = list(range(10)) 27 | framerate = 2048 28 | 29 | 30 | 31 | Combo = namedtuple('Combo', ['subject', 'session', 'gesture', 'trial'], verbose=False) 32 | 33 | #elif dataset is 'dba' or 'dbb' or 'dbc' 34 | 35 | 36 | 37 | 38 | filtering_type = 'lowpass' 39 | 40 | window_length_ms = 250 41 | window_stride_ms = 10 42 | 43 | window = window_length_ms*framerate/1000 44 | stride = window_stride_ms*framerate/1000 45 | 46 | output_path = ('Y:/semg/%s-feature/%s-prepro-%s-win-%d-stride-%d' % (dataset, dataset, filtering_type, window, stride)) 47 | 48 | 49 | #from contextlib import contextmanager 50 | #@contextmanager 51 | #def Context(log=None, parallel=False, level=None): 52 | # from .utils import logging_context 53 | # with logging_context(log, level=level): 54 | # if not parallel: 55 | # yield 56 | # else: 57 | # import joblib as jb 58 | # from multiprocessing import cpu_count 59 | # with jb.Parallel(n_jobs=cpu_count()) as par: 60 | # Context.parallel = par 61 | # yield 62 | 63 | 64 | 65 | def get_combos(*args): 66 | for arg in args: 67 | if isinstance(arg, tuple): 68 | arg = [arg] 69 | for a in arg: 70 | yield Combo(*a) 71 | 72 | 73 | 74 | #the following functions can be loaded from ..utils 75 | 76 | 77 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 78 | from scipy.signal import butter, lfilter, filtfilt 79 | 80 | nyq = 0.5 * fs 81 | cut = cut / nyq 82 | 83 | b, a = butter(order, cut, btype='low') 84 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 85 | return y 86 | 87 | 88 | def butter_bandpass_filter(data, lowcut, highcut, fs, order): 89 | from scipy.signal import butter, lfilter 90 | 91 | nyq = 0.5 * fs 92 | low = lowcut / nyq 93 | high = highcut / nyq 94 | 95 | b, a = butter(order, [low, high], btype='bandpass') 96 | y = lfilter(b, a, data) 97 | return y 98 | 99 | 100 | def median_filter(data, num_semg_row, num_semg_col): 101 | return np.array([median_filter(image, 3).ravel() for image 102 | in data.reshape(-1, num_semg_row, num_semg_col)]) 103 | 104 | 105 | def cslcut(data, framerate): 106 | begin, end = _csl_cut(data, framerate) 107 | return data[begin:end] 108 | 109 | def _csl_cut(data, framerate): 110 | window = int(np.round(150 * framerate / 2048)) 111 | data = data[:len(data) // window * window].reshape(-1, 150, data.shape[1]) 112 | rms = np.sqrt(np.mean(np.square(data), axis=1)) 113 | rms = [median_filter(image, 3).ravel() for image in rms.reshape(-1, 24, 7)] 114 | rms = np.mean(rms, axis=1) 115 | threshold = np.mean(rms) 116 | mask = rms > threshold 117 | for i in range(1, len(mask) - 1): 118 | if not mask[i] and mask[i - 1] and mask[i + 1]: 119 | mask[i] = True 120 | from .. import utils 121 | begin, end = max(utils.continuous_segments(mask), 122 | key=lambda s: (mask[s[0]], s[1] - s[0])) 123 | return begin * window, end * window 124 | 125 | def downsample(data, step): 126 | return data[::step].copy() 127 | 128 | 129 | 130 | def get_segments(data, window, stride): 131 | return windowed_view( 132 | data.flat, 133 | window * data.shape[1], 134 | (window-stride)* data.shape[1] 135 | ) 136 | 137 | def windowed_view(arr, window, overlap): 138 | from numpy.lib.stride_tricks import as_strided 139 | arr = np.asarray(arr) 140 | window_step = window - overlap 141 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 142 | window) 143 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 144 | arr.strides[-1:]) 145 | return as_strided(arr, shape=new_shape, strides=new_strides) 146 | 147 | 148 | def dft(data): 149 | f = np.fft.fft2(data) 150 | fshift = np.fft.fftshift(f) 151 | magnitude_spectrum = 20*np.log(np.abs(fshift)) 152 | return magnitude_spectrum 153 | 154 | #the following functions can be loaded from .. 155 | 156 | def dft_dy(data): 157 | data = data.T 158 | n = data.shape[-1] 159 | window = np.hanning(n) 160 | windowed = data * window 161 | spectrum = np.fft.fft(windowed) 162 | return np.abs(spectrum) 163 | 164 | 165 | 166 | if __name__ == '__main__': 167 | 168 | print ("%s stacked optical flow generation, use window = %d frames, stride = %d frames" % (dataset, window, stride)) 169 | 170 | if dataset is 'csl': 171 | 172 | combos = get_combos(product(subjects, sessions, gestures, trials)) 173 | 174 | combos = list(combos) 175 | 176 | for combo in combos: 177 | in_path = os.path.join( 178 | input_path, dataset, 179 | 'subject%d' % combo.subject, 180 | 'session%d' % combo.session, 181 | 'gest%d.mat' % combo.gesture) 182 | 183 | out_dir = os.path.join( 184 | output_path, 185 | '{c.subject:03d}', 186 | '{c.session:03d}', 187 | '{c.gesture:03d}').format(c=combo) 188 | 189 | if os.path.isdir(out_dir) is False: 190 | os.makedirs(out_dir) 191 | 192 | data = sio.loadmat(in_path)['gestures'] 193 | data = [np.transpose(np.delete(segment.astype(np.float32), np.s_[7:192:8], 0)) 194 | for segment in data.flat] 195 | 196 | 197 | 198 | print ("Subject %d Session %d Gesture %d data loaded!" % (combo.subject, combo.session, combo.gesture)) 199 | 200 | 201 | 202 | # if filtering_type is 'lowpass': 203 | ## data = np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 204 | # data = np.transpose([butter_lowpass_filter(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 205 | # print ("Subject %d Gesture %d Trial %d bandpass filtering finished!" % (combo.subject, combo.gesture, combo.trial)) 206 | # else: 207 | # pass 208 | # 209 | # 210 | # 211 | # chnum = data.shape[1]; 212 | # data = get_segments(data, window, stride) 213 | # data = data.reshape(-1, window, chnum) 214 | # 215 | # data = np.transpose(get_signal_img(data.T)) 216 | # 217 | # out_path = os.path.join( 218 | # out_dir, 219 | # '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}_sigimg.mat').format(c=combo) 220 | # sio.savemat(out_path, {'data': data, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 221 | # 222 | # print ("Subject %d Gesture %d Trial %d sig image saved!" % (combo.subject, combo.gesture, combo.trial)) 223 | # 224 | ### for test only 225 | ## data = data[0:25,] 226 | ## data = dft(data) 227 | ## data = cv2.resize(data,None,fx=20,fy=20) 228 | ## cv2.imshow('image',data) 229 | ## cv2.waitKey(0) 230 | ## cv2.destroyAllWindows() 231 | # 232 | # 233 | # 234 | # 235 | # 236 | # data = [dft(seg) for seg in data] 237 | # data = np.array(data) 238 | # 239 | # 240 | # print ("Subject %d Gesture %d Trial %d data windowing finished!" % (combo.subject, combo.gesture, combo.trial)) 241 | # 242 | # out_path = os.path.join( 243 | # out_dir, 244 | # '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}_actimg.mat').format(c=combo) 245 | # sio.savemat(out_path, {'data': data, 'label': combo.gesture, 'subject': combo.subject, 'trial':combo.trial}) 246 | 247 | 248 | 249 | 250 | -------------------------------------------------------------------------------- /sigr/stacked_optical_flow/test_optical_flow_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | import numpy as np 4 | import cv2 5 | from itertools import product 6 | from collections import namedtuple 7 | 8 | 9 | #import cv2 10 | #from ..utils import butter_lowpass_filter as lowpass 11 | 12 | 13 | subjects = list(range(27)) 14 | gestures = list(range(53)) 15 | trials = list(range(10)) 16 | #input_path = '/home/weiwentao/public/duyu/misc/ninapro-db1' 17 | #output_path = '/home/weiwentao/public/semg/ninapro-feature/ninapro-db1-var-raw' 18 | 19 | 20 | input_path = 'Y:/duyu/misc/ninapro-db1' 21 | 22 | 23 | filtering_type = 'none' 24 | framerate = 100 25 | 26 | window_length_ms = 150 27 | window_stride_ms = 10 28 | 29 | window = window_length_ms*framerate/1000 30 | stride = window_stride_ms*framerate/1000 31 | 32 | output_path = ('Y:/semg/ninapro-feature/TEST-ninapro-db1-var-raw-prepro-%s-win-%d-stride-%d' % (filtering_type, window, stride)) 33 | 34 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False) 35 | 36 | def get_combos(*args): 37 | for arg in args: 38 | if isinstance(arg, tuple): 39 | arg = [arg] 40 | for a in arg: 41 | yield Combo(*a) 42 | 43 | 44 | 45 | #the following functions can be loaded from ..utils 46 | 47 | 48 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 49 | from scipy.signal import butter, lfilter, filtfilt 50 | 51 | nyq = 0.5 * fs 52 | cut = cut / nyq 53 | 54 | b, a = butter(order, cut, btype='low') 55 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 56 | return y 57 | 58 | 59 | def get_segments(data, window, stride): 60 | return windowed_view( 61 | data.flat, 62 | window * data.shape[1], 63 | (window-stride)* data.shape[1] 64 | ) 65 | 66 | def windowed_view(arr, window, overlap): 67 | from numpy.lib.stride_tricks import as_strided 68 | arr = np.asarray(arr) 69 | window_step = window - overlap 70 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 71 | window) 72 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 73 | arr.strides[-1:]) 74 | return as_strided(arr, shape=new_shape, strides=new_strides) 75 | 76 | 77 | def dft(data): 78 | f = np.fft.fft2(data) 79 | fshift = np.fft.fftshift(f) 80 | magnitude_spectrum = 20*np.log(np.abs(fshift)) 81 | return magnitude_spectrum 82 | 83 | #the following functions can be loaded from .. 84 | 85 | def dft_dy(data): 86 | data = data.T 87 | n = data.shape[-1] 88 | window = np.hanning(n) 89 | windowed = data * window 90 | spectrum = np.fft.fft(windowed) 91 | return np.abs(spectrum.T) 92 | 93 | def lucas_kanade_np(im1, im2, win=2): 94 | assert im1.shape == im2.shape 95 | I_x = np.zeros(im1.shape) 96 | I_y = np.zeros(im1.shape) 97 | I_t = np.zeros(im1.shape) 98 | I_x[1:-1, 1:-1] = (im1[1:-1, 2:] - im1[1:-1, :-2]) / 2 99 | I_y[1:-1, 1:-1] = (im1[2:, 1:-1] - im1[:-2, 1:-1]) / 2 100 | I_t[1:-1, 1:-1] = im1[1:-1, 1:-1] - im2[1:-1, 1:-1] 101 | params = np.zeros(im1.shape + (5,)) #Ix2, Iy2, Ixy, Ixt, Iyt 102 | params[..., 0] = I_x * I_x # I_x2 103 | params[..., 1] = I_y * I_y # I_y2 104 | params[..., 2] = I_x * I_y # I_xy 105 | params[..., 3] = I_x * I_t # I_xt 106 | params[..., 4] = I_y * I_t # I_yt 107 | del I_x, I_y, I_t 108 | cum_params = np.cumsum(np.cumsum(params, axis=0), axis=1) 109 | del params 110 | win_params = (cum_params[2 * win + 1:, 2 * win + 1:] - 111 | cum_params[2 * win + 1:, :-1 - 2 * win] - 112 | cum_params[:-1 - 2 * win, 2 * win + 1:] + 113 | cum_params[:-1 - 2 * win, :-1 - 2 * win]) 114 | del cum_params 115 | op_flow = np.zeros(im1.shape + (2,)) 116 | det = win_params[...,0] * win_params[..., 1] - win_params[..., 2] **2 117 | op_flow_x = np.where(det != 0, 118 | (win_params[..., 1] * win_params[..., 3] - 119 | win_params[..., 2] * win_params[..., 4]) / det, 120 | 0) 121 | op_flow_y = np.where(det != 0, 122 | (win_params[..., 0] * win_params[..., 4] - 123 | win_params[..., 2] * win_params[..., 3]) / det, 124 | 0) 125 | op_flow[win + 1: -1 - win, win + 1: -1 - win, 0] = op_flow_x[:-1, :-1] 126 | op_flow[win + 1: -1 - win, win + 1: -1 - win, 1] = op_flow_y[:-1, :-1] 127 | return op_flow 128 | 129 | if __name__ == '__main__': 130 | 131 | print ("NinaPro activity image generation, use window = %d frames, stride = %d frames" % (window, stride)) 132 | 133 | combos = get_combos(product(subjects, gestures, trials)) 134 | 135 | combos = list(combos) 136 | 137 | for combo in combos: 138 | in_path = os.path.join( 139 | input_path, 'data', 140 | '{c.subject:03d}', 141 | '{c.gesture:03d}', 142 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 143 | 144 | out_dir = os.path.join( 145 | output_path, 146 | '{c.subject:03d}', 147 | '{c.gesture:03d}').format(c=combo) 148 | 149 | if os.path.isdir(out_dir) is False: 150 | os.makedirs(out_dir) 151 | 152 | 153 | data = sio.loadmat(in_path)['data'].astype(np.float32) 154 | 155 | print ("Subject %d Gesture %d Trial %d data loaded!" % (combo.subject, combo.gesture, combo.trial)) 156 | 157 | if filtering_type is 'lowpass': 158 | # data = np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 159 | data = np.transpose([butter_lowpass_filter(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 160 | print ("Subject %d Gesture %d Trial %d bandpass filtering finished!" % (combo.subject, combo.gesture, combo.trial)) 161 | else: 162 | pass 163 | 164 | 165 | data = data.T 166 | chnum = data.shape[0] 167 | 168 | first_frame = data[:,30] 169 | second_frame = data[:,50] 170 | 171 | first_frame = np.reshape(first_frame, (10,1)) 172 | second_frame = np.reshape(second_frame, (10,1)) 173 | 174 | # first_frame = cv2.cv.fromarray(first_frame) 175 | # second_frame = cv2.cv.fromarray(second_frame) 176 | 177 | flow = lucas_kanade_np(first_frame, second_frame) 178 | 179 | print flow 180 | -------------------------------------------------------------------------------- /sigr/symbol_multistream.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computer-animation-perception-group/sEMG-based-mscnn/5cee7d27e087d9a3f198162e6203ad31dc224c9d/sigr/symbol_multistream.py -------------------------------------------------------------------------------- /sigr/train_semimyo.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import click 3 | import mxnet as mx 4 | from logbook import Logger 5 | from pprint import pformat 6 | import os 7 | from .utils import packargs, Bunch 8 | from .module_semimyo import Module 9 | from .data import Preprocess, Dataset 10 | from . import Context, constant 11 | 12 | 13 | logger = Logger('semimyo') 14 | 15 | 16 | @click.group() 17 | def cli(): 18 | pass 19 | 20 | 21 | @cli.command() 22 | @click.option('--glove-loss-weight', type=float, required=True) 23 | @click.option('--num-epoch', type=int, default=constant.NUM_EPOCH, help='Maximum epoches') 24 | @click.option('--lr-step', type=int, multiple=True, default=constant.LR_STEP, help='Epoch numbers to decay learning rate') 25 | @click.option('--lr-factor', type=float, multiple=True) 26 | @click.option('--batch-size', type=int, default=constant.BATCH_SIZE, help='Batch size') 27 | @click.option('--lr', type=float, default=constant.LR, help='Base learning rate') 28 | @click.option('--wd', type=float, default=constant.WD, help='Weight decay') 29 | @click.option('--gpu', type=int, multiple=True, default=[0]) 30 | @click.option('--log', type=click.Path(), help='Path of the logging file') 31 | @click.option('--snapshot', type=click.Path(), help='Snapshot prefix') 32 | @click.option('--root', type=click.Path(), help='Root path of the experiment, auto create if not exists') 33 | @click.option('--params', type=click.Path(exists=True), help='Inital weights') 34 | @click.option('--ignore-params', multiple=True, help='Ignore params in --params with regex') 35 | @click.option('--adabn', is_flag=True, help='AdaBN for model adaptation, must be used with --num-mini-batch') 36 | @click.option('--num-adabn-epoch', type=int, default=constant.NUM_ADABN_EPOCH) 37 | @click.option('--num-conv-layer', type=int, default=constant.NUM_CONV_LAYER, help='Conv layers') 38 | @click.option('--num-conv-filter', type=int, default=constant.NUM_CONV_FILTER, help='Kernels of the conv layers') 39 | @click.option('--num-lc-layer', type=int, default=constant.NUM_LC_LAYER, help='LC layers') 40 | @click.option('--num-lc-hidden', type=int, default=constant.NUM_LC_HIDDEN, help='Kernels of the LC layers') 41 | @click.option('--lc-kernel', type=int, default=constant.LC_KERNEL) 42 | @click.option('--lc-stride', type=int, default=constant.LC_STRIDE) 43 | @click.option('--lc-pad', type=int, default=constant.LC_PAD) 44 | @click.option('--num-fc-layer', type=int, default=constant.NUM_FC_LAYER, help='FC layers') 45 | @click.option('--num-fc-hidden', type=int, default=constant.NUM_FC_HIDDEN, help='Kernels of the FC layers') 46 | @click.option('--num-bottleneck', type=int, default=constant.NUM_BOTTLENECK, help='Kernels of the bottleneck layer') 47 | @click.option('--dropout', type=float, default=constant.DROPOUT, help='Dropout ratio') 48 | @click.option('--num-glove-layer', type=int, required=True) 49 | @click.option('--num-glove-hidden', type=int, required=True) 50 | @click.option('--num-mini-batch', type=int, default=constant.NUM_MINI_BATCH, help='Split data into mini-batches') 51 | @click.option('--num-eval-epoch', type=int, default=1) 52 | @click.option('--snapshot-period', type=int, default=constant.SNAPSHOT_PERIOD) 53 | @click.option('--fix-params', multiple=True) 54 | @click.option('--decay-all/--no-decay-all', default=constant.DECAY_ALL) 55 | @click.option('--preprocess', callback=lambda ctx, param, value: Preprocess.parse(value)) 56 | @click.option('--dataset', type=click.Choice(['s21', 'csl', 57 | 'dba', 'dbb', 'dbc', 58 | 'ninapro-db1-matlab-lowpass', 59 | 'ninapro-db1/caputo', 60 | 'ninapro-db1', 61 | 'ninapro-db1-raw/semg-glove', 62 | 'ninapro-db1/g53', 63 | 'ninapro-db1/g5', 64 | 'ninapro-db1/g8', 65 | 'ninapro-db1/g12']), required=True) 66 | @click.option('--balance-gesture', type=float, default=0) 67 | @click.option('--module', type=click.Choice(['convnet', 68 | 'knn', 69 | 'svm', 70 | 'random-forests', 71 | 'lda']), default='convnet') 72 | @click.option('--amplitude-weighting', is_flag=True) 73 | @click.option('--fold', type=int, required=True, help='Fold number of the crossval experiment') 74 | @click.option('--crossval-type', type=click.Choice(['intra-session', 75 | 'universal-intra-session', 76 | 'inter-session', 77 | 'universal-inter-session', 78 | 'intra-subject', 79 | 'universal-intra-subject', 80 | 'inter-subject', 81 | 'one-fold-intra-subject', 82 | 'universal-one-fold-intra-subject']), required=True) 83 | @packargs 84 | def crossval(args): 85 | if args.root: 86 | if args.log: 87 | args.log = os.path.join(args.root, args.log) 88 | if args.snapshot: 89 | args.snapshot = os.path.join(args.root, args.snapshot) 90 | 91 | with Context(args.log, parallel=True): 92 | logger.info('Args:\n{}', pformat(args)) 93 | for i in range(args.num_epoch): 94 | path = args.snapshot + '-%04d.params' % (i + 1) 95 | if os.path.exists(path): 96 | logger.info('Found snapshot {}, exit', path) 97 | return 98 | 99 | dataset = Dataset.from_name(args.dataset) 100 | get_crossval_data = getattr(dataset, 'get_%s_data' % args.crossval_type.replace('-', '_')) 101 | train, val = get_crossval_data( 102 | batch_size=args.batch_size, 103 | fold=args.fold, 104 | preprocess=args.preprocess, 105 | num_mini_batch=args.num_mini_batch, 106 | balance_gesture=args.balance_gesture, 107 | amplitude_weighting=args.amplitude_weighting 108 | ) 109 | logger.info('Train samples: {}', train.num_sample) 110 | logger.info('Val samples: {}', val.num_sample) 111 | mod = Module.parse( 112 | args.module, 113 | adabn=args.adabn, 114 | num_adabn_epoch=args.num_adabn_epoch, 115 | for_training=True, 116 | num_eval_epoch=args.num_eval_epoch, 117 | snapshot_period=args.snapshot_period, 118 | symbol_kargs=dict( 119 | num_gesture=dataset.num_gesture, 120 | num_glove=dataset.num_glove, 121 | num_semg_row=dataset.num_semg_row, 122 | num_semg_col=dataset.num_semg_col, 123 | num_conv_layer=args.num_conv_layer, 124 | num_conv_filter=args.num_conv_filter, 125 | num_lc_layer=args.num_lc_layer, 126 | num_lc_hidden=args.num_lc_hidden, 127 | lc_kernel=args.lc_kernel, 128 | lc_stride=args.lc_stride, 129 | lc_pad=args.lc_pad, 130 | num_fc_layer=args.num_fc_layer, 131 | num_fc_hidden=args.num_fc_hidden, 132 | num_bottleneck=args.num_bottleneck, 133 | dropout=args.dropout, 134 | num_glove_layer=args.num_glove_layer, 135 | num_glove_hidden=args.num_glove_hidden, 136 | num_mini_batch=args.num_mini_batch, 137 | glove_loss_weight=args.glove_loss_weight 138 | ), 139 | context=[mx.gpu(i) for i in args.gpu] 140 | ) 141 | mod.fit( 142 | train_data=train, 143 | eval_data=val, 144 | num_epoch=args.num_epoch, 145 | num_train=train.num_sample, 146 | batch_size=args.batch_size, 147 | lr_step=args.lr_step, 148 | lr_factor=args.lr_factor, 149 | lr=args.lr, 150 | wd=args.wd, 151 | snapshot=args.snapshot, 152 | params=args.params, 153 | ignore_params=args.ignore_params, 154 | fix_params=args.fix_params, 155 | decay_all=args.decay_all 156 | ) 157 | 158 | 159 | if __name__ == '__main__': 160 | cli(obj=Bunch()) 161 | -------------------------------------------------------------------------------- /sigr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import os 3 | import numpy as np 4 | from .proxy import LazyProxy 5 | assert LazyProxy 6 | 7 | 8 | @contextmanager 9 | def logging_context(path=None, level=None): 10 | from logbook import StderrHandler, FileHandler 11 | from logbook.compat import redirected_logging 12 | with StderrHandler(level=level or 'INFO').applicationbound(): 13 | if path: 14 | if not os.path.isdir(os.path.dirname(path)): 15 | os.makedirs(os.path.dirname(path)) 16 | with FileHandler(path, bubble=True).applicationbound(): 17 | with redirected_logging(): 18 | yield 19 | else: 20 | with redirected_logging(): 21 | yield 22 | 23 | 24 | def return_list(func): 25 | import inspect 26 | from functools import wraps 27 | assert inspect.isgeneratorfunction(func) 28 | 29 | @wraps(func) 30 | def wrapped(*args, **kargs): 31 | return list(func(*args, **kargs)) 32 | 33 | return wrapped 34 | 35 | 36 | @return_list 37 | def continuous_segments(label): 38 | label = np.asarray(label) 39 | 40 | if not len(label): 41 | return 42 | 43 | breaks = list(np.where(label[:-1] != label[1:])[0] + 1) 44 | for begin, end in zip([0] + breaks, breaks + [len(label)]): 45 | assert begin < end 46 | yield begin, end 47 | 48 | 49 | def cached(*args, **kargs): 50 | import joblib as jb 51 | from .. import CACHE 52 | memo = getattr(cached, 'memo', None) 53 | if memo is None: 54 | cached.memo = memo = jb.Memory(CACHE, verbose=0) 55 | return memo.cache(*args, **kargs) 56 | 57 | 58 | def get_segments(data, window): 59 | return windowed_view( 60 | data.flat, 61 | window * data.shape[1], 62 | (window - 1) * data.shape[1] 63 | ) 64 | 65 | 66 | def windowed_view(arr, window, overlap): 67 | from numpy.lib.stride_tricks import as_strided 68 | arr = np.asarray(arr) 69 | window_step = window - overlap 70 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 71 | window) 72 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 73 | arr.strides[-1:]) 74 | return as_strided(arr, shape=new_shape, strides=new_strides) 75 | 76 | 77 | class Bunch(dict): 78 | 79 | def __getattr__(self, key): 80 | if key in self: 81 | return self[key] 82 | raise AttributeError(key) 83 | 84 | def __setattr__(self, key, value): 85 | self[key] = value 86 | 87 | 88 | def _packargs(func): 89 | from functools import wraps 90 | import inspect 91 | 92 | @wraps(func) 93 | def wrapped(ctx_or_args, **kargs): 94 | if isinstance(ctx_or_args, Bunch): 95 | args = ctx_or_args 96 | else: 97 | args = ctx_or_args.obj 98 | ignore = inspect.getargspec(func).args 99 | args.update({key: kargs.pop(key) for key in list(kargs) 100 | if key not in ignore and key not in args}) 101 | return func(ctx_or_args, **kargs) 102 | return wrapped 103 | 104 | 105 | def packargs(func): 106 | import click 107 | return click.pass_obj(_packargs(func)) 108 | 109 | 110 | def butter_bandpass_filter(data, lowcut, highcut, fs, order): 111 | from scipy.signal import butter, lfilter 112 | 113 | nyq = 0.5 * fs 114 | low = lowcut / nyq 115 | high = highcut / nyq 116 | 117 | b, a = butter(order, [low, high], btype='bandpass') 118 | y = lfilter(b, a, data) 119 | return y 120 | 121 | 122 | def butter_bandstop_filter(data, lowcut, highcut, fs, order): 123 | from scipy.signal import butter, lfilter 124 | 125 | nyq = 0.5 * fs 126 | low = lowcut / nyq 127 | high = highcut / nyq 128 | 129 | b, a = butter(order, [low, high], btype='bandstop') 130 | y = lfilter(b, a, data) 131 | return y 132 | 133 | 134 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 135 | from scipy.signal import butter, lfilter, filtfilt 136 | 137 | nyq = 0.5 * fs 138 | cut = cut / nyq 139 | 140 | b, a = butter(order, cut, btype='low') 141 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 142 | return y 143 | -------------------------------------------------------------------------------- /sigr/utils/proxy.py: -------------------------------------------------------------------------------- 1 | class LazyProxy(object): 2 | 3 | def __init__(self, make): 4 | self._make = make 5 | 6 | def __getattr__(self, name): 7 | if name == '_inst': 8 | self._inst = self._make() 9 | return self._inst 10 | return getattr(self._inst, name) 11 | 12 | def __setattr__(self, name, value): 13 | if name in ('_make', '_inst'): 14 | return super(LazyProxy, self).__setattr__(name, value) 15 | return setattr(self._inst, name, value) 16 | 17 | def __getstate__(self): 18 | return self._make 19 | 20 | def __setstate__(self, make): 21 | self._make = make 22 | 23 | def __hash__(self): 24 | return hash(self._make) 25 | 26 | def __iter__(self): 27 | return self._inst.__iter__() 28 | -------------------------------------------------------------------------------- /sigr/vote.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import joblib as jb 4 | from nose.tools import assert_greater 5 | from .utils import return_list, cached 6 | from . import Context 7 | 8 | 9 | def get_vote_accuracy_curve(labels, predictions, segments, windows, balance=False): 10 | if len(set(segments)) < len(windows): 11 | func = get_vote_accuracy_curve_aux 12 | else: 13 | func = get_vote_accuracy_curve_aux_few_windows 14 | return func(np.asarray(labels), 15 | np.asarray(predictions), 16 | np.asarray(segments), 17 | np.asarray(windows), 18 | balance) 19 | 20 | 21 | @cached 22 | def get_vote_accuracy_curve_aux(labels, predictions, segments, windows, balance): 23 | segment_labels = partial_vote(labels, segments) 24 | return ( 25 | np.asarray(windows), 26 | np.array(list(Context.parallel( 27 | jb.delayed(get_vote_accuracy_curve_step)( 28 | segment_labels, 29 | predictions, 30 | segments, 31 | window, 32 | balance 33 | ) for window in windows 34 | ))) 35 | ) 36 | 37 | 38 | @cached 39 | def get_vote_accuracy_curve_aux_few_windows(labels, predictions, segments, windows, balance): 40 | segment_labels = partial_vote(labels, segments) 41 | return ( 42 | np.asarray(windows), 43 | np.array([ 44 | get_vote_accuracy_curve_step( 45 | segment_labels, 46 | predictions, 47 | segments, 48 | window, 49 | balance, 50 | parallel=True 51 | ) for window in windows 52 | ]) 53 | ) 54 | 55 | 56 | def get_vote_accuracy(labels, predictions, segments, window, balance): 57 | _, y = get_vote_accuracy_curve(labels, predictions, segments, [window], balance) 58 | return y[0] 59 | 60 | 61 | vote = get_vote_accuracy 62 | 63 | 64 | def get_segment_vote_accuracy(segment_label, segment_predictions, window): 65 | def gen(): 66 | count = { 67 | label: np.hstack([[0], np.cumsum(segment_predictions == label)]) 68 | for label in set(segment_predictions) 69 | } 70 | tmp = window 71 | if tmp == -1: 72 | tmp = len(segment_predictions) 73 | tmp = min(tmp, len(segment_predictions)) 74 | for begin in range(len(segment_predictions) - tmp + 1): 75 | yield segment_label == max( 76 | count, 77 | key=lambda label: count[label][begin + tmp] - count[label][begin] 78 | ), segment_label 79 | return list(gen()) 80 | 81 | 82 | def get_vote_accuracy_curve_step(segment_labels, predictions, segments, window, 83 | balance, 84 | parallel=False): 85 | def gen(): 86 | # assert_greater(window, 0) 87 | assert window > 0 or window == -1 88 | if not parallel: 89 | for segment_label, segment_predictions in zip(segment_labels, split(predictions, segments)): 90 | for ret in get_segment_vote_accuracy(segment_label, segment_predictions, window): 91 | yield ret 92 | else: 93 | for rets in Context.parallel( 94 | jb.delayed(get_segment_vote_accuracy)(segment_label, segment_predictions, window) 95 | for segment_label, segment_predictions in zip(segment_labels, split(predictions, segments)) 96 | ): 97 | for ret in rets: 98 | yield ret 99 | 100 | good, labels = zip(*list(gen())) 101 | good = np.asarray(good) 102 | 103 | if not balance: 104 | return np.sum(good) / len(good) 105 | else: 106 | acc = [] 107 | for label in set(labels): 108 | mask = [labels == label] 109 | acc.append(np.sum(good[mask]) / np.sum(mask)) 110 | return np.mean(acc) 111 | 112 | 113 | @return_list 114 | def partial_vote(labels, segments, length=None): 115 | for part in split(labels, segments): 116 | part = list(part) 117 | 118 | if length is not None: 119 | part = part[:length] 120 | 121 | assert_greater(len(part), 0) 122 | yield max([(part.count(label), label) for label in set(part)])[1] 123 | 124 | 125 | def split(labels, segments): 126 | return [labels[segments == segment] for segment in sorted(set(segments))] 127 | --------------------------------------------------------------------------------