├── shutdown.sh
├── update_pip.py
├── git_setup.sh
├── jupyter_related.sh
├── cleanup_output.sh
├── startup.sh
├── download_images.sh
├── python_install.sh
├── README.md
├── conv2list.py
├── env.sh
├── addmissing.py
├── .gitignore
├── installs.sh
├── cmds.txt
├── LICENSE
└── TrainLoop.ipynb
/shutdown.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | rsync -avz /mnt/ram-disk/imaterialist_fashion/ /mnt/disks/imaterialist_fashion/
--------------------------------------------------------------------------------
/update_pip.py:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | curl -O https://bootstrap.pypa.io/get-pip.py
4 | python3 get-pip.py
5 |
--------------------------------------------------------------------------------
/git_setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | git config --global user.name "Sourabh Daptardar"
4 | git config --global user.email saurabh.daptardar@gmail.com
5 |
--------------------------------------------------------------------------------
/jupyter_related.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source env.sh
4 |
5 | conda install -y jupyter nb_conda
6 |
7 | jupyter notebook --generate-config
8 | jupyter notebook password
9 |
10 | conda install -y tqdm ipdb matplotlib
11 | conda install -y pytorch torchvision cuda91 -c pytorch
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/cleanup_output.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -vx
3 |
4 | rm -rf /mnt/disks/imaterialist_fashion/data_ifood/output/submissions
5 | mkdir /mnt/disks/imaterialist_fashion/data_ifood/output/submissions
6 | rm -rf /mnt/ram-disk/imaterialist_fashion/data_ifood/output/submissions
7 | mkdir /mnt/ram-disk/imaterialist_fashion/data_ifood/output/submissions
8 |
--------------------------------------------------------------------------------
/startup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Create RAM disk
4 | mkdir -p /mnt/ram-disk
5 | mount -t tmpfs -o size=50g tmpfs /mnt/ram-disk
6 | mkdir -p /mnt/ram-disk/imaterialist_fashion
7 | chown -R saurabh_daptardar:saurabh_daptardar /mnt/ram-disk/imaterialist_fashion
8 |
9 | # Mount Persistent disk and rsync
10 | mkdir -p /mnt/disks
11 | mkdir -p /mnt/disks/imaterialist_fashion
12 | mount -t ext4 /dev/sdb /mnt/disks/imaterialist_fashion && \
13 | rsync -avz /mnt/disks/imaterialist_fashion/ /mnt/ram-disk/imaterialist_fashion/
14 |
--------------------------------------------------------------------------------
/download_images.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -ne 3 ]
4 | then
5 | cat << MSG
6 | Usage: ./download.sh /path/to/filelist.txt /pat/to/dest/folder num_jobs
7 | exit
8 | MSG
9 | fi
10 |
11 | filelist="$1"
12 | dst="$2"
13 | njobs="$3"
14 |
15 | dwld() {
16 | [ "$(identify $2.jpg |& awk '{ print $2 == "JPEG" }')" == "1" ] || (wget -q -t 5 $1 -O $2.jpg && mogrify -resize "256^>" $2.jpg)
17 | }
18 |
19 | export -f dwld
20 |
21 | cd $dst
22 | pwd
23 | parallel --no-notice --load="100%" --progress --bar --colsep=" " -j $njobs "dwld {1} {2}" :::: "$filelist"
24 | cd -
25 |
--------------------------------------------------------------------------------
/python_install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export CROOT="$HOME/.local/conda"
4 | export MROOT="$CROOT/miniconda3"
5 |
6 | mkdir -p "$HOME/.local"
7 | # mkdir -p "$CROOT"
8 |
9 | curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
10 | chmod +x ~/miniconda.sh && \
11 | ~/miniconda.sh -b -p $CROOT && \
12 | rm ~/miniconda.sh && \
13 | $CROOT/bin/conda create -n py3k python=3 && \
14 | $CROOT/bin/conda list -n py3k && \
15 | $CROOT/bin/conda install -n py3k numpy pyyaml scipy ipython mkl mkl-include && \
16 | $CROOT/bin/conda install -n py3k -c pytorch magma-cuda91 && \
17 | $CROOT/bin/conda install -n py3k -c jupyter ipykernel
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fine Grained Visual Categorization
2 |
3 | Code for the [fine grained visual categorization challenge - FGVC5 at CVPR 2018](https://sites.google.com/view/fgvc5/home).
4 |
5 | ## Code
6 |
7 | * [iMaterialist-Fashion](https://www.kaggle.com/c/imaterialist-challenge-fashion-2018)
8 | ResNet-50 multilabel classifier trained on 10000 images on a single Nvidia GTX 980
9 | * [iFood](https://sites.google.com/view/fgvc5/competitions/fgvcx/ifood)
10 | PNASNet-5-Large multilabel classifier trained on 101K training images on 8 x V100 GPU machine on Google Compute Engine cloud based virtual machine
11 |
12 | ## Pretrained Models
13 |
14 | coming soon
15 |
16 | ## References
17 |
18 | * [PyTorch ImageNet example](https://github.com/pytorch/examples/tree/master/imagenet)
19 |
--------------------------------------------------------------------------------
/conv2list.py:
--------------------------------------------------------------------------------
1 | import ijson
2 | import argparse
3 | import sys
4 |
5 |
6 | def conv2list(ifile, ofile):
7 | objs = ijson.items(ifile, 'images')
8 | for x in next(objs):
9 | ofile.write('%s %s\n' % (x['url'], x['imageId']))
10 |
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser(description='Extract image URLs from JSON and convert to list')
14 | parser.add_argument(
15 | '--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
16 | metavar='PATH',
17 | help="Input JSON (default: standard input).")
18 | parser.add_argument(
19 | '--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
20 | metavar='PATH',
21 | help="Output file (default: standard output)")
22 | args = parser.parse_args()
23 | conv2list(args.input, args.output)
24 |
--------------------------------------------------------------------------------
/env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Setup environment for coding
3 |
4 | # 1) CUDA
5 | export CUDAROOT="/usr/local/cuda"
6 | export PATH="$CUDAROOT/bin:$PATH"
7 | export LD_LIBRARY_PATH="$CUDAROOT/lib64:$LD_LIBRARY_PATH"
8 | export CUDNNROOT="$CUDAROOT"
9 | export PATH="$CUDNNROOT/bin:$PATH"
10 | export LD_LIBRARY_PATH="$CUDNNROOT/lib64:$LD_LIBRARY_PATH"
11 |
12 | # 2) Local installs
13 | export LOCALDIR="$HOME/.local"
14 | export PATH="$LOCALDIR/bin:$PATH"
15 |
16 | # 3) Miniconda
17 | export CONDA="$LOCALDIR/conda"
18 | export MCONDA="$CONDA/miniconda3"
19 | export PY3K="$MCONDA/envs/py3k"
20 | export PATH="$CONDA/bin:$PY3K/bin:$MCONDA/bin:$PATH"
21 | export PYTHONPATH="$PY3K/lib/python3.6/site-packages:$PYTHONPATH"
22 | export PATH="$PY3K/bin:$PATH"
23 | export CPATH="$PY3K/include:$CPATH"
24 | export LD_LIBRARY_PATH="$PY3K/lib:$LD_LIBRARY_PATH"
25 |
26 |
27 | source activate py3k
28 |
--------------------------------------------------------------------------------
/addmissing.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 |
4 |
5 | def addmissing(ifile, ofile, N):
6 | with ifile:
7 | with ofile:
8 | S = set(range(1, N+1))
9 | for line in ifile:
10 | l = line.strip().split(',')
11 | if l[0] != 'image_id':
12 | S.remove(int(l[0]))
13 | ofile.write(line)
14 | for s in S:
15 | ofile.write("%d,\n" % s)
16 |
17 |
18 | if __name__ == '__main__':
19 | parser = argparse.ArgumentParser(description='Add missing items to submission file')
20 | parser.add_argument(
21 | '--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
22 | metavar='PATH',
23 | help="Input csv (default: standard input).")
24 | parser.add_argument(
25 | '--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
26 | metavar='PATH',
27 | help="Output csv (default: standard output)")
28 | parser.add_argument('--range', '-r', type=int, default=39706,
29 | help="range of ids to be filled in (default: 39706)")
30 | args = parser.parse_args()
31 | addmissing(args.input, args.output, args.range)
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/installs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | apt-get update
4 | apt-get install -y build-essential binutils git imagemagick unzip parallel gcc g++
5 |
6 |
7 | echo "Checking for CUDA and installing."
8 | # Check for CUDA and try to install.
9 | if ! dpkg-query -W cuda-9-1; then
10 | curl -O http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_9.1.85-1_amd64.deb
11 | dpkg -i ./cuda-repo-ubuntu1604_9.1.85-1_amd64.deb
12 | apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/7fa2af80.pub
13 | apt-get update
14 | apt-get install cuda-9-1 -y
15 | fi
16 | # Enable persistence mode
17 | nvidia-smi -pm 1
18 |
19 | # CUDNN
20 | export CUDNN_VERSION="7.1.4.18"
21 | apt-get update
22 | apt-get install -y --no-install-recommends \
23 | libcudnn7=$CUDNN_VERSION-1+cuda9.1 \
24 | libcudnn7-dev=$CUDNN_VERSION-1+cuda9.1
25 |
26 |
27 | curl -O http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvidia-machine-learning-repo-ubuntu1604_1.0.0-1_amd64.deb
28 | apt-get install -y libnccl2=2.1.15-1+cuda9.1 libnccl-dev=2.1.15-1+cuda9.1
29 |
30 | nvidia-smi
31 |
32 | apt-get update
33 | apt-get install -y --no-install-recommends \
34 | build-essential \
35 | cmake \
36 | git \
37 | curl \
38 | vim \
39 | ca-certificates \
40 | libjpeg-dev \
41 | libpng-dev
42 |
43 | # mkdir -p /mnt/ram-disk
44 | # mount -t tmpfs -o size=50g tmpfs /mnt/ram-disk
45 | # mkdir -p /mnt/ram-disk/imaterialist_fashion
46 | # chown -R saurabh_daptardar:saurabh_daptardar /mnt/ram-disk/imaterialist_fashion
47 |
48 | # mkdir -p /mnt/disks
49 | # mkdir -p /mnt/disks/imaterialist_fashion
50 | # mount -t ext4 /dev/sdb /mnt/disks/imaterialist_fashion
51 |
--------------------------------------------------------------------------------
/cmds.txt:
--------------------------------------------------------------------------------
1 |
2 | # Create regional persistent SSD disk
3 |
4 | gcloud beta compute disks create imaterialist-fashion-ssd --size 50 --type pd-ssd --region us-central1 --replica-zones us-central1-f
5 |
6 | #################################################################
7 | # High-CPU instance for downloading dataset
8 |
9 | gcloud beta compute --project=deccanlearners instances create downloader-vm --zone=us-central1-f --machine-type=custom-24-22272 --subnet=default --network-tier=PREMIUM --no-restart-on-failure --maintenance-policy=TERMINATE --preemptible --service-account=206684283285-compute@developer.gserviceaccount.com --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --tags=http-server,https-server --image=ubuntu-1804-bionic-v20180522 --image-project=ubuntu-os-cloud --boot-disk-size=10GB --boot-disk-type=pd-ssd --boot-disk-device-name=downloader-vm
10 |
11 | gcloud compute --project=deccanlearners firewall-rules create default-allow-http --direction=INGRESS --priority=1000 --network=default --action=ALLOW --rules=tcp:80 --source-ranges=0.0.0.0/0 --target-tags=http-server
12 |
13 | gcloud compute --project=deccanlearners firewall-rules create default-allow-https --direction=INGRESS --priority=1000 --network=default --action=ALLOW --rules=tcp:443 --source-ranges=0.0.0.0/0 --target-tags=https-server
14 |
15 | #################################################################
16 |
17 | gcloud beta compute instances attach-disk downloader-vm --disk imaterialist-fashion-ssd --disk-scope regional
18 |
19 |
20 | ########################################################
21 | # Format Disk
22 |
23 | sudo lsblk
24 |
25 | sudo mkfs.ext4 -m 0 -F -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb
26 |
27 | sudo mkdir -p /mnt/disks/imaterialist_fashion
28 |
29 | sudo mount -o discard,defaults /dev/sdb /mnt/disks/imaterialist_fashion
30 |
31 | sudo chmod a+w /mnt/disks/imaterialist_fashion
32 |
33 | sudo cp /etc/fstab /etc/fstab.backup
34 |
35 | sudo blkid /dev/sdb
36 |
37 | # In /etc/fstab
38 | UUID=[UUID_VALUE] /mnt/disks/imaterialist_fashion ext4 discard,defaults,nofail 0 2
39 |
40 | OR
41 | echo UUID=`sudo blkid -s UUID -o value /dev/sdb` /mnt/disks/disk-1 ext4 discard,defaults,nofail 0 2 | sudo tee -a /etc/fstab
42 |
43 | ########################################################
44 | sudo apt install build-essential binutils git imagemagick unzip parallel
45 |
46 | git clone --recursive https://github.com/sourabhd/objrec
47 |
48 |
49 | #########################################################
50 |
51 | mkdir -p /mnt/disks/imaterialist_fashion/data
52 | mkdir -p /mnt/disks/imaterialist_fashion/data/input
53 | mkdir -p /mnt/disks/imaterialist_fashion/data/output
54 | chown -R saurabh_daptardar:saurabh_daptardar /mnt/disks/imaterialist_fashion/data
55 | #########################################################
56 |
57 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/train_tiny.json saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
58 |
59 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/train_small.json saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
60 |
61 | gcloud compute scp /home/sourabhd/.kaggle/competitions/imaterialist-challenge-fashion-2018/train.json.zip saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
62 |
63 | gcloud compute scp /home/sourabhd/.kaggle/competitions/imaterialist-challenge-fashion-2018/train.json.zip saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
64 |
65 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/validation.json saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
66 |
67 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/test.json saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
68 |
69 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/train_tiny.txt saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
70 |
71 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/train_small.txt saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
72 |
73 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/train.txt.tar.bz2 saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
74 |
75 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/validation.txt saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
76 |
77 | gcloud compute scp /data/datasets/kaggle_fashion/data/input/test.txt saurabh_daptardar@downloader-vm:/mnt/disks/imaterialist_fashion/data/input/
78 |
79 | cd /mnt/disks/imaterialist_fashion/data/input
80 | unzip train.json.zip
81 | tar xvf train.txt.tar.bz2
82 |
83 | #########################################################
84 |
85 | mkdir -p /mnt/disks/imaterialist_fashion/data/input/img_train
86 | mkdir -p /mnt/disks/imaterialist_fashion/data/input/img_validation
87 | mkdir -p /mnt/disks/imaterialist_fashion/data/input/img_test
88 |
89 |
90 |
91 | ############################################################
92 |
93 | gcloud beta compute --project=deccanlearners instances create gpu-vm --description=iMaterialist\ challenge\ CVPR --zone=us-central1-f --machine-type=custom-8-196608-ext --subnet=default --address=35.206.83.26 --network-tier=STANDARD --metadata=shutdown-script=\!/bin/bash$'\n'rsync\ -avz\ /mnt/ram-disk/imaterialist_fashion/\ /mnt/disks/imaterialist_fashion/,startup-script=\#\!/bin/bash$'\n'$'\n'\#\ Create\ RAM\ disk$'\n'mkdir\ -p\ /mnt/ram-disk$'\n'mount\ -t\ tmpfs\ -o\ size=50g\ tmpfs\ /mnt/ram-disk$'\n'mkdir\ -p\ /mnt/ram-disk/imaterialist_fashion$'\n'chown\ -R\ saurabh_daptardar:saurabh_daptardar\ /mnt/ram-disk/imaterialist_fashion$'\n'$'\n'\#\ Mount\ Persistent\ disk\ and\ rsync\ $'\n'mkdir\ -p\ /mnt/disks$'\n'mkdir\ -p\ /mnt/disks/imaterialist_fashion$'\n'mount\ -t\ ext4\ /dev/sdb\ /mnt/disks/imaterialist_fashion\ \&\&\ \\$'\n'rsync\ -avz\ /mnt/disks/imaterialist_fashion/\ /mnt/ram-disk/imaterialist_fashion/ --no-restart-on-failure --maintenance-policy=TERMINATE --preemptible --service-account=206684283285-compute@developer.gserviceaccount.com --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --accelerator=type=nvidia-tesla-v100,count=8 --tags=http-server,https-server --image=ubuntu-1604-xenial-v20180522 --image-project=ubuntu-os-cloud --boot-disk-size=10GB --boot-disk-type=pd-ssd --boot-disk-device-name=gpu-vm
94 |
95 |
96 | ############################################################
97 | gcloud beta compute --project=deccanlearners instances create gpu-vm --description=VM\ for\ iMaterialist\ challenge\ CVPR --zone=us-central1-f --machine-type=custom-4-131072-ext --subnet=default --address=35.206.83.26 --network-tier=STANDARD --metadata=shutdown-script=\#\!/bin/bash$'\n'rsync\ -avz\ /mnt/ram-disk/imaterialist_fashion/\ /mnt/disks/imaterialist_fashion/,startup-script=\#\!/bin/bash$'\n'$'\n'\#\ Create\ RAM\ disk$'\n'mkdir\ -p\ /mnt/ram-disk$'\n'mount\ -t\ tmpfs\ -o\ size=50g\ tmpfs\ /mnt/ram-disk$'\n'mkdir\ -p\ /mnt/ram-disk/imaterialist_fashion$'\n'chown\ -R\ saurabh_daptardar:saurabh_daptardar\ /mnt/ram-disk/imaterialist_fashion$'\n'$'\n'\#\ Mount\ Persistent\ disk\ and\ rsync\ $'\n'mkdir\ -p\ /mnt/disks$'\n'mkdir\ -p\ /mnt/disks/imaterialist_fashion$'\n'mount\ -t\ ext4\ /dev/sdb\ /mnt/disks/imaterialist_fashion\ \&\&\ \\$'\n'rsync\ -avz\ /mnt/disks/imaterialist_fashion/\ /mnt/ram-disk/imaterialist_fashion/ --no-restart-on-failure --maintenance-policy=TERMINATE --preemptible --service-account=206684283285-compute@developer.gserviceaccount.com --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --accelerator=type=nvidia-tesla-p100,count=4 --tags=http-server,https-server --image=ubuntu-1604-xenial-v20180522 --image-project=ubuntu-os-cloud --boot-disk-size=10GB --boot-disk-type=pd-standard --boot-disk-device-name=gpu-vm
98 |
99 | ############################################################
100 |
101 | gcloud beta compute --project=deccanlearners instances create gpu-vm --description=VM\ for\ CVPR\ FGVC\ 2018 --zone=us-central1-f --machine-type=custom-8-196608-ext --subnet=default --address=35.206.83.26 --network-tier=STANDARD --metadata=shutdown-script=\!/bin/bash$'\n'rsync\ -avz\ /mnt/ram-disk/imaterialist_fashion/\ /mnt/disks/imaterialist_fashion/,startup-script=\#\!/bin/bash$'\n'$'\n'\#\ Create\ RAM\ disk$'\n'mkdir\ -p\ /mnt/ram-disk$'\n'mount\ -t\ tmpfs\ -o\ size=50g\ tmpfs\ /mnt/ram-disk$'\n'mkdir\ -p\ /mnt/ram-disk/imaterialist_fashion$'\n'chown\ -R\ saurabh_daptardar:saurabh_daptardar\ /mnt/ram-disk/imaterialist_fashion$'\n'$'\n'\#\ Mount\ Persistent\ disk\ and\ rsync\ $'\n'mkdir\ -p\ /mnt/disks$'\n'mkdir\ -p\ /mnt/disks/imaterialist_fashion$'\n'mount\ -t\ ext4\ /dev/sdb\ /mnt/disks/imaterialist_fashion\ \&\&\ \\$'\n'rsync\ -avz\ /mnt/disks/imaterialist_fashion/\ /mnt/ram-disk/imaterialist_fashion/ --no-restart-on-failure --maintenance-policy=TERMINATE --preemptible --service-account=206684283285-compute@developer.gserviceaccount.com --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --accelerator=type=nvidia-tesla-v100,count=8 --tags=http-server,https-server --image=ubuntu-1604-xenial-v20180522 --image-project=ubuntu-os-cloud --boot-disk-size=10GB --boot-disk-type=pd-ssd --boot-disk-device-name=gpu-vm
102 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/TrainLoop.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "4OGEu9nITbnO"
8 | },
9 | "source": [
10 | "# Install Torch"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "colab": {
18 | "autoexec": {
19 | "startup": false,
20 | "wait_interval": 0
21 | }
22 | },
23 | "colab_type": "code",
24 | "id": "uX687hj69g9g"
25 | },
26 | "outputs": [],
27 | "source": [
28 | "torchver = \"0.4.0\""
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {
35 | "colab": {
36 | "autoexec": {
37 | "startup": false,
38 | "wait_interval": 0
39 | },
40 | "base_uri": "https://localhost:8080/",
41 | "height": 306
42 | },
43 | "colab_type": "code",
44 | "executionInfo": {
45 | "elapsed": 1967,
46 | "status": "ok",
47 | "timestamp": 1527015382182,
48 | "user": {
49 | "displayName": "Sourabh Daptardar",
50 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
51 | "userId": "115812262388010820083"
52 | },
53 | "user_tz": -330
54 | },
55 | "id": "_gX52NUpzIYC",
56 | "outputId": "311649d4-8385-4c39-b984-1846818c2388"
57 | },
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "/bin/sh: 1: /opt/bin/nvidia-smi: not found\n",
64 | "Thu May 31 10:15:31 2018 \n",
65 | "+-----------------------------------------------------------------------------+\n",
66 | "| NVIDIA-SMI 396.26 Driver Version: 396.26 |\n",
67 | "|-------------------------------+----------------------+----------------------+\n",
68 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
69 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
70 | "|===============================+======================+======================|\n",
71 | "| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n",
72 | "| N/A 34C P0 37W / 300W | 98MiB / 16160MiB | 0% Default |\n",
73 | "+-------------------------------+----------------------+----------------------+\n",
74 | "| 1 Tesla V100-SXM2... Off | 00000000:00:05.0 Off | 0 |\n",
75 | "| N/A 36C P0 35W / 300W | 0MiB / 16160MiB | 0% Default |\n",
76 | "+-------------------------------+----------------------+----------------------+\n",
77 | "| 2 Tesla V100-SXM2... Off | 00000000:00:06.0 Off | 0 |\n",
78 | "| N/A 37C P0 38W / 300W | 0MiB / 16160MiB | 0% Default |\n",
79 | "+-------------------------------+----------------------+----------------------+\n",
80 | "| 3 Tesla V100-SXM2... Off | 00000000:00:07.0 Off | 0 |\n",
81 | "| N/A 36C P0 35W / 300W | 0MiB / 16160MiB | 0% Default |\n",
82 | "+-------------------------------+----------------------+----------------------+\n",
83 | "| 4 Tesla V100-SXM2... Off | 00000000:00:08.0 Off | 0 |\n",
84 | "| N/A 34C P0 38W / 300W | 0MiB / 16160MiB | 0% Default |\n",
85 | "+-------------------------------+----------------------+----------------------+\n",
86 | "| 5 Tesla V100-SXM2... Off | 00000000:00:09.0 Off | 0 |\n",
87 | "| N/A 35C P0 37W / 300W | 0MiB / 16160MiB | 0% Default |\n",
88 | "+-------------------------------+----------------------+----------------------+\n",
89 | "| 6 Tesla V100-SXM2... Off | 00000000:00:0A.0 Off | 0 |\n",
90 | "| N/A 35C P0 37W / 300W | 0MiB / 16160MiB | 0% Default |\n",
91 | "+-------------------------------+----------------------+----------------------+\n",
92 | "| 7 Tesla V100-SXM2... Off | 00000000:00:0B.0 Off | 0 |\n",
93 | "| N/A 35C P0 37W / 300W | 0MiB / 16160MiB | 0% Default |\n",
94 | "+-------------------------------+----------------------+----------------------+\n",
95 | " \n",
96 | "+-----------------------------------------------------------------------------+\n",
97 | "| Processes: GPU Memory |\n",
98 | "| GPU PID Type Process name Usage |\n",
99 | "|=============================================================================|\n",
100 | "| 0 1655 G /usr/lib/xorg/Xorg 98MiB |\n",
101 | "+-----------------------------------------------------------------------------+\n"
102 | ]
103 | }
104 | ],
105 | "source": [
106 | "!/opt/bin/nvidia-smi || /usr/bin/nvidia-smi"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 3,
112 | "metadata": {
113 | "colab": {
114 | "autoexec": {
115 | "startup": false,
116 | "wait_interval": 0
117 | },
118 | "base_uri": "https://localhost:8080/",
119 | "height": 51
120 | },
121 | "colab_type": "code",
122 | "executionInfo": {
123 | "elapsed": 2037,
124 | "status": "ok",
125 | "timestamp": 1527015384289,
126 | "user": {
127 | "displayName": "Sourabh Daptardar",
128 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
129 | "userId": "115812262388010820083"
130 | },
131 | "user_tz": -330
132 | },
133 | "id": "Z0wFaqgbE4wI",
134 | "outputId": "feb79b53-fdc9-45eb-92c7-7dff6334c183"
135 | },
136 | "outputs": [],
137 | "source": [
138 | "# !ls /colabtools"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 4,
144 | "metadata": {
145 | "colab": {
146 | "autoexec": {
147 | "startup": false,
148 | "wait_interval": 0
149 | },
150 | "base_uri": "https://localhost:8080/",
151 | "height": 34
152 | },
153 | "colab_type": "code",
154 | "executionInfo": {
155 | "elapsed": 2121,
156 | "status": "ok",
157 | "timestamp": 1527015386438,
158 | "user": {
159 | "displayName": "Sourabh Daptardar",
160 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
161 | "userId": "115812262388010820083"
162 | },
163 | "user_tz": -330
164 | },
165 | "id": "G4WvjiCDzWPR",
166 | "outputId": "708f088a-9f78-4a08-9811-145b8874105b"
167 | },
168 | "outputs": [
169 | {
170 | "name": "stdout",
171 | "output_type": "stream",
172 | "text": [
173 | "Python 3.6.5 :: Anaconda, Inc.\r\n"
174 | ]
175 | }
176 | ],
177 | "source": [
178 | "!python --version"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 5,
184 | "metadata": {
185 | "colab": {
186 | "autoexec": {
187 | "startup": false,
188 | "wait_interval": 0
189 | },
190 | "base_uri": "https://localhost:8080/",
191 | "height": 187
192 | },
193 | "colab_type": "code",
194 | "executionInfo": {
195 | "elapsed": 5041,
196 | "status": "ok",
197 | "timestamp": 1527015392777,
198 | "user": {
199 | "displayName": "Sourabh Daptardar",
200 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
201 | "userId": "115812262388010820083"
202 | },
203 | "user_tz": -330
204 | },
205 | "id": "z85P4eDZNSdu",
206 | "outputId": "26542569-3f96-4bc5-de5b-3d8c51c98f75"
207 | },
208 | "outputs": [
209 | {
210 | "name": "stdout",
211 | "output_type": "stream",
212 | "text": [
213 | "\u001b[33mSkipping pillow as it is not installed.\u001b[0m\n",
214 | "Collecting pillow-simd\n",
215 | "\u001b[31mtorchvision 0.2.1 requires pillow>=4.1.1, which is not installed.\u001b[0m\n",
216 | "Installing collected packages: pillow-simd\n",
217 | " Found existing installation: Pillow-SIMD 5.1.1.post0\n",
218 | " Uninstalling Pillow-SIMD-5.1.1.post0:\n",
219 | " Successfully uninstalled Pillow-SIMD-5.1.1.post0\n",
220 | "Successfully installed pillow-simd-5.1.1.post0\n"
221 | ]
222 | }
223 | ],
224 | "source": [
225 | "!pip3 uninstall -y pillow\n",
226 | "!CC=\"cc -mavx2\" pip3 install -U --force-reinstall pillow-simd\n"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 6,
232 | "metadata": {
233 | "colab": {
234 | "autoexec": {
235 | "startup": false,
236 | "wait_interval": 0
237 | },
238 | "base_uri": "https://localhost:8080/",
239 | "height": 309
240 | },
241 | "colab_type": "code",
242 | "executionInfo": {
243 | "elapsed": 3414,
244 | "status": "ok",
245 | "timestamp": 1527015396225,
246 | "user": {
247 | "displayName": "Sourabh Daptardar",
248 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
249 | "userId": "115812262388010820083"
250 | },
251 | "user_tz": -330
252 | },
253 | "id": "7FiDFXCiT8wS",
254 | "outputId": "6a582c41-86d8-4fac-bdfd-b1008bba2099"
255 | },
256 | "outputs": [],
257 | "source": [
258 | "\n",
259 | "# !pip3 install ipdb\n"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": 7,
265 | "metadata": {
266 | "colab": {
267 | "autoexec": {
268 | "startup": false,
269 | "wait_interval": 0
270 | },
271 | "base_uri": "https://localhost:8080/",
272 | "height": 272
273 | },
274 | "colab_type": "code",
275 | "executionInfo": {
276 | "elapsed": 6145,
277 | "status": "ok",
278 | "timestamp": 1527015402414,
279 | "user": {
280 | "displayName": "Sourabh Daptardar",
281 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
282 | "userId": "115812262388010820083"
283 | },
284 | "user_tz": -330
285 | },
286 | "id": "TAP3KzaO_3mr",
287 | "outputId": "d48fd867-01ff-479d-a37e-82c0ac00ce44"
288 | },
289 | "outputs": [
290 | {
291 | "name": "stdout",
292 | "output_type": "stream",
293 | "text": [
294 | "36\n",
295 | "PIL\n"
296 | ]
297 | }
298 | ],
299 | "source": [
300 | "\n",
301 | "from os import path\n",
302 | "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n",
303 | "platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n",
304 | "\n",
305 | "pver = !python --version |& awk '{print $2 }' | awk -F. '{ print $1$2}'\n",
306 | "pyver = pver[0]\n",
307 | "print(pyver)\n",
308 | "\n",
309 | "# cver = !echo \"cu`nvcc --version | sed \"s/ /\\n/g\" | grep -i release -A 1 | tail -n 1 | tr -d [\\.,]`\"\n",
310 | "# cudaver = cver[0]\n",
311 | "cudaver = 'cu91'\n",
312 | "\n",
313 | "# accelerator = cudaver if path.exists('/opt/bin/nvidia-smi') or path.exists('/usr/bin/nvidia-smi') else 'cpu'\n",
314 | "# print(accelerator)\n",
315 | "\n",
316 | "# torchurl = \"http://download.pytorch.org/whl/{0}/torch-{1}-cp{2}-cp{2}m-linux_x86_64.whl\".format(accelerator, torchver, pyver)\n",
317 | "# print(torchurl)\n",
318 | "\n",
319 | "# !pip3 install http://download.pytorch.org/whl/cu91/torch-0.4.0-cp36-cp36m-linux_x86_64.whl \n",
320 | "# !pip3 install torchvision\n",
321 | "\n",
322 | "import torch\n",
323 | "import torchvision\n",
324 | "print(torchvision.get_image_backend())"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 8,
330 | "metadata": {
331 | "colab": {
332 | "autoexec": {
333 | "startup": false,
334 | "wait_interval": 0
335 | },
336 | "base_uri": "https://localhost:8080/",
337 | "height": 34
338 | },
339 | "colab_type": "code",
340 | "executionInfo": {
341 | "elapsed": 3083,
342 | "status": "ok",
343 | "timestamp": 1527015405574,
344 | "user": {
345 | "displayName": "Sourabh Daptardar",
346 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
347 | "userId": "115812262388010820083"
348 | },
349 | "user_tz": -330
350 | },
351 | "id": "a4CFa1WLgoUX",
352 | "outputId": "cd632861-265b-4c66-bc57-fd7e2d6228ff"
353 | },
354 | "outputs": [],
355 | "source": [
356 | "#!pip3 install tqdm"
357 | ]
358 | },
359 | {
360 | "cell_type": "markdown",
361 | "metadata": {
362 | "colab_type": "text",
363 | "id": "GZt8MRT5RfK6"
364 | },
365 | "source": [
366 | "# Imports"
367 | ]
368 | },
369 | {
370 | "cell_type": "code",
371 | "execution_count": 9,
372 | "metadata": {
373 | "colab": {
374 | "autoexec": {
375 | "startup": false,
376 | "wait_interval": 0
377 | }
378 | },
379 | "colab_type": "code",
380 | "id": "ZptSyG9oSN1c"
381 | },
382 | "outputs": [],
383 | "source": [
384 | "import torch\n",
385 | "import os\n",
386 | "import sys\n",
387 | "import logging\n",
388 | "import io\n",
389 | "import time\n",
390 | "import shutil\n",
391 | "from tqdm import tqdm\n",
392 | "from matplotlib.pyplot import imshow\n",
393 | "import numpy as np\n",
394 | "from PIL import Image\n",
395 | "import torch\n",
396 | "import torch.nn as nn\n",
397 | "import torch.nn.parallel\n",
398 | "import torch.backends.cudnn as cudnn\n",
399 | "import torch.distributed as dist\n",
400 | "import torch.optim as optim\n",
401 | "import torch.optim.lr_scheduler as lr_scheduler\n",
402 | "import torch.utils.data\n",
403 | "import torch.utils.data.distributed\n",
404 | "import torchvision.transforms as transforms\n",
405 | "import torchvision.datasets as datasets\n",
406 | "import torchvision.models as models\n",
407 | "from argparse import Namespace\n",
408 | "from collections import OrderedDict\n",
409 | "from scipy.sparse import coo_matrix\n",
410 | "import socket\n",
411 | "from datetime import datetime\n",
412 | "import json\n",
413 | "import re\n",
414 | "import hashlib\n",
415 | "import subprocess\n",
416 | "from copy import deepcopy, copy\n",
417 | "from pprint import pprint\n",
418 | "import torch.utils.data as data\n",
419 | "from copy import copy\n",
420 | "import numpy as np\n",
421 | "import json\n",
422 | "from collections import namedtuple\n",
423 | "from PIL import Image \n",
424 | "from torchvision import get_image_backend\n",
425 | "from torch.utils.data.distributed import DistributedSampler\n",
426 | "import torch.nn.init as weight_init"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 10,
432 | "metadata": {
433 | "colab": {
434 | "autoexec": {
435 | "startup": false,
436 | "wait_interval": 0
437 | }
438 | },
439 | "colab_type": "code",
440 | "id": "IEEo0VYsZhvO"
441 | },
442 | "outputs": [],
443 | "source": [
444 | "%matplotlib inline"
445 | ]
446 | },
447 | {
448 | "cell_type": "markdown",
449 | "metadata": {
450 | "colab_type": "text",
451 | "id": "N1BQLwQTWcKU"
452 | },
453 | "source": [
454 | "# Parameters"
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": null,
460 | "metadata": {},
461 | "outputs": [],
462 | "source": [
463 | "def get_hostname_timestamp_id():\n",
464 | " return socket.gethostname() + '_' + re.sub(r'\\W+', '', str(datetime.now()))"
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": null,
470 | "metadata": {},
471 | "outputs": [],
472 | "source": [
473 | "def get_output_fname():\n",
474 | " return \"%s_%s_%s\" % (args.author, args.arch, get_hostname_timestamp_id())"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": null,
480 | "metadata": {
481 | "colab": {
482 | "autoexec": {
483 | "startup": false,
484 | "wait_interval": 0
485 | }
486 | },
487 | "colab_type": "code",
488 | "id": "eM2a7qmqWh3Q"
489 | },
490 | "outputs": [],
491 | "source": [
492 | "args = Namespace()\n",
493 | "# base_dir = '/content/fashion'\n",
494 | "# args.perm_dir = '/data/datasets/kaggle_fashion'\n",
495 | "# args.base_dir = '/data/datasets/kaggle_fashion'\n",
496 | "args.perm_dir = '/mnt/disks/imaterialist_fashion'\n",
497 | "args.base_dir = '/mnt/ram-disk/imaterialist_fashion'\n",
498 | "args.data_dir = args.base_dir + os.sep + 'data'\n",
499 | "args.input_dir = args.data_dir + os.sep + 'input'\n",
500 | "args.output_dir = args.data_dir + os.sep + 'output'\n",
501 | "args.train_zip = args.input_dir + os.sep + 'train_data.zip'\n",
502 | "args.val_zip = args.input_dir + os.sep + 'validation_data.zip'\n",
503 | "args.train_dir = args.input_dir + os.sep + 'img_train'\n",
504 | "args.val_dir = args.input_dir + os.sep + 'img_val'\n",
505 | "args.test_dir = args.input_dir + os.sep + 'img_test'\n",
506 | "args.train_id = \"1rx1rL8RUAggN4hKlrYLtpdQagtUWmIbO\"\n",
507 | "args.val_id = \"1U19eWiBFJ6wGcFk47l6g9mmoWp1i4hPY\"\n",
508 | "# args.train_labels_id = \"1NOoWniR3ioqPKbVWoaWGy4HPDzZAAJX9\"\n",
509 | "args.train_labels_id = \"1X7TpWyxxtmCT5rw__7OKus_W4fh8xpKO\" # small dataset\n",
510 | "args.val_labels_id = \"1d9RuQTx5E8qFxraIu6B4rDTOC4sx2xXT\"\n",
511 | "args.test_labels_id = \"1VwzGCJfOL13pk1Wi-xPHQ6mVnofy9_Z4\"\n",
512 | "# args.train_labels_json = args.input_dir + os.sep + 'train.json'\n",
513 | "args.train_labels_json = args.input_dir + os.sep + 'train_small.json' \n",
514 | "# args.train_labels_json = args.input_dir + os.sep + 'train_tiny.json' \n",
515 | "args.val_labels_json = args.input_dir + os.sep + 'validation.json'\n",
516 | "args.test_labels_json = args.input_dir + os.sep + 'test.json'\n",
517 | "args.debug_weights = False\n",
518 | "args.test_overfit = False\n",
519 | "args.num_labels = 228\n",
520 | "args.batch_size = 16\n",
521 | "# args.batch_size = 64\n",
522 | "args.image_min_size = 256\n",
523 | "args.nw_input_size = 224\n",
524 | "args.num_workers = 4\n",
525 | "args.imagenet_mean = [0.485, 0.456, 0.406]\n",
526 | "args.imagenet_std = [0.229, 0.224, 0.225]\n",
527 | "args.pretrain_dset_mean = args.imagenet_mean\n",
528 | "args.pretrain_dset_std = args.imagenet_std\n",
529 | "args.world_size = 1\n",
530 | "args.dist_url = 'file://' + args.output_dir + os.sep + 'dfile'\n",
531 | "args.dist_backend = 'gloo'\n",
532 | "args.distributed = args.world_size > 1\n",
533 | "args.arch = 'resnet101'\n",
534 | "# args.arch = 'resnet152'\n",
535 | "args.fv_size = 2048\n",
536 | "args.pretrained = True\n",
537 | "args.resume = False\n",
538 | "args.start_epoch = 0\n",
539 | "args.small=1e-12 # small value used for avoiding div by zero\n",
540 | "args.optimizer_learning_rate = 1e-4 # Adam optimizer initial learning rate\n",
541 | "args.scheduler_patience = 1 # Number of epochs with no improvement after which learning rate will be reduced\n",
542 | "args.scheduler_threshold = 1e-6 # learning rate scheduler threshold for measuring the new optimum, to only focus on significant changes\n",
543 | "args.scheduler_factor = 0.1 # learning rate scheduler factor by which the learning rate will be reduced. new_lr = lr * factor\n",
544 | "args.earlystopping_patience = 1 # early stopping patience is the number of epochs with no improvement after which training will be stopped\n",
545 | "args.earlystopping_min_delta = 1e-5 # minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement\n",
546 | "args.evaluate = False\n",
547 | "args.epochs = 2\n",
548 | "args.print_freq = args.batch_size\n",
549 | "args.ckpt_dir = args.output_dir + os.sep + 'ckpt'\n",
550 | "args.ckpt = args.ckpt_dir + os.sep + 'ckpt_%s.pth.tar' % (args.arch,)\n",
551 | "args.best = args.ckpt_dir + os.sep + 'best_%s.pth.tar' % (args.arch,)\n",
552 | "args.threshold = 0.5\n",
553 | "args.sub_dir = args.output_dir + os.sep + 'submissions'\n",
554 | "args.author = 'deccanlearners'\n",
555 | "args.output_id = get_output_fname()\n",
556 | "args.output_file = args.sub_dir + os.sep + 'output_%s.csv' % args.output_id\n",
557 | "args.params_file = args.sub_dir + os.sep + 'params_%s.json' % args.output_id\n",
558 | "args.min_img_bytes = 4792"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": null,
564 | "metadata": {
565 | "colab": {
566 | "autoexec": {
567 | "startup": false,
568 | "wait_interval": 0
569 | },
570 | "base_uri": "https://localhost:8080/",
571 | "height": 928
572 | },
573 | "colab_type": "code",
574 | "executionInfo": {
575 | "elapsed": 1332,
576 | "status": "error",
577 | "timestamp": 1527015425722,
578 | "user": {
579 | "displayName": "Sourabh Daptardar",
580 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
581 | "userId": "115812262388010820083"
582 | },
583 | "user_tz": -330
584 | },
585 | "id": "90hY9m66UYPd",
586 | "outputId": "de44a4f5-386e-4243-b3bb-2da97134ce99"
587 | },
588 | "outputs": [],
589 | "source": [
590 | "print(torch.backends.cudnn.version())\n",
591 | "print(torch.cuda.is_available())\n",
592 | "print(torch.cuda.get_device_name(0))\n"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": null,
598 | "metadata": {
599 | "colab": {
600 | "autoexec": {
601 | "startup": false,
602 | "wait_interval": 0
603 | }
604 | },
605 | "colab_type": "code",
606 | "id": "Kc-OLRcoHDBl"
607 | },
608 | "outputs": [],
609 | "source": [
610 | "cudnn.benchmark = True"
611 | ]
612 | },
613 | {
614 | "cell_type": "code",
615 | "execution_count": null,
616 | "metadata": {},
617 | "outputs": [],
618 | "source": [
619 | "def mkdir_p(d):\n",
620 | " os.makedirs(d, exist_ok=True)\n",
621 | "\n",
622 | "def sha1_hash(fname, blocksize=4096):\n",
623 | " \"\"\" compute sha1hash of a file \"\"\"\n",
624 | " hash = ''\n",
625 | " if not os.path.exists(fname):\n",
626 | " errmsg = \"File %s does not exist\" % (fname)\n",
627 | " print(errmsg)\n",
628 | " return ''\n",
629 | " try:\n",
630 | " hasher = hashlib.sha1()\n",
631 | " with open(fname, 'rb') as f:\n",
632 | " buf = f.read(blocksize)\n",
633 | " while len(buf) > 0:\n",
634 | " hasher.update(buf)\n",
635 | " buf = f.read(blocksize)\n",
636 | " hash = hasher.hexdigest()\n",
637 | " except:\n",
638 | " print(\"Exception in hashing file\")\n",
639 | " raise\n",
640 | " return hash\n",
641 | "\n",
642 | "\n",
643 | "def rsync_and_verify(src, dst, verify=False, max_attempts=1):\n",
644 | " \"\"\"Rsync src to dst and verify if copy is done\"\"\"\n",
645 | "\n",
646 | " print('Rsync %s to %s on %s\\n' % (src,\n",
647 | " dst,\n",
648 | " socket.gethostname()))\n",
649 | " sys.stdout.flush()\n",
650 | " src_ = deepcopy(src)\n",
651 | " dst_ = deepcopy(dst)\n",
652 | " src_cred = ''\n",
653 | " src_path = ''\n",
654 | " dst_cred = ''\n",
655 | " dst_path = ''\n",
656 | " rsync_path = ''\n",
657 | "\n",
658 | " if ':' in src:\n",
659 | " src_cred, src_path = src.split(':')\n",
660 | " else:\n",
661 | " src_cred = ''\n",
662 | " src_path = src\n",
663 | "\n",
664 | " if ':' in dst:\n",
665 | " dst_cred, dst_path = dst.split(':')\n",
666 | " else:\n",
667 | " dst_cred = ''\n",
668 | " dst_path = dst\n",
669 | "\n",
670 | " if src_cred == '':\n",
671 | " mkdir_p(src_path)\n",
672 | " else:\n",
673 | " rsync_path = '--rsync-path=' + '\"' + 'mkdir -p' + ' ' + src_path + ' ' + '&&' + ' ' + 'rsync' + '\"'\n",
674 | " \n",
675 | " if dst_cred == '':\n",
676 | " mkdir_p(dst_path)\n",
677 | " else:\n",
678 | " rsync_path = '--rsync-path=' + '\"' + 'mkdir -p' + ' ' + src_path + ' ' + '&&' + ' ' + 'rsync' + '\"'\n",
679 | "\n",
680 | " if src_[-1] != os.sep:\n",
681 | " src_ = src_ + os.sep\n",
682 | " \n",
683 | " if dst_[-1] != os.sep:\n",
684 | " dst_ = dst_ + os.sep\n",
685 | "\n",
686 | " for attempt in range(max_attempts):\n",
687 | " print('attempt %d' % attempt)\n",
688 | " try:\n",
689 | " copycmd = 'rsync -av' + ' ' + rsync_path + ' ' + src_ + ' ' + dst_ \n",
690 | " pprint(copycmd)\n",
691 | " sys.stdout.flush()\n",
692 | " output = subprocess.check_output(copycmd,\n",
693 | " shell=True)\n",
694 | " pprint(output)\n",
695 | " sys.stdout.flush()\n",
696 | "\n",
697 | " if verify:\n",
698 | " # Verify if the copying is done correctly\n",
699 | " if os.path.isdir(src):\n",
700 | " for fl in os.listdir(src):\n",
701 | " sfile = src + os.sep + fl\n",
702 | " dfile = dst + os.sep + fl\n",
703 | " shash = sha1_hash(sfile)\n",
704 | " dhash = sha1_hash(dfile)\n",
705 | " if shash != dhash:\n",
706 | " print('Hashes of files %s and %s do not match.' % (sfile, dfile))\n",
707 | " print('Error in copying. Quitting ...\\n')\n",
708 | " sys.stdout.flush()\n",
709 | " raise Exception('hash mismatch')\n",
710 | " print('.', end='')\n",
711 | " sys.stdout.flush()\n",
712 | " else:\n",
713 | " shash = sha1_hash(src)\n",
714 | " dhash = sha1_hash(dst)\n",
715 | " if shash != dhash:\n",
716 | " print('Hashes of files %s and %s do not match.' % (src, dst))\n",
717 | " print('Error in copying. Quitting ...\\n')\n",
718 | " sys.stdout.flush()\n",
719 | " raise Exception('hash mismatch')\n",
720 | " print('Hash check passed')\n",
721 | " sys.stdout.flush()\n",
722 | "\n",
723 | " break # break if successful\n",
724 | " # except Exception, arg:\n",
725 | " except:\n",
726 | " # print('Error:', arg)\n",
727 | " print('Error in rsync')\n",
728 | " pass # else retry\n"
729 | ]
730 | },
731 | {
732 | "cell_type": "code",
733 | "execution_count": null,
734 | "metadata": {
735 | "colab": {
736 | "autoexec": {
737 | "startup": false,
738 | "wait_interval": 0
739 | }
740 | },
741 | "colab_type": "code",
742 | "id": "jB9hgpyUfbqG"
743 | },
744 | "outputs": [],
745 | "source": [
746 | "os.makedirs(args.base_dir, exist_ok=True)\n",
747 | "os.makedirs(args.data_dir, exist_ok=True)\n",
748 | "os.makedirs(args.input_dir, exist_ok=True)\n",
749 | "os.makedirs(args.output_dir, exist_ok=True)\n",
750 | "os.makedirs(args.ckpt_dir, exist_ok=True)\n",
751 | "os.makedirs(args.sub_dir, exist_ok=True)"
752 | ]
753 | },
754 | {
755 | "cell_type": "code",
756 | "execution_count": null,
757 | "metadata": {},
758 | "outputs": [],
759 | "source": [
760 | "rsync_and_verify(args.perm_dir, args.base_dir)"
761 | ]
762 | },
763 | {
764 | "cell_type": "markdown",
765 | "metadata": {
766 | "colab_type": "text",
767 | "id": "dBC_aI1vRknn"
768 | },
769 | "source": [
770 | "# Download Dataset"
771 | ]
772 | },
773 | {
774 | "cell_type": "code",
775 | "execution_count": null,
776 | "metadata": {
777 | "colab": {
778 | "autoexec": {
779 | "startup": false,
780 | "wait_interval": 0
781 | }
782 | },
783 | "colab_type": "code",
784 | "id": "g5eP3RxWV5L5"
785 | },
786 | "outputs": [],
787 | "source": [
788 | "# from google.colab import auth\n",
789 | "# auth.authenticate_user()"
790 | ]
791 | },
792 | {
793 | "cell_type": "code",
794 | "execution_count": null,
795 | "metadata": {
796 | "colab": {
797 | "autoexec": {
798 | "startup": false,
799 | "wait_interval": 0
800 | }
801 | },
802 | "colab_type": "code",
803 | "id": "68MYkyHJWP0m"
804 | },
805 | "outputs": [],
806 | "source": [
807 | "# from googleapiclient.discovery import build\n",
808 | "# import io\n",
809 | "# from googleapiclient.http import MediaIoBaseDownload\n",
810 | "# import json\n",
811 | "\n",
812 | "# def md5_hash(fname, blocksize=4096):\n",
813 | "# \"\"\" compute md5hash of a file \"\"\"\n",
814 | "# import hashlib\n",
815 | "# hash = ''\n",
816 | "# if not os.path.exists(fname):\n",
817 | "# errmsg = \"File %s does not exist\" % (fname)\n",
818 | "# print(errmsg)\n",
819 | "# return ''\n",
820 | "# try:\n",
821 | "# hasher = hashlib.md5()\n",
822 | "# with open(fname, 'rb') as f:\n",
823 | "# buf = f.read(blocksize)\n",
824 | "# while len(buf) > 0:\n",
825 | "# hasher.update(buf)\n",
826 | "# buf = f.read(blocksize)\n",
827 | "# hash = hasher.hexdigest()\n",
828 | "# except:\n",
829 | "# print(\"Exception in hashing file\")\n",
830 | "# raise\n",
831 | "# return hash\n",
832 | "\n",
833 | "# def _download(drive_service, file_id, loc):\n",
834 | "# request = drive_service.files().get_media(fileId=file_id)\n",
835 | "# fh = io.FileIO(loc, mode='wb')\n",
836 | "# downloader = MediaIoBaseDownload(fh, request, chunksize=1024*1024)\n",
837 | "# prev_progress = 0\n",
838 | "# done = False\n",
839 | "# with tqdm(total=100) as pbar:\n",
840 | "# while done is False:\n",
841 | "# status, done = downloader.next_chunk()\n",
842 | "# if status:\n",
843 | "# # print(\"Download %d%%.\" % int(status.progress() * 100))\n",
844 | "# pbar.update(int(100 *(status.progress() - prev_progress)))\n",
845 | "# prev_progress = status.progress()\n",
846 | "# print(\"Download Complete!\")\n",
847 | "# file_size = os.path.getsize(loc)\n",
848 | "# print(\"Downloaded %d bytes\" % (file_size))\n",
849 | "\n",
850 | "# def download(file_id, loc):\n",
851 | "# \"\"\"Downloads a file to local file system.\"\"\" \n",
852 | "# drive_service = build('drive', 'v3')\n",
853 | " \n",
854 | "# request_mdata = drive_service.files().list(fields=\"files(md5Checksum, originalFilename, id)\")\n",
855 | "# rh = io.BytesIO()\n",
856 | "# downloader_mdata = MediaIoBaseDownload(rh, request_mdata, chunksize=1024*1024)\n",
857 | "# done = False\n",
858 | "# while not done:\n",
859 | "# _, done = downloader_mdata.next_chunk()\n",
860 | "# mdata = json.loads(rh.getvalue())\n",
861 | "# found = False\n",
862 | "# md5drive = ''\n",
863 | "# fname = ''\n",
864 | "# for x in mdata['files']:\n",
865 | "# if x['id'] == file_id:\n",
866 | "# found = True\n",
867 | "# md5drive = x['md5Checksum']\n",
868 | "# fname = x['originalFilename']\n",
869 | "# break\n",
870 | "# if not found:\n",
871 | "# print(\"{:s} : not found on gdrive\".format(file_id))\n",
872 | "# else:\n",
873 | "# if os.path.exists(loc):\n",
874 | "# if md5drive == md5_hash(loc):\n",
875 | "# print(\"{:s} : file already present on colab\".format(loc))\n",
876 | "# else:\n",
877 | "# print(\"{:s} [gdrive] and {:s} [colab] : md5 mismatch ... downloading\".format(fname, loc))\n",
878 | "# _download(drive_service, file_id, loc)\n",
879 | "# else:\n",
880 | "# print(\"{:s} not present on colab ... downloading ...\".format(loc))\n",
881 | "# _download(drive_service, file_id, loc)\n",
882 | " \n"
883 | ]
884 | },
885 | {
886 | "cell_type": "code",
887 | "execution_count": null,
888 | "metadata": {
889 | "colab": {
890 | "autoexec": {
891 | "startup": false,
892 | "wait_interval": 0
893 | },
894 | "base_uri": "https://localhost:8080/",
895 | "height": 102
896 | },
897 | "colab_type": "code",
898 | "executionInfo": {
899 | "elapsed": 8187,
900 | "status": "ok",
901 | "timestamp": 1527001525917,
902 | "user": {
903 | "displayName": "Sourabh Daptardar",
904 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
905 | "userId": "115812262388010820083"
906 | },
907 | "user_tz": -330
908 | },
909 | "id": "jOBHlpMKaE-F",
910 | "outputId": "71d5fe19-81cd-4428-90c7-8e0f720772b4"
911 | },
912 | "outputs": [],
913 | "source": [
914 | "# download(args.train_id, args.train_zip)\n",
915 | "# download(args.val_id, args.val_zip)\n",
916 | "# download(args.train_labels_id, args.train_labels_json)\n",
917 | "# download(args.val_labels_id, args.val_labels_json)\n",
918 | "# download(args.test_labels_id, args.test_labels_json)"
919 | ]
920 | },
921 | {
922 | "cell_type": "code",
923 | "execution_count": null,
924 | "metadata": {
925 | "colab": {
926 | "autoexec": {
927 | "startup": false,
928 | "wait_interval": 0
929 | }
930 | },
931 | "colab_type": "code",
932 | "id": "mA1kgVVEdSWI"
933 | },
934 | "outputs": [],
935 | "source": [
936 | "# import shutil\n",
937 | "# shutil.unpack_archive(args.train_zip, args.input_dir)\n",
938 | "# shutil.unpack_archive(args.val_zip, args.input_dir)\n"
939 | ]
940 | },
941 | {
942 | "cell_type": "code",
943 | "execution_count": null,
944 | "metadata": {
945 | "colab": {
946 | "autoexec": {
947 | "startup": false,
948 | "wait_interval": 0
949 | },
950 | "base_uri": "https://localhost:8080/",
951 | "height": 153
952 | },
953 | "colab_type": "code",
954 | "executionInfo": {
955 | "elapsed": 2944,
956 | "status": "ok",
957 | "timestamp": 1527001543581,
958 | "user": {
959 | "displayName": "Sourabh Daptardar",
960 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
961 | "userId": "115812262388010820083"
962 | },
963 | "user_tz": -330
964 | },
965 | "id": "X4Eugqw2fJRQ",
966 | "outputId": "76b1998f-7f77-4cfb-ea39-39fa50283aca"
967 | },
968 | "outputs": [],
969 | "source": [
970 | "# !ls -ltr /content/fashion/data/input"
971 | ]
972 | },
973 | {
974 | "cell_type": "code",
975 | "execution_count": null,
976 | "metadata": {
977 | "colab": {
978 | "autoexec": {
979 | "startup": false,
980 | "wait_interval": 0
981 | },
982 | "base_uri": "https://localhost:8080/",
983 | "height": 204
984 | },
985 | "colab_type": "code",
986 | "executionInfo": {
987 | "elapsed": 2245,
988 | "status": "ok",
989 | "timestamp": 1527001545898,
990 | "user": {
991 | "displayName": "Sourabh Daptardar",
992 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
993 | "userId": "115812262388010820083"
994 | },
995 | "user_tz": -330
996 | },
997 | "id": "_XnHDCmclB9B",
998 | "outputId": "7a665530-a452-41d9-ba5f-578562e6da35"
999 | },
1000 | "outputs": [],
1001 | "source": [
1002 | "# !ls -ltr /content/fashion/data/input/train_data | head"
1003 | ]
1004 | },
1005 | {
1006 | "cell_type": "code",
1007 | "execution_count": null,
1008 | "metadata": {
1009 | "colab": {
1010 | "autoexec": {
1011 | "startup": false,
1012 | "wait_interval": 0
1013 | },
1014 | "base_uri": "https://localhost:8080/",
1015 | "height": 204
1016 | },
1017 | "colab_type": "code",
1018 | "executionInfo": {
1019 | "elapsed": 2216,
1020 | "status": "ok",
1021 | "timestamp": 1527001548219,
1022 | "user": {
1023 | "displayName": "Sourabh Daptardar",
1024 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
1025 | "userId": "115812262388010820083"
1026 | },
1027 | "user_tz": -330
1028 | },
1029 | "id": "4bmRgA9ilISL",
1030 | "outputId": "db0b777e-bbc6-4326-c467-64779ca51b3b"
1031 | },
1032 | "outputs": [],
1033 | "source": [
1034 | "# !ls -ltr /content/fashion/data/input/validation_data | head"
1035 | ]
1036 | },
1037 | {
1038 | "cell_type": "markdown",
1039 | "metadata": {
1040 | "colab_type": "text",
1041 | "id": "c6LY5l-SRtWw"
1042 | },
1043 | "source": [
1044 | "# Dataset"
1045 | ]
1046 | },
1047 | {
1048 | "cell_type": "code",
1049 | "execution_count": null,
1050 | "metadata": {
1051 | "colab": {
1052 | "autoexec": {
1053 | "startup": false,
1054 | "wait_interval": 0
1055 | }
1056 | },
1057 | "colab_type": "code",
1058 | "id": "VMcIw45smeZE"
1059 | },
1060 | "outputs": [],
1061 | "source": [
1062 | "import torch.utils.data as data\n",
1063 | "from copy import copy\n",
1064 | "import numpy as np\n",
1065 | "\n",
1066 | "def fetch_labels(annotations, num_labels):\n",
1067 | " labels = OrderedDict()\n",
1068 | " for x in annotations:\n",
1069 | " arr = np.zeros((num_labels,), dtype=np.float32)\n",
1070 | " for y in map(int, x['labelId']):\n",
1071 | " arr[y-1] = 1.0\n",
1072 | " labels[int(x['imageId'])] = copy(arr)\n",
1073 | " return labels\n",
1074 | "\n",
1075 | "def json_to_dict(fpath):\n",
1076 | " import json\n",
1077 | " with open(fpath) as f: \n",
1078 | " D = json.load(f)\n",
1079 | " return D\n",
1080 | "\n",
1081 | "def get_labelinfo(annotations):\n",
1082 | " from collections import namedtuple\n",
1083 | " labelinfo = namedtuple('labelinfo', \"set min max count\")\n",
1084 | " labelinfo.set = set()\n",
1085 | " for x in annotations:\n",
1086 | " labelinfo.set.update(map(int, x['labelId']))\n",
1087 | " labelinfo.min = min(labelinfo.set)\n",
1088 | " labelinfo.max = max(labelinfo.set)\n",
1089 | " labelinfo.count = len(labelinfo.set)\n",
1090 | " return labelinfo\n",
1091 | "\n",
1092 | "def has_file_allowed_extension(filename, extensions):\n",
1093 | " \"\"\"Checks if a file is an allowed extension.\n",
1094 | " Args:\n",
1095 | " filename (string): path to a file\n",
1096 | " Returns:\n",
1097 | " bool: True if the filename ends with a known image extension\n",
1098 | " \"\"\"\n",
1099 | " filename_lower = filename.lower()\n",
1100 | " return any(filename_lower.endswith(ext) for ext in extensions)\n",
1101 | "\n",
1102 | "\n",
1103 | "def pil_loader(path):\n",
1104 | " from PIL import Image \n",
1105 | " # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\n",
1106 | " with open(path, 'rb') as f:\n",
1107 | " img = Image.open(f)\n",
1108 | " return img.convert('RGB')\n",
1109 | "\n",
1110 | "\n",
1111 | "def accimage_loader(path):\n",
1112 | " import accimage\n",
1113 | " try:\n",
1114 | " return accimage.Image(path)\n",
1115 | " except IOError:\n",
1116 | " # Potentially a decoding problem, fall back to PIL.Image\n",
1117 | " return pil_loader(path)\n",
1118 | "\n",
1119 | "\n",
1120 | "def default_loader(path):\n",
1121 | " from torchvision import get_image_backend\n",
1122 | " if get_image_backend() == 'accimage':\n",
1123 | " return accimage_loader(path)\n",
1124 | " else:\n",
1125 | " return pil_loader(path)\n",
1126 | "\n",
1127 | " \n",
1128 | "class FashionDataset(data.Dataset):\n",
1129 | " \"\"\"Fashion dataset CVPR challenge.\n",
1130 | " Adapted from torchvision ImageFolder.\n",
1131 | " Similar to ImageFolder with the following differences:\n",
1132 | " 1. Multilabel\n",
1133 | " 2. Directory structure where all images are directly in the root folder\n",
1134 | " 3. Labels are read from json file\n",
1135 | " \n",
1136 | " Args:\n",
1137 | " root (string): Root directory path.\n",
1138 | " loader (callable): A function to load a sample given its path.\n",
1139 | " extensions (list[string]): A list of allowed extensions.\n",
1140 | " transform (callable, optional): A function/transform that takes in\n",
1141 | " a sample and returns a transformed version.\n",
1142 | " E.g, ``transforms.RandomCrop`` for images.\n",
1143 | " target_transform (callable, optional): A function/transform that takes\n",
1144 | " in the target and transforms it.\n",
1145 | " \n",
1146 | " \"\"\"\n",
1147 | "\n",
1148 | " def __init__(self, root, metadata_file, num_labels=228, transform=None, target_transform=None,\n",
1149 | " loader=default_loader, test=False, min_img_bytes=4792):\n",
1150 | " extensions = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']\n",
1151 | " self.test = test\n",
1152 | " self.num_labels = num_labels\n",
1153 | " self.images_ = OrderedDict()\n",
1154 | " self.images = OrderedDict()\n",
1155 | " self.metadata_file = metadata_file\n",
1156 | " self.metadata = json_to_dict(self.metadata_file)\n",
1157 | " self.transform = transform\n",
1158 | " self.root = root\n",
1159 | " self.target_transform = target_transform\n",
1160 | " self.loader = loader\n",
1161 | " self.corrupt = 0\n",
1162 | " self.corrupt_ids = set()\n",
1163 | " self.labels = OrderedDict()\n",
1164 | " self.labels_ = OrderedDict()\n",
1165 | " \n",
1166 | " # Fetch labels\n",
1167 | " if not self.test:\n",
1168 | " self.labels_ = fetch_labels(self.metadata['annotations'], self.num_labels)\n",
1169 | "\n",
1170 | " # Create Image list\n",
1171 | " for x in self.metadata['images']:\n",
1172 | " self.images_[int(x['imageId'])] = '%s%s%d.jpg' % (root, os.sep, int(x['imageId']))\n",
1173 | " \n",
1174 | " # Remove corrupt image files\n",
1175 | " ids = self.images_.keys()\n",
1176 | " for i in tqdm(ids):\n",
1177 | " ## Correct but slow\n",
1178 | "# try:\n",
1179 | "# img = self.loader(self.images_[i])\n",
1180 | "# img.close()\n",
1181 | "# except:\n",
1182 | "# self.corrupt += 1\n",
1183 | "# self.corrupt_ids.add(i)\n",
1184 | " ## Optimistic \n",
1185 | " if os.path.getsize(self.images_[i]) < min_img_bytes:\n",
1186 | " self.corrupt += 1\n",
1187 | " self.corrupt_ids.add(i)\n",
1188 | "\n",
1189 | " for i in ids:\n",
1190 | " if i not in self.corrupt_ids:\n",
1191 | " self.images[i] = copy(self.images_[i])\n",
1192 | " if not self.test:\n",
1193 | " self.labels[i] = copy(self.labels_[i])\n",
1194 | " self.image_ids = list(self.images.keys())\n",
1195 | " \n",
1196 | " if not self.test:\n",
1197 | " self.labelinfo = get_labelinfo(self.metadata['annotations'])\n",
1198 | " \n",
1199 | " def __getitem__(self, index):\n",
1200 | " \"\"\"\n",
1201 | " Args:\n",
1202 | " index (int): Index\n",
1203 | " Returns:\n",
1204 | " tuple: (sample, target) where target is class_index of the target class.\n",
1205 | " \"\"\"\n",
1206 | " if not self.test:\n",
1207 | " path, target = self.images[self.image_ids[index]], self.labels[self.image_ids[index]]\n",
1208 | " else:\n",
1209 | " path = self.images[self.image_ids[index]]\n",
1210 | " sample = self.loader(path)\n",
1211 | " if self.transform is not None:\n",
1212 | " sample = self.transform(sample)\n",
1213 | " if not self.test:\n",
1214 | " if self.target_transform is not None:\n",
1215 | " target = self.target_transform(target)\n",
1216 | " \n",
1217 | " if self.test:\n",
1218 | " return sample\n",
1219 | " else:\n",
1220 | " return sample, target\n",
1221 | "\n",
1222 | " def __len__(self):\n",
1223 | " return len(self.images)\n",
1224 | " \n",
1225 | " def __repr__(self):\n",
1226 | " fmt_str = 'Dataset ' + self.__class__.__name__ + '\\n'\n",
1227 | " fmt_str += ' Number of datapoints: {}\\n'.format(self.__len__())\n",
1228 | " fmt_str += ' Number of corrupt datapoints discarded: {}\\n'.format(self.corrupt)\n",
1229 | " if not self.test:\n",
1230 | " fmt_str += ' Number of labels: {}\\n'.format(self.labelinfo.count)\n",
1231 | " fmt_str += ' Root Location: {}\\n'.format(self.root)\n",
1232 | " fmt_str += ' Metadata file: {}\\n'.format(self.metadata_file)\n",
1233 | " tmp = ' Transforms (if any): '\n",
1234 | " fmt_str += '{0}{1}\\n'.format(tmp, self.transform.__repr__().replace('\\n', '\\n' + ' ' * len(tmp)))\n",
1235 | " if not self.test:\n",
1236 | " tmp = ' Target Transforms (if any): '\n",
1237 | " fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\\n', '\\n' + ' ' * len(tmp)))\n",
1238 | " tmp = ' Loader: '\n",
1239 | " fmt_str += '\\n{0}{1}'.format(tmp, self.loader.__name__)\n",
1240 | " return fmt_str\n",
1241 | "\n",
1242 | " "
1243 | ]
1244 | },
1245 | {
1246 | "cell_type": "code",
1247 | "execution_count": null,
1248 | "metadata": {
1249 | "colab": {
1250 | "autoexec": {
1251 | "startup": false,
1252 | "wait_interval": 0
1253 | }
1254 | },
1255 | "colab_type": "code",
1256 | "id": "BuB8IhRXrZzK"
1257 | },
1258 | "outputs": [],
1259 | "source": [
1260 | "import torchvision.transforms as transforms\n",
1261 | "\n",
1262 | "def create_transforms(args):\n",
1263 | " if args.test_overfit:\n",
1264 | " train_tform = transforms.Compose([transforms.Resize(args.image_min_size),\n",
1265 | " transforms.CenterCrop(args.nw_input_size),\n",
1266 | " transforms.ToTensor(),\n",
1267 | " transforms.Normalize(mean=args.pretrain_dset_mean,\n",
1268 | " std=args.pretrain_dset_std)\n",
1269 | " ])\n",
1270 | " else:\n",
1271 | " train_tform = transforms.Compose([transforms.RandomResizedCrop(args.nw_input_size),\n",
1272 | " transforms.RandomHorizontalFlip(),\n",
1273 | " transforms.ToTensor(),\n",
1274 | " transforms.Normalize(mean=args.pretrain_dset_mean,\n",
1275 | " std=args.pretrain_dset_std)\n",
1276 | " ])\n",
1277 | "\n",
1278 | " val_tform = transforms.Compose([transforms.Resize(args.image_min_size),\n",
1279 | " transforms.CenterCrop(args.nw_input_size),\n",
1280 | " transforms.ToTensor(),\n",
1281 | " transforms.Normalize(mean=args.pretrain_dset_mean,\n",
1282 | " std=args.pretrain_dset_std)\n",
1283 | " ])\n",
1284 | " return (train_tform, val_tform)"
1285 | ]
1286 | },
1287 | {
1288 | "cell_type": "code",
1289 | "execution_count": null,
1290 | "metadata": {},
1291 | "outputs": [],
1292 | "source": [
1293 | "train_tform, val_tform = create_transforms(args)"
1294 | ]
1295 | },
1296 | {
1297 | "cell_type": "code",
1298 | "execution_count": null,
1299 | "metadata": {
1300 | "colab": {
1301 | "autoexec": {
1302 | "startup": false,
1303 | "wait_interval": 0
1304 | },
1305 | "base_uri": "https://localhost:8080/",
1306 | "height": 459
1307 | },
1308 | "colab_type": "code",
1309 | "executionInfo": {
1310 | "elapsed": 981,
1311 | "status": "ok",
1312 | "timestamp": 1527001551536,
1313 | "user": {
1314 | "displayName": "Sourabh Daptardar",
1315 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
1316 | "userId": "115812262388010820083"
1317 | },
1318 | "user_tz": -330
1319 | },
1320 | "id": "b7JjnbO4a1bU",
1321 | "outputId": "d20390c8-2b46-4399-bd70-7486521b4976",
1322 | "scrolled": true
1323 | },
1324 | "outputs": [],
1325 | "source": [
1326 | "train_dset = FashionDataset(args.train_dir, args.train_labels_json, args.num_labels, transform=train_tform, min_img_bytes=args.min_img_bytes)\n",
1327 | "val_dset = FashionDataset(args.val_dir, args.val_labels_json, args.num_labels, transform=val_tform, min_img_bytes=args.min_img_bytes)\n",
1328 | "test_dset = FashionDataset(args.test_dir, args.test_labels_json, args.num_labels, transform=val_tform, test=True, min_img_bytes=args.min_img_bytes) # same transform as validation\n",
1329 | "\n",
1330 | "\n",
1331 | "print(train_dset)\n",
1332 | "print(val_dset)\n",
1333 | "print(test_dset)"
1334 | ]
1335 | },
1336 | {
1337 | "cell_type": "code",
1338 | "execution_count": null,
1339 | "metadata": {
1340 | "colab": {
1341 | "autoexec": {
1342 | "startup": false,
1343 | "wait_interval": 0
1344 | }
1345 | },
1346 | "colab_type": "code",
1347 | "id": "-EoLW0no-em7"
1348 | },
1349 | "outputs": [],
1350 | "source": [
1351 | "def tensor_to_numpy(t, avg, std):\n",
1352 | " return (255.0 * (np.transpose(np.asarray(t), (1, 2, 0)) * std + avg)).astype(np.uint8)\n",
1353 | " "
1354 | ]
1355 | },
1356 | {
1357 | "cell_type": "code",
1358 | "execution_count": null,
1359 | "metadata": {
1360 | "colab": {
1361 | "autoexec": {
1362 | "startup": false,
1363 | "wait_interval": 0
1364 | },
1365 | "base_uri": "https://localhost:8080/",
1366 | "height": 439
1367 | },
1368 | "colab_type": "code",
1369 | "executionInfo": {
1370 | "elapsed": 1703,
1371 | "status": "ok",
1372 | "timestamp": 1527001554370,
1373 | "user": {
1374 | "displayName": "Sourabh Daptardar",
1375 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
1376 | "userId": "115812262388010820083"
1377 | },
1378 | "user_tz": -330
1379 | },
1380 | "id": "L8qbOaD8HvYi",
1381 | "outputId": "4e64893b-2094-4f5f-c9cd-6469eb2eaa8e"
1382 | },
1383 | "outputs": [],
1384 | "source": [
1385 | "rnd1 = np.random.randint(len(train_dset))\n",
1386 | "im1, lbl1 = train_dset[rnd1]\n",
1387 | "imshow(tensor_to_numpy(im1, args.pretrain_dset_mean, args.pretrain_dset_std))\n",
1388 | "print(lbl1)"
1389 | ]
1390 | },
1391 | {
1392 | "cell_type": "code",
1393 | "execution_count": null,
1394 | "metadata": {
1395 | "colab": {
1396 | "autoexec": {
1397 | "startup": false,
1398 | "wait_interval": 0
1399 | },
1400 | "base_uri": "https://localhost:8080/",
1401 | "height": 439
1402 | },
1403 | "colab_type": "code",
1404 | "executionInfo": {
1405 | "elapsed": 1502,
1406 | "status": "ok",
1407 | "timestamp": 1527001555965,
1408 | "user": {
1409 | "displayName": "Sourabh Daptardar",
1410 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
1411 | "userId": "115812262388010820083"
1412 | },
1413 | "user_tz": -330
1414 | },
1415 | "id": "3h5_M6G0QBYZ",
1416 | "outputId": "01e18eb7-caf7-4bcb-82b1-b296b9784185"
1417 | },
1418 | "outputs": [],
1419 | "source": [
1420 | "rnd2 = np.random.randint(len(val_dset))\n",
1421 | "im2, lbl2 = val_dset[rnd2]\n",
1422 | "imshow(tensor_to_numpy(im2, args.pretrain_dset_mean, args.pretrain_dset_std))\n",
1423 | "print(lbl2)"
1424 | ]
1425 | },
1426 | {
1427 | "cell_type": "code",
1428 | "execution_count": null,
1429 | "metadata": {},
1430 | "outputs": [],
1431 | "source": [
1432 | "rnd3 = np.random.randint(len(test_dset))\n",
1433 | "im3 = test_dset[rnd3]\n",
1434 | "imshow(tensor_to_numpy(im3, args.pretrain_dset_mean, args.pretrain_dset_std))\n"
1435 | ]
1436 | },
1437 | {
1438 | "cell_type": "markdown",
1439 | "metadata": {
1440 | "colab_type": "text",
1441 | "id": "iU_VDQm2Rtro"
1442 | },
1443 | "source": [
1444 | "# DataLoader"
1445 | ]
1446 | },
1447 | {
1448 | "cell_type": "code",
1449 | "execution_count": null,
1450 | "metadata": {
1451 | "colab": {
1452 | "autoexec": {
1453 | "startup": false,
1454 | "wait_interval": 0
1455 | }
1456 | },
1457 | "colab_type": "code",
1458 | "id": "4KcwWuuHoxoo"
1459 | },
1460 | "outputs": [],
1461 | "source": [
1462 | "if args.distributed:\n",
1463 | " dist.init_process_group(backend=args.dist_backend,\n",
1464 | " init_method=args.dist_url,\n",
1465 | " world_size=args.world_size)\n"
1466 | ]
1467 | },
1468 | {
1469 | "cell_type": "code",
1470 | "execution_count": null,
1471 | "metadata": {
1472 | "colab": {
1473 | "autoexec": {
1474 | "startup": false,
1475 | "wait_interval": 0
1476 | }
1477 | },
1478 | "colab_type": "code",
1479 | "id": "EjMlN6vqHtsE"
1480 | },
1481 | "outputs": [],
1482 | "source": [
1483 | "from torch.utils.data.distributed import DistributedSampler \n",
1484 | "\n",
1485 | "\n",
1486 | "if args.distributed:\n",
1487 | " train_sampler = DistributedSampler(train_dset)\n",
1488 | "else:\n",
1489 | " train_sampler = None\n",
1490 | "\n",
1491 | "train_loader = torch.utils.data.DataLoader(train_dset,\n",
1492 | " batch_size=args.batch_size,\n",
1493 | " shuffle=(train_sampler is None),\n",
1494 | " num_workers=args.num_workers,\n",
1495 | " pin_memory=True,\n",
1496 | " sampler=train_sampler\n",
1497 | " )\n",
1498 | "\n",
1499 | "val_loader = torch.utils.data.DataLoader(val_dset,\n",
1500 | " batch_size=args.batch_size,\n",
1501 | " shuffle=False,\n",
1502 | " num_workers=args.num_workers,\n",
1503 | " pin_memory=True\n",
1504 | " )\n",
1505 | "\n",
1506 | "test_loader = torch.utils.data.DataLoader(test_dset,\n",
1507 | " batch_size=args.batch_size,\n",
1508 | " shuffle=False,\n",
1509 | " num_workers=args.num_workers,\n",
1510 | " pin_memory=True\n",
1511 | " )\n",
1512 | "\n"
1513 | ]
1514 | },
1515 | {
1516 | "cell_type": "code",
1517 | "execution_count": null,
1518 | "metadata": {
1519 | "colab": {
1520 | "autoexec": {
1521 | "startup": false,
1522 | "wait_interval": 0
1523 | }
1524 | },
1525 | "colab_type": "code",
1526 | "id": "uzNG4-7x6Ovt"
1527 | },
1528 | "outputs": [],
1529 | "source": [
1530 | "# train_images, train_labels = next(iter(train_loader))"
1531 | ]
1532 | },
1533 | {
1534 | "cell_type": "code",
1535 | "execution_count": null,
1536 | "metadata": {
1537 | "colab": {
1538 | "autoexec": {
1539 | "startup": false,
1540 | "wait_interval": 0
1541 | },
1542 | "base_uri": "https://localhost:8080/",
1543 | "height": 731
1544 | },
1545 | "colab_type": "code",
1546 | "executionInfo": {
1547 | "elapsed": 1409,
1548 | "status": "ok",
1549 | "timestamp": 1527001561251,
1550 | "user": {
1551 | "displayName": "Sourabh Daptardar",
1552 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
1553 | "userId": "115812262388010820083"
1554 | },
1555 | "user_tz": -330
1556 | },
1557 | "id": "CDOOVYOHbcbl",
1558 | "outputId": "46f258e9-ee85-4435-8df8-57b8912d5ced"
1559 | },
1560 | "outputs": [],
1561 | "source": [
1562 | "# rnd11 = np.random.randint(args.batch_size)\n",
1563 | "# print(train_images[rnd11,:,:,:])\n",
1564 | "# print(train_labels[rnd11, :])"
1565 | ]
1566 | },
1567 | {
1568 | "cell_type": "code",
1569 | "execution_count": null,
1570 | "metadata": {
1571 | "colab": {
1572 | "autoexec": {
1573 | "startup": false,
1574 | "wait_interval": 0
1575 | }
1576 | },
1577 | "colab_type": "code",
1578 | "id": "IZ7R4Mgb7F3b"
1579 | },
1580 | "outputs": [],
1581 | "source": [
1582 | "# val_images, val_labels = next(iter(val_loader))"
1583 | ]
1584 | },
1585 | {
1586 | "cell_type": "code",
1587 | "execution_count": null,
1588 | "metadata": {
1589 | "colab": {
1590 | "autoexec": {
1591 | "startup": false,
1592 | "wait_interval": 0
1593 | },
1594 | "base_uri": "https://localhost:8080/",
1595 | "height": 731
1596 | },
1597 | "colab_type": "code",
1598 | "executionInfo": {
1599 | "elapsed": 918,
1600 | "status": "ok",
1601 | "timestamp": 1527001564208,
1602 | "user": {
1603 | "displayName": "Sourabh Daptardar",
1604 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
1605 | "userId": "115812262388010820083"
1606 | },
1607 | "user_tz": -330
1608 | },
1609 | "id": "CIOHOwAqaRkX",
1610 | "outputId": "7b674419-1de5-4765-c7e4-c2db097f3170"
1611 | },
1612 | "outputs": [],
1613 | "source": [
1614 | "# rnd21 = np.random.randint(args.batch_size)\n",
1615 | "# print(val_images[rnd21,:,:,:])\n",
1616 | "# print(val_labels[rnd21, :])"
1617 | ]
1618 | },
1619 | {
1620 | "cell_type": "markdown",
1621 | "metadata": {
1622 | "colab_type": "text",
1623 | "id": "cQAQrfMJRtv3"
1624 | },
1625 | "source": [
1626 | "\n",
1627 | "# Model"
1628 | ]
1629 | },
1630 | {
1631 | "cell_type": "code",
1632 | "execution_count": null,
1633 | "metadata": {
1634 | "colab": {
1635 | "autoexec": {
1636 | "startup": false,
1637 | "wait_interval": 0
1638 | }
1639 | },
1640 | "colab_type": "code",
1641 | "id": "jNxCNVyu98GF"
1642 | },
1643 | "outputs": [],
1644 | "source": [
1645 | "import torch.nn.init as weight_init\n",
1646 | "\n",
1647 | "\n",
1648 | "class FCWithLogSigmoid(nn.Module):\n",
1649 | " \n",
1650 | " def __init__(self, num_inputs, num_outputs):\n",
1651 | " super(FCWithLogSigmoid, self).__init__()\n",
1652 | " self.linear = nn.Linear(num_inputs, num_outputs)\n",
1653 | " self.logsigmoid = nn.LogSigmoid()\n",
1654 | " \n",
1655 | " def forward(self, x):\n",
1656 | " return self.logsigmoid(self.linear(x))\n",
1657 | "\n",
1658 | "\n",
1659 | "def create_model(arch, num_labels=228, fv_size=2048, pretrained=True, resume=False, distributed=False):\n",
1660 | " if pretrained:\n",
1661 | " print(\"=> using pre-trained model '{}'\".format(arch))\n",
1662 | " model = models.__dict__[arch](pretrained=True)\n",
1663 | " else:\n",
1664 | " print(\"=> creating model '{}'\".format(arch))\n",
1665 | " model = models.__dict__[arch]()\n",
1666 | " model.fc = FCWithLogSigmoid(fv_size, num_labels)\n",
1667 | " if not distributed:\n",
1668 | " if arch.startswith('alexnet') or arch.startswith('vgg'):\n",
1669 | " model.features = torch.nn.DataParallel(model.features)\n",
1670 | " model.cuda()\n",
1671 | " else:\n",
1672 | " model = torch.nn.DataParallel(model).cuda()\n",
1673 | " else:\n",
1674 | " model.cuda()\n",
1675 | " model = torch.nn.parallel.DistributedDataParallel(model)\n",
1676 | " return model\n"
1677 | ]
1678 | },
1679 | {
1680 | "cell_type": "code",
1681 | "execution_count": null,
1682 | "metadata": {
1683 | "colab": {
1684 | "autoexec": {
1685 | "startup": false,
1686 | "wait_interval": 0
1687 | }
1688 | },
1689 | "colab_type": "code",
1690 | "id": "9eMpjddlO6BC"
1691 | },
1692 | "outputs": [],
1693 | "source": [
1694 | "def count_parameters(model):\n",
1695 | " \"\"\"source: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9\"\"\"\n",
1696 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)"
1697 | ]
1698 | },
1699 | {
1700 | "cell_type": "code",
1701 | "execution_count": null,
1702 | "metadata": {
1703 | "colab": {
1704 | "autoexec": {
1705 | "startup": false,
1706 | "wait_interval": 0
1707 | },
1708 | "base_uri": "https://localhost:8080/",
1709 | "height": 34
1710 | },
1711 | "colab_type": "code",
1712 | "executionInfo": {
1713 | "elapsed": 1945,
1714 | "status": "ok",
1715 | "timestamp": 1527001568263,
1716 | "user": {
1717 | "displayName": "Sourabh Daptardar",
1718 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
1719 | "userId": "115812262388010820083"
1720 | },
1721 | "user_tz": -330
1722 | },
1723 | "id": "CDhvVsKk_cNI",
1724 | "outputId": "e3b9d484-4d9e-4fbd-c7d2-799d0d838ab9"
1725 | },
1726 | "outputs": [],
1727 | "source": [
1728 | "model = create_model(args.arch,\n",
1729 | " num_labels=args.num_labels,\n",
1730 | " fv_size=args.fv_size,\n",
1731 | " pretrained=args.pretrained,\n",
1732 | " resume=args.resume,\n",
1733 | " distributed=args.distributed)"
1734 | ]
1735 | },
1736 | {
1737 | "cell_type": "code",
1738 | "execution_count": null,
1739 | "metadata": {
1740 | "colab": {
1741 | "autoexec": {
1742 | "startup": false,
1743 | "wait_interval": 0
1744 | },
1745 | "base_uri": "https://localhost:8080/",
1746 | "height": 34
1747 | },
1748 | "colab_type": "code",
1749 | "executionInfo": {
1750 | "elapsed": 990,
1751 | "status": "ok",
1752 | "timestamp": 1527001569282,
1753 | "user": {
1754 | "displayName": "Sourabh Daptardar",
1755 | "photoUrl": "//lh4.googleusercontent.com/-onn5Q0_MiKQ/AAAAAAAAAAI/AAAAAAAACDI/iOxkSEz16nA/s50-c-k-no/photo.jpg",
1756 | "userId": "115812262388010820083"
1757 | },
1758 | "user_tz": -330
1759 | },
1760 | "id": "yG1C75oXPLx8",
1761 | "outputId": "c1d7369f-4563-4fbf-d155-62227edccd93"
1762 | },
1763 | "outputs": [],
1764 | "source": [
1765 | "print(\"Neural Network has \", count_parameters(model), \" trainable parameters\")"
1766 | ]
1767 | },
1768 | {
1769 | "cell_type": "code",
1770 | "execution_count": null,
1771 | "metadata": {},
1772 | "outputs": [],
1773 | "source": [
1774 | "class WeightUpdateTracker:\n",
1775 | " \n",
1776 | " def __init__(self, model):\n",
1777 | " with torch.no_grad():\n",
1778 | " self.num_param_tensors = len(list(model.parameters()))\n",
1779 | " self.prev_pnorms = torch.zeros(self.num_param_tensors) \n",
1780 | " self.curr_pnorms = self.parameter_norms(model) \n",
1781 | "\n",
1782 | " def parameter_norms(self, model):\n",
1783 | " with torch.no_grad():\n",
1784 | " pnorms = torch.zeros(self.num_param_tensors)\n",
1785 | " for i, x in enumerate(list(model.parameters())):\n",
1786 | " pnorms[i] = x.norm().item()\n",
1787 | " return pnorms\n",
1788 | " \n",
1789 | " def track(self, model):\n",
1790 | " with torch.no_grad():\n",
1791 | " self.prev_pnorms = self.curr_pnorms.clone()\n",
1792 | " self.curr_pnorms = self.parameter_norms(model)\n",
1793 | " self.delta = (self.curr_pnorms - self.prev_pnorms) / self.prev_pnorms\n",
1794 | "\n",
1795 | " \n",
1796 | " def __repr__(self):\n",
1797 | " with torch.no_grad():\n",
1798 | " return self.delta.__repr__()\n",
1799 | " "
1800 | ]
1801 | },
1802 | {
1803 | "cell_type": "markdown",
1804 | "metadata": {
1805 | "colab_type": "text",
1806 | "id": "VIILcEp9Rtz-"
1807 | },
1808 | "source": [
1809 | "# Loss Function\n"
1810 | ]
1811 | },
1812 | {
1813 | "cell_type": "code",
1814 | "execution_count": null,
1815 | "metadata": {},
1816 | "outputs": [],
1817 | "source": [
1818 | "criterion = torch.nn.BCEWithLogitsLoss().cuda()"
1819 | ]
1820 | },
1821 | {
1822 | "cell_type": "markdown",
1823 | "metadata": {
1824 | "colab_type": "text",
1825 | "id": "PICCxotzRt4z"
1826 | },
1827 | "source": [
1828 | "# Update Rule"
1829 | ]
1830 | },
1831 | {
1832 | "cell_type": "code",
1833 | "execution_count": null,
1834 | "metadata": {
1835 | "colab": {
1836 | "autoexec": {
1837 | "startup": false,
1838 | "wait_interval": 0
1839 | }
1840 | },
1841 | "colab_type": "code",
1842 | "id": "zaX2mCHTDgSi"
1843 | },
1844 | "outputs": [],
1845 | "source": [
1846 | "optimizer = optim.Adam(model.parameters(),\n",
1847 | " amsgrad=True,\n",
1848 | " lr=args.optimizer_learning_rate,\n",
1849 | " betas=(0.9, 0.999),\n",
1850 | " eps=1e-8,\n",
1851 | " weight_decay=0.0\n",
1852 | " )\n",
1853 | "scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,\n",
1854 | " mode='max', # F1 measure\n",
1855 | " patience=args.scheduler_patience,\n",
1856 | " threshold=args.scheduler_threshold,\n",
1857 | " factor=args.scheduler_factor,\n",
1858 | " verbose=1\n",
1859 | " )\n"
1860 | ]
1861 | },
1862 | {
1863 | "cell_type": "markdown",
1864 | "metadata": {
1865 | "colab_type": "text",
1866 | "id": "tCm_msJ0RuIu"
1867 | },
1868 | "source": [
1869 | "# Training Loop\n"
1870 | ]
1871 | },
1872 | {
1873 | "cell_type": "code",
1874 | "execution_count": null,
1875 | "metadata": {},
1876 | "outputs": [],
1877 | "source": [
1878 | "def load_checkpoint(model, optimizer, scheduler, args, resume=True, ckpt=None):\n",
1879 | " \"\"\"optionally resume from a checkpoint.\"\"\"\n",
1880 | " best_f1 = 0\n",
1881 | " if args.resume:\n",
1882 | " if os.path.isfile(ckpt):\n",
1883 | " print(\"=> loading checkpoint '{}'\".format(ckpt))\n",
1884 | " checkpoint = torch.load(ckpt)\n",
1885 | " args.start_epoch = checkpoint['epoch']\n",
1886 | " best_f1 = checkpoint['best_f1']\n",
1887 | " model.load_state_dict(checkpoint['state_dict'])\n",
1888 | " optimizer.load_state_dict(checkpoint['optimizer'])\n",
1889 | " # scheduler.load_state_dict(checkpoint['scheduler'])\n",
1890 | " print(\"=> loaded checkpoint '{}' (epoch {})\"\n",
1891 | " .format(args.resume, checkpoint['epoch']))\n",
1892 | " else:\n",
1893 | " print(\"=> no checkpoint found at '{}'\".format(ckpt))\n",
1894 | " best_f1 = 0\n",
1895 | " return (model, optimizer, scheduler, args, best_f1)"
1896 | ]
1897 | },
1898 | {
1899 | "cell_type": "code",
1900 | "execution_count": null,
1901 | "metadata": {},
1902 | "outputs": [],
1903 | "source": [
1904 | "def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_model_filename='model_best.pth.tar'):\n",
1905 | " torch.save(state, filename)\n",
1906 | " if is_best:\n",
1907 | " shutil.copyfile(filename, best_model_filename)"
1908 | ]
1909 | },
1910 | {
1911 | "cell_type": "code",
1912 | "execution_count": null,
1913 | "metadata": {},
1914 | "outputs": [],
1915 | "source": [
1916 | "class F1MicroAverageMeter(object):\n",
1917 | " \"\"\"Computes and stores F1 store\"\"\"\n",
1918 | " def __init__(self, threshold=0.5, small=1e-12):\n",
1919 | " self.threshold = threshold\n",
1920 | " self.small = small\n",
1921 | " self.reset()\n",
1922 | "\n",
1923 | " def reset(self):\n",
1924 | " self.TP = 0.0\n",
1925 | " self.FP = 0.0\n",
1926 | " self.FN = 0.0\n",
1927 | " self.TN = 0.0\n",
1928 | " self.precision = 0.0\n",
1929 | " self.recall = 0.0\n",
1930 | " self.f1 = 0.0\n",
1931 | "\n",
1932 | " def update(self, labels, pred):\n",
1933 | " tp, fp, fn, tn = self.confusion_matrix_(labels, pred)\n",
1934 | " self.TP += tp\n",
1935 | " self.FP += fp\n",
1936 | " self.FN += fn\n",
1937 | " self.TN += tn\n",
1938 | " self.precision = self.TP / (self.small + self.TP + self.FP)\n",
1939 | " self.recall = self.TP / (self.small + self.TP + self.FN)\n",
1940 | " self.f1 = (2.0 * self.precision * self.recall) / (self.small + self.precision + self.recall)\n",
1941 | " \n",
1942 | " def confusion_matrix_(self, labels, pred):\n",
1943 | " with torch.no_grad():\n",
1944 | " real = labels\n",
1945 | " fake = 1.0 - real\n",
1946 | " pos = pred.ge(self.threshold)\n",
1947 | " pos = pos.float()\n",
1948 | " neg = 1.0 - pos\n",
1949 | " tp = torch.sum(real * pos).item()\n",
1950 | " fp = torch.sum(fake * pos).item()\n",
1951 | " fn = torch.sum(real * neg).item()\n",
1952 | " tn = torch.sum(fake * neg).item()\n",
1953 | " return (tp, fp, fn, tn)\n",
1954 | " "
1955 | ]
1956 | },
1957 | {
1958 | "cell_type": "code",
1959 | "execution_count": null,
1960 | "metadata": {},
1961 | "outputs": [],
1962 | "source": [
1963 | "class AverageMeter(object):\n",
1964 | " \"\"\"Computes and stores the average and current value\"\"\"\n",
1965 | " def __init__(self):\n",
1966 | " self.reset()\n",
1967 | "\n",
1968 | " def reset(self):\n",
1969 | " self.val = 0\n",
1970 | " self.avg = 0\n",
1971 | " self.sum = 0\n",
1972 | " self.count = 0\n",
1973 | "\n",
1974 | " def update(self, val, n=1):\n",
1975 | " self.val = val\n",
1976 | " self.sum += val * n\n",
1977 | " self.count += n\n",
1978 | " self.avg = self.sum / self.count"
1979 | ]
1980 | },
1981 | {
1982 | "cell_type": "code",
1983 | "execution_count": null,
1984 | "metadata": {},
1985 | "outputs": [],
1986 | "source": [
1987 | "def adjust_learning_rate(optimizer, scheduler, epoch, measure, args):\n",
1988 | " if not args.test_overfit:\n",
1989 | " scheduler.step(measure)\n"
1990 | ]
1991 | },
1992 | {
1993 | "cell_type": "code",
1994 | "execution_count": null,
1995 | "metadata": {},
1996 | "outputs": [],
1997 | "source": [
1998 | "def train(train_loader, model, criterion, optimizer, epoch):\n",
1999 | " batch_time = AverageMeter()\n",
2000 | " data_time = AverageMeter()\n",
2001 | " losses = AverageMeter()\n",
2002 | " cmpoint5 = F1MicroAverageMeter(threshold=0.5)\n",
2003 | "\n",
2004 | " # switch to train mode\n",
2005 | " model.train()\n",
2006 | "\n",
2007 | " end = time.time()\n",
2008 | " for i, (input, target) in enumerate(train_loader):\n",
2009 | " # measure data loading time\n",
2010 | " data_time.update(time.time() - end)\n",
2011 | "\n",
2012 | " target = target.cuda(non_blocking=True)\n",
2013 | "\n",
2014 | " # compute output\n",
2015 | " output = model(input)\n",
2016 | " loss = criterion(output, target)\n",
2017 | "\n",
2018 | " # measure F1 and record loss\n",
2019 | " losses.update(loss.item(), input.size(0))\n",
2020 | " cmpoint5.update(target, torch.exp(output))\n",
2021 | "\n",
2022 | " # compute gradient and do SGD step\n",
2023 | " optimizer.zero_grad()\n",
2024 | " loss.backward()\n",
2025 | " optimizer.step()\n",
2026 | "\n",
2027 | " # measure elapsed time\n",
2028 | " batch_time.update(time.time() - end)\n",
2029 | " end = time.time()\n",
2030 | " \n",
2031 | " \n",
2032 | "\n",
2033 | " if i % args.print_freq == 0:\n",
2034 | " print('Epoch: [{0}][{1}/{2}]\\t'\n",
2035 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n",
2036 | " 'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n",
2037 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n",
2038 | " 'Precision {cmpoint5.precision:.3f}\\t'\n",
2039 | " 'Recall {cmpoint5.recall:.3f}\\t'\n",
2040 | " 'F1 {cmpoint5.f1:.3f}'.format(\n",
2041 | " epoch, i, len(train_loader), batch_time=batch_time,\n",
2042 | " data_time=data_time, loss=losses, cmpoint5=cmpoint5))"
2043 | ]
2044 | },
2045 | {
2046 | "cell_type": "code",
2047 | "execution_count": null,
2048 | "metadata": {},
2049 | "outputs": [],
2050 | "source": [
2051 | "def validate(val_loader, model, criterion):\n",
2052 | " batch_time = AverageMeter()\n",
2053 | " losses = AverageMeter()\n",
2054 | " cmpoint5 = F1MicroAverageMeter(threshold=0.5)\n",
2055 | "\n",
2056 | " # switch to evaluate mode\n",
2057 | " model.eval()\n",
2058 | "\n",
2059 | " with torch.no_grad():\n",
2060 | " end = time.time()\n",
2061 | " for i, (input, target) in enumerate(val_loader):\n",
2062 | " target = target.cuda(non_blocking=True)\n",
2063 | "\n",
2064 | " # compute output\n",
2065 | " output = model(input)\n",
2066 | " loss = criterion(output, target)\n",
2067 | "\n",
2068 | " # measure F1 and record loss\n",
2069 | " losses.update(loss.item(), input.size(0))\n",
2070 | " cmpoint5.update(target, torch.exp(output))\n",
2071 | " \n",
2072 | " # measure elapsed time\n",
2073 | " batch_time.update(time.time() - end)\n",
2074 | " end = time.time()\n",
2075 | "\n",
2076 | " if i % args.print_freq == 0:\n",
2077 | " print('Test: [{0}/{1}]\\t'\n",
2078 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n",
2079 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n",
2080 | " 'Precision {cmpoint5.precision:.3f}\\t'\n",
2081 | " 'Recall {cmpoint5.recall:.3f}\\t'\n",
2082 | " 'F1 {cmpoint5.f1:.3f}'.format(\n",
2083 | " i, len(val_loader), batch_time=batch_time, loss=losses,\n",
2084 | " cmpoint5=cmpoint5))\n",
2085 | "\n",
2086 | " print(' * Precision {cmpoint5.precision:.3f} Recall {cmpoint5.recall:.3f} F1 {cmpoint5.f1:.3f}'\n",
2087 | " .format(cmpoint5=cmpoint5))\n",
2088 | "\n",
2089 | " return cmpoint5.f1"
2090 | ]
2091 | },
2092 | {
2093 | "cell_type": "code",
2094 | "execution_count": null,
2095 | "metadata": {},
2096 | "outputs": [],
2097 | "source": [
2098 | "def test(ofname, pfname, args, test_dset, test_loader, best_model_ckpt, model, threshold=0.5, epoch=0):\n",
2099 | " \n",
2100 | "# checkpoint = torch.load(best_model_ckpt)\n",
2101 | "# model.load_state_dict(checkpoint['state_dict'])\n",
2102 | " \n",
2103 | " batch_time = AverageMeter()\n",
2104 | " res = OrderedDict()\n",
2105 | "\n",
2106 | " # switch to evaluate mode\n",
2107 | " model.eval()\n",
2108 | "\n",
2109 | " with torch.no_grad():\n",
2110 | " end = time.time()\n",
2111 | " for i, input in enumerate(test_loader):\n",
2112 | " # compute output\n",
2113 | " output = model(input)\n",
2114 | " spout = coo_matrix(torch.exp(output).ge(threshold).int().cpu().numpy())\n",
2115 | " for p in zip(spout.row, spout.col):\n",
2116 | " imid = test_dset.image_ids[i* args.batch_size+p[0]]\n",
2117 | " if imid not in res.keys():\n",
2118 | " res[imid] = [p[1]+1]\n",
2119 | " else:\n",
2120 | " res[imid].append(p[1]+1)\n",
2121 | " \n",
2122 | " # measure elapsed time\n",
2123 | " batch_time.update(time.time() - end)\n",
2124 | " end = time.time()\n",
2125 | "\n",
2126 | " if i % args.print_freq == 0:\n",
2127 | " print('Test: [{0}/{1}]\\t'\n",
2128 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'.format(\n",
2129 | " i, len(test_loader), batch_time=batch_time))\n",
2130 | " \n",
2131 | " ofname_ = \"%s%s%03d_%s\" % (os.path.dirname(ofname), os.sep, epoch, os.path.basename(ofname))\n",
2132 | " with open(ofname_, \"w\") as ofd:\n",
2133 | " ofd.write(\"image_id,label_id\\n\")\n",
2134 | " for k, v in res.items():\n",
2135 | " ofd.write(\"%d,%s\\n\" % (k, \" \".join(map(str,v))))\n",
2136 | " \n",
2137 | " pfname_ = \"%s%s%03d_%s\" % (os.path.dirname(pfname), os.sep, epoch, os.path.basename(pfname))\n",
2138 | " with open(pfname_, \"w\") as pfd:\n",
2139 | " json.dump(vars(args), pfd, sort_keys=True, indent=4)\n",
2140 | " \n",
2141 | " print(\"Output written to %s\\n\" % ofname_)\n",
2142 | " print(\"Program parameters written to %s\\n\" % pfname_)\n",
2143 | " sys.stdout.flush()"
2144 | ]
2145 | },
2146 | {
2147 | "cell_type": "code",
2148 | "execution_count": null,
2149 | "metadata": {},
2150 | "outputs": [],
2151 | "source": [
2152 | "def train_loop(train_loader, val_loader, test_loader, test_dset, args, optimizer, scheduler, model, criterion, threshold=0.5):\n",
2153 | " if args.evaluate:\n",
2154 | " validate(val_loader, model, criterion)\n",
2155 | " else:\n",
2156 | " model, optimizer, scheduler, args, best_f1 = load_checkpoint(model, optimizer, scheduler, args, resume=args.resume, ckpt=args.ckpt)\n",
2157 | " wut = None\n",
2158 | " if args.debug_weights:\n",
2159 | " wut = WeightUpdateTracker(model)\n",
2160 | " for epoch in range(args.start_epoch, args.epochs):\n",
2161 | " if args.distributed:\n",
2162 | " train_sampler.set_epoch(epoch)\n",
2163 | " # adjust_learning_rate(optimizer, epoch)\n",
2164 | "\n",
2165 | " # train for one epoch\n",
2166 | " train(train_loader, model, criterion, optimizer, epoch)\n",
2167 | "\n",
2168 | " if args.debug_weights:\n",
2169 | " # debug: track weight updates\n",
2170 | " wut.track(model)\n",
2171 | " print(wut)\n",
2172 | "\n",
2173 | " # evaluate on validation set\n",
2174 | " f1 = validate(val_loader, model, criterion)\n",
2175 | "\n",
2176 | " # remember best f1 and save checkpoint\n",
2177 | " is_best = f1 > best_f1\n",
2178 | " best_f1 = max(f1, best_f1)\n",
2179 | " save_checkpoint({\n",
2180 | " 'epoch': epoch + 1,\n",
2181 | " 'arch': args.arch,\n",
2182 | " 'state_dict': model.state_dict(),\n",
2183 | " 'best_f1': best_f1,\n",
2184 | " 'optimizer' : optimizer.state_dict(),\n",
2185 | " # 'scheduler' : scheduler.state_dict(),\n",
2186 | " }, is_best, filename=args.ckpt, best_model_filename=args.best)\n",
2187 | "\n",
2188 | " if is_best:\n",
2189 | " print(\"BEST: \", epoch)\n",
2190 | " sys.stdout.flush()\n",
2191 | " adjust_learning_rate(optimizer, scheduler, epoch, f1, args)\n",
2192 | " test(args.output_file, args.params_file, args, test_dset, test_loader, args.best, model, threshold=args.threshold, epoch=epoch) \n",
2193 | " rsync_and_verify(args.base_dir, args.perm_dir)\n"
2194 | ]
2195 | },
2196 | {
2197 | "cell_type": "code",
2198 | "execution_count": null,
2199 | "metadata": {},
2200 | "outputs": [],
2201 | "source": [
2202 | "train_loop(train_loader, val_loader, test_loader, test_dset, args, optimizer, scheduler, model, criterion, threshold=args.threshold)"
2203 | ]
2204 | },
2205 | {
2206 | "cell_type": "markdown",
2207 | "metadata": {},
2208 | "source": [
2209 | "# Inference"
2210 | ]
2211 | },
2212 | {
2213 | "cell_type": "code",
2214 | "execution_count": null,
2215 | "metadata": {},
2216 | "outputs": [],
2217 | "source": [
2218 | "# Move inference inside training loop for results from partially trained model\n",
2219 | "#test(args.output_file, args.params_file, args, test_dset, test_loader, args.best, model, threshold=args.threshold)"
2220 | ]
2221 | },
2222 | {
2223 | "cell_type": "markdown",
2224 | "metadata": {
2225 | "colab_type": "text",
2226 | "id": "gevaiXFORuTH"
2227 | },
2228 | "source": [
2229 | "# Save Results"
2230 | ]
2231 | },
2232 | {
2233 | "cell_type": "markdown",
2234 | "metadata": {
2235 | "colab_type": "text",
2236 | "id": "XlL1brDNRucP"
2237 | },
2238 | "source": []
2239 | },
2240 | {
2241 | "cell_type": "code",
2242 | "execution_count": null,
2243 | "metadata": {},
2244 | "outputs": [],
2245 | "source": []
2246 | },
2247 | {
2248 | "cell_type": "markdown",
2249 | "metadata": {
2250 | "colab": {
2251 | "autoexec": {
2252 | "startup": false,
2253 | "wait_interval": 0
2254 | }
2255 | },
2256 | "colab_type": "code",
2257 | "id": "PT9Shf_MRhui"
2258 | },
2259 | "source": []
2260 | }
2261 | ],
2262 | "metadata": {
2263 | "accelerator": "GPU",
2264 | "colab": {
2265 | "collapsed_sections": [],
2266 | "default_view": {},
2267 | "name": "TrainLoop.ipynb",
2268 | "provenance": [],
2269 | "version": "0.3.2",
2270 | "views": {}
2271 | },
2272 | "kernelspec": {
2273 | "display_name": "Python [default]",
2274 | "language": "python",
2275 | "name": "python3"
2276 | },
2277 | "language_info": {
2278 | "codemirror_mode": {
2279 | "name": "ipython",
2280 | "version": 3
2281 | },
2282 | "file_extension": ".py",
2283 | "mimetype": "text/x-python",
2284 | "name": "python",
2285 | "nbconvert_exporter": "python",
2286 | "pygments_lexer": "ipython3",
2287 | "version": "3.6.5"
2288 | }
2289 | },
2290 | "nbformat": 4,
2291 | "nbformat_minor": 1
2292 | }
2293 |
--------------------------------------------------------------------------------