├── .gitignore ├── dat └── ptb.pkl ├── environment.yml ├── readme.txt └── src ├── cifar_main.py ├── cifar_model.py ├── ptb_main.py └── ptb_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | src/.idea/ 3 | -------------------------------------------------------------------------------- /dat/ptb.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joansj/pytorch-intro/ef7df3bb6039c6697dc920539d8cd6495e6e2e1c/dat/ptb.pkl -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dl 2 | channels: 3 | - !!python/unicode 4 | 'soumith' 5 | - !!python/unicode 6 | 'defaults' 7 | dependencies: 8 | - !!python/unicode 9 | '_license=1.1=py27_1' 10 | - !!python/unicode 11 | 'alabaster=0.7.10=py27_0' 12 | - !!python/unicode 13 | 'anaconda-client=1.6.3=py27_0' 14 | - !!python/unicode 15 | 'anaconda=custom=py27_0' 16 | - !!python/unicode 17 | 'anaconda-navigator=1.5.2=py27_0' 18 | - !!python/unicode 19 | 'anaconda-project=0.4.1=py27_0' 20 | - !!python/unicode 21 | 'argcomplete=1.0.0=py27_1' 22 | - !!python/unicode 23 | 'asn1crypto=0.22.0=py27_0' 24 | - !!python/unicode 25 | 'astroid=1.4.9=py27_0' 26 | - !!python/unicode 27 | 'astropy=1.3.2=np112py27_0' 28 | - !!python/unicode 29 | 'babel=2.4.0=py27_0' 30 | - !!python/unicode 31 | 'backports=1.0=py27_0' 32 | - !!python/unicode 33 | 'backports_abc=0.5=py27_0' 34 | - !!python/unicode 35 | 'beautifulsoup4=4.6.0=py27_0' 36 | - !!python/unicode 37 | 'bitarray=0.8.1=py27_0' 38 | - !!python/unicode 39 | 'blaze=0.10.1=py27_0' 40 | - !!python/unicode 41 | 'bleach=1.5.0=py27_0' 42 | - !!python/unicode 43 | 'bokeh=0.12.5=py27_0' 44 | - !!python/unicode 45 | 'boto=2.46.1=py27_0' 46 | - !!python/unicode 47 | 'bottleneck=1.2.0=np112py27_0' 48 | - !!python/unicode 49 | 'cairo=1.14.8=0' 50 | - !!python/unicode 51 | 'cdecimal=2.3=py27_2' 52 | - !!python/unicode 53 | 'cffi=1.10.0=py27_0' 54 | - !!python/unicode 55 | 'chardet=3.0.2=py27_0' 56 | - !!python/unicode 57 | 'chest=0.2.3=py27_0' 58 | - !!python/unicode 59 | 'click=6.7=py27_0' 60 | - !!python/unicode 61 | 'cloudpickle=0.2.2=py27_0' 62 | - !!python/unicode 63 | 'clyent=1.2.2=py27_0' 64 | - !!python/unicode 65 | 'colorama=0.3.9=py27_0' 66 | - !!python/unicode 67 | 'configobj=5.0.6=py27_0' 68 | - !!python/unicode 69 | 'configparser=3.5.0=py27_0' 70 | - !!python/unicode 71 | 'contextlib2=0.5.4=py27_0' 72 | - !!python/unicode 73 | 'cryptography=1.8.1=py27_0' 74 | - !!python/unicode 75 | 'curl=7.52.1=0' 76 | - !!python/unicode 77 | 'cycler=0.10.0=py27_0' 78 | - !!python/unicode 79 | 'cython=0.25.2=py27_0' 80 | - !!python/unicode 81 | 'cytoolz=0.8.2=py27_0' 82 | - !!python/unicode 83 | 'dask=0.14.3=py27_0' 84 | - !!python/unicode 85 | 'datashape=0.5.4=py27_0' 86 | - !!python/unicode 87 | 'dbus=1.10.10=0' 88 | - !!python/unicode 89 | 'decorator=4.0.11=py27_0' 90 | - !!python/unicode 91 | 'dill=0.2.6=py27_0' 92 | - !!python/unicode 93 | 'docutils=0.13.1=py27_0' 94 | - !!python/unicode 95 | 'entrypoints=0.2.2=py27_1' 96 | - !!python/unicode 97 | 'enum34=1.1.6=py27_0' 98 | - !!python/unicode 99 | 'et_xmlfile=1.0.1=py27_0' 100 | - !!python/unicode 101 | 'expat=2.1.0=0' 102 | - !!python/unicode 103 | 'fastcache=1.0.2=py27_1' 104 | - !!python/unicode 105 | 'flask=0.12.1=py27_0' 106 | - !!python/unicode 107 | 'flask-cors=3.0.2=py27_0' 108 | - !!python/unicode 109 | 'fontconfig=2.12.1=3' 110 | - !!python/unicode 111 | 'freetype=2.5.5=2' 112 | - !!python/unicode 113 | 'funcsigs=1.0.2=py27_0' 114 | - !!python/unicode 115 | 'functools32=3.2.3.2=py27_0' 116 | - !!python/unicode 117 | 'futures=3.1.1=py27_0' 118 | - !!python/unicode 119 | 'get_terminal_size=1.0.0=py27_0' 120 | - !!python/unicode 121 | 'gevent=1.2.1=py27_0' 122 | - !!python/unicode 123 | 'glib=2.50.2=1' 124 | - !!python/unicode 125 | 'greenlet=0.4.12=py27_0' 126 | - !!python/unicode 127 | 'grin=1.2.1=py27_3' 128 | - !!python/unicode 129 | 'gst-plugins-base=1.8.0=0' 130 | - !!python/unicode 131 | 'gstreamer=1.8.0=0' 132 | - !!python/unicode 133 | 'h5py=2.7.0=np112py27_0' 134 | - !!python/unicode 135 | 'harfbuzz=0.9.39=2' 136 | - !!python/unicode 137 | 'hdf5=1.8.17=1' 138 | - !!python/unicode 139 | 'heapdict=1.0.0=py27_1' 140 | - !!python/unicode 141 | 'html5lib=0.999=py27_0' 142 | - !!python/unicode 143 | 'icu=54.1=0' 144 | - !!python/unicode 145 | 'idna=2.5=py27_0' 146 | - !!python/unicode 147 | 'imagesize=0.7.1=py27_0' 148 | - !!python/unicode 149 | 'ipaddress=1.0.18=py27_0' 150 | - !!python/unicode 151 | 'ipykernel=4.6.1=py27_0' 152 | - !!python/unicode 153 | 'ipython=5.3.0=py27_0' 154 | - !!python/unicode 155 | 'ipython_genutils=0.2.0=py27_0' 156 | - !!python/unicode 157 | 'ipywidgets=6.0.0=py27_0' 158 | - !!python/unicode 159 | 'isort=4.2.5=py27_0' 160 | - !!python/unicode 161 | 'itsdangerous=0.24=py27_0' 162 | - !!python/unicode 163 | 'jbig=2.1=0' 164 | - !!python/unicode 165 | 'jdcal=1.3=py27_0' 166 | - !!python/unicode 167 | 'jedi=0.10.2=py27_2' 168 | - !!python/unicode 169 | 'jinja2=2.9.6=py27_0' 170 | - !!python/unicode 171 | 'jpeg=9b=0' 172 | - !!python/unicode 173 | 'jsonschema=2.6.0=py27_0' 174 | - !!python/unicode 175 | 'jupyter=1.0.0=py27_3' 176 | - !!python/unicode 177 | 'jupyter_client=5.0.1=py27_0' 178 | - !!python/unicode 179 | 'jupyter_console=5.1.0=py27_0' 180 | - !!python/unicode 181 | 'jupyter_core=4.3.0=py27_0' 182 | - !!python/unicode 183 | 'lazy-object-proxy=1.2.2=py27_0' 184 | - !!python/unicode 185 | 'libffi=3.2.1=1' 186 | - !!python/unicode 187 | 'libgcc=5.2.0=0' 188 | - !!python/unicode 189 | 'libgfortran=3.0.0=1' 190 | - !!python/unicode 191 | 'libiconv=1.14=0' 192 | - !!python/unicode 193 | 'libpng=1.6.27=0' 194 | - !!python/unicode 195 | 'libsodium=1.0.10=0' 196 | - !!python/unicode 197 | 'libtiff=4.0.6=3' 198 | - !!python/unicode 199 | 'libxcb=1.12=1' 200 | - !!python/unicode 201 | 'libxml2=2.9.4=0' 202 | - !!python/unicode 203 | 'libxslt=1.1.29=0' 204 | - !!python/unicode 205 | 'llvmlite=0.18.0=py27_0' 206 | - !!python/unicode 207 | 'locket=0.2.0=py27_1' 208 | - !!python/unicode 209 | 'lxml=3.7.3=py27_0' 210 | - !!python/unicode 211 | 'markupsafe=0.23=py27_2' 212 | - !!python/unicode 213 | 'matplotlib=2.0.2=np112py27_0' 214 | - !!python/unicode 215 | 'mistune=0.7.4=py27_0' 216 | - !!python/unicode 217 | 'mkl=2017.0.1=0' 218 | - !!python/unicode 219 | 'mkl-service=1.1.2=py27_3' 220 | - !!python/unicode 221 | 'mpmath=0.19=py27_1' 222 | - !!python/unicode 223 | 'multipledispatch=0.4.9=py27_0' 224 | - !!python/unicode 225 | 'nbconvert=5.1.1=py27_0' 226 | - !!python/unicode 227 | 'nbformat=4.3.0=py27_0' 228 | - !!python/unicode 229 | 'networkx=1.11=py27_0' 230 | - !!python/unicode 231 | 'nltk=3.2.2=py27_0' 232 | - !!python/unicode 233 | 'nose=1.3.7=py27_1' 234 | - !!python/unicode 235 | 'notebook=5.0.0=py27_0' 236 | - !!python/unicode 237 | 'numba=0.33.0=np112py27_0' 238 | - !!python/unicode 239 | 'numexpr=2.6.2=np112py27_0' 240 | - !!python/unicode 241 | 'numpy=1.12.1=py27_0' 242 | - !!python/unicode 243 | 'numpydoc=0.6.0=py27_0' 244 | - !!python/unicode 245 | 'odo=0.5.0=py27_1' 246 | - !!python/unicode 247 | 'olefile=0.44=py27_0' 248 | - !!python/unicode 249 | 'openpyxl=2.4.7=py27_0' 250 | - !!python/unicode 251 | 'openssl=1.0.2k=2' 252 | - !!python/unicode 253 | 'packaging=16.8=py27_0' 254 | - !!python/unicode 255 | 'pandas=0.20.1=np112py27_0' 256 | - !!python/unicode 257 | 'pandocfilters=1.4.1=py27_0' 258 | - !!python/unicode 259 | 'partd=0.3.8=py27_0' 260 | - !!python/unicode 261 | 'path.py=10.3.1=py27_0' 262 | - !!python/unicode 263 | 'pathlib2=2.2.1=py27_0' 264 | - !!python/unicode 265 | 'patsy=0.4.1=py27_0' 266 | - !!python/unicode 267 | 'pcre=8.39=1' 268 | - !!python/unicode 269 | 'pep8=1.7.0=py27_0' 270 | - !!python/unicode 271 | 'pexpect=4.2.1=py27_0' 272 | - !!python/unicode 273 | 'pickleshare=0.7.4=py27_0' 274 | - !!python/unicode 275 | 'pillow=4.1.1=py27_0' 276 | - !!python/unicode 277 | 'pip=9.0.1=py27_1' 278 | - !!python/unicode 279 | 'pixman=0.34.0=0' 280 | - !!python/unicode 281 | 'ply=3.10=py27_0' 282 | - !!python/unicode 283 | 'prompt_toolkit=1.0.14=py27_0' 284 | - !!python/unicode 285 | 'psutil=5.2.2=py27_0' 286 | - !!python/unicode 287 | 'ptyprocess=0.5.1=py27_0' 288 | - !!python/unicode 289 | 'py=1.4.33=py27_0' 290 | - !!python/unicode 291 | 'pyasn1=0.2.3=py27_0' 292 | - !!python/unicode 293 | 'pycairo=1.10.0=py27_0' 294 | - !!python/unicode 295 | 'pycosat=0.6.2=py27_0' 296 | - !!python/unicode 297 | 'pycparser=2.17=py27_0' 298 | - !!python/unicode 299 | 'pycrypto=2.6.1=py27_4' 300 | - !!python/unicode 301 | 'pycurl=7.43.0=py27_2' 302 | - !!python/unicode 303 | 'pyflakes=1.5.0=py27_0' 304 | - !!python/unicode 305 | 'pygments=2.2.0=py27_0' 306 | - !!python/unicode 307 | 'pylint=1.6.4=py27_1' 308 | - !!python/unicode 309 | 'pyopenssl=17.0.0=py27_0' 310 | - !!python/unicode 311 | 'pyparsing=2.1.4=py27_0' 312 | - !!python/unicode 313 | 'pyqt=5.6.0=py27_2' 314 | - !!python/unicode 315 | 'pytables=3.3.0=np112py27_0' 316 | - !!python/unicode 317 | 'pytest=3.0.7=py27_0' 318 | - !!python/unicode 319 | 'python=2.7.13=0' 320 | - !!python/unicode 321 | 'python-dateutil=2.6.0=py27_0' 322 | - !!python/unicode 323 | 'pytz=2017.2=py27_0' 324 | - !!python/unicode 325 | 'pywavelets=0.5.2=np112py27_0' 326 | - !!python/unicode 327 | 'pyyaml=3.12=py27_0' 328 | - !!python/unicode 329 | 'pyzmq=16.0.2=py27_0' 330 | - !!python/unicode 331 | 'qt=5.6.2=4' 332 | - !!python/unicode 333 | 'qtawesome=0.4.4=py27_0' 334 | - !!python/unicode 335 | 'qtconsole=4.3.0=py27_0' 336 | - !!python/unicode 337 | 'qtpy=1.2.1=py27_0' 338 | - !!python/unicode 339 | 'readline=6.2=2' 340 | - !!python/unicode 341 | 'redis=3.2.0=0' 342 | - !!python/unicode 343 | 'redis-py=2.10.5=py27_0' 344 | - !!python/unicode 345 | 'requests=2.14.2=py27_0' 346 | - !!python/unicode 347 | 'rope=0.9.4=py27_1' 348 | - !!python/unicode 349 | 'ruamel_yaml=0.11.14=py27_1' 350 | - !!python/unicode 351 | 'scandir=1.5=py27_0' 352 | - !!python/unicode 353 | 'scikit-image=0.13.0=np112py27_0' 354 | - !!python/unicode 355 | 'scikit-learn=0.18.1=np112py27_1' 356 | - !!python/unicode 357 | 'scipy=0.19.0=np112py27_0' 358 | - !!python/unicode 359 | 'seaborn=0.7.1=py27_0' 360 | - !!python/unicode 361 | 'setuptools=27.2.0=py27_0' 362 | - !!python/unicode 363 | 'simplegeneric=0.8.1=py27_1' 364 | - !!python/unicode 365 | 'singledispatch=3.4.0.3=py27_0' 366 | - !!python/unicode 367 | 'sip=4.18=py27_0' 368 | - !!python/unicode 369 | 'six=1.10.0=py27_0' 370 | - !!python/unicode 371 | 'snowballstemmer=1.2.1=py27_0' 372 | - !!python/unicode 373 | 'sockjs-tornado=1.0.3=py27_0' 374 | - !!python/unicode 375 | 'sphinx=1.5.6=py27_0' 376 | - !!python/unicode 377 | 'spyder=3.1.4=py27_0' 378 | - !!python/unicode 379 | 'sqlalchemy=1.1.9=py27_0' 380 | - !!python/unicode 381 | 'sqlite=3.13.0=0' 382 | - !!python/unicode 383 | 'ssl_match_hostname=3.4.0.2=py27_1' 384 | - !!python/unicode 385 | 'statsmodels=0.8.0=np112py27_0' 386 | - !!python/unicode 387 | 'subprocess32=3.2.7=py27_0' 388 | - !!python/unicode 389 | 'sympy=1.0=py27_0' 390 | - !!python/unicode 391 | 'terminado=0.6=py27_0' 392 | - !!python/unicode 393 | 'testpath=0.3=py27_0' 394 | - !!python/unicode 395 | 'tk=8.5.18=0' 396 | - !!python/unicode 397 | 'toolz=0.8.2=py27_0' 398 | - !!python/unicode 399 | 'tornado=4.5.1=py27_0' 400 | - !!python/unicode 401 | 'traitlets=4.3.2=py27_0' 402 | - !!python/unicode 403 | 'unicodecsv=0.14.1=py27_0' 404 | - !!python/unicode 405 | 'wcwidth=0.1.7=py27_0' 406 | - !!python/unicode 407 | 'werkzeug=0.12.2=py27_0' 408 | - !!python/unicode 409 | 'wheel=0.29.0=py27_0' 410 | - !!python/unicode 411 | 'widgetsnbextension=2.0.0=py27_0' 412 | - !!python/unicode 413 | 'wrapt=1.10.10=py27_0' 414 | - !!python/unicode 415 | 'xlrd=1.0.0=py27_0' 416 | - !!python/unicode 417 | 'xlsxwriter=0.9.6=py27_0' 418 | - !!python/unicode 419 | 'xlwt=1.2.0=py27_0' 420 | - !!python/unicode 421 | 'xz=5.2.2=1' 422 | - !!python/unicode 423 | 'yaml=0.1.6=0' 424 | - !!python/unicode 425 | 'zeromq=4.1.5=0' 426 | - !!python/unicode 427 | 'zlib=1.2.8=3' 428 | - !!python/unicode 429 | 'cuda80=1.0=0' 430 | - !!python/unicode 431 | 'pytorch=0.1.12=py27_2cu80' 432 | - !!python/unicode 433 | 'torchvision=0.1.8=py27_2' 434 | - pip: 435 | - backports-abc==0.5 436 | - backports.shutil-get-terminal-size==1.0.0 437 | - backports.ssl-match-hostname==3.4.0.2 438 | - et-xmlfile==1.0.1 439 | - ipython-genutils==0.2.0 440 | - jupyter-client==5.0.1 441 | - jupyter-console==5.1.0 442 | - jupyter-core==4.3.0 443 | - prompt-toolkit==1.0.14 444 | - tables==3.3.0 445 | - torch==0.1.12.post2 446 | - tqdm==4.11.2 447 | prefix: !!python/unicode '/home/jsj/miniconda2/envs/dl' 448 | 449 | -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | This is just a couple of simple scripts to illustrate how to use PyTorch. There are two examples: 2 | 3 | 1) Convolutional residual network (bottleneck variant) inspired by [R1] -- We use CIFAR-10 to train/test. 4 | 5 | 2) LSTM-based word language model inspired by [R2] -- We use a custom-processed version of the Penn Treebank data set. 6 | 7 | [R1] He et al. (2015), "Deep residual learning for image recognition", Proc. of the IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), pp. 770-778. https://arxiv.org/abs/1512.03385 8 | [R2] Zaremba et al. (2015), "Recurrent neural network regularization", Int. Conf. on Learning Representations (ICLR). https://arxiv.org/abs/1409.2329 9 | 10 | -------------------------------------------------------------------------------- /src/cifar_main.py: -------------------------------------------------------------------------------- 1 | import sys,argparse 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | # Parse arguments 6 | parser=argparse.ArgumentParser(description='Main script using CIFAR-10') 7 | parser.add_argument('--seed',default=17724,type=int,required=False,help='(default=%(default)d)') 8 | parser.add_argument('--log_interval',default=200,type=int,required=False,help='(default=%(default)d)') 9 | parser.add_argument('--data_folder',default='../dat/',type=str,required=False,help='(default=%(default)s') 10 | parser.add_argument('--batch_size',default=128,type=int,required=False,help='(default=%(default)d)') 11 | parser.add_argument('--num_epochs',default=300,type=int,required=False,help='(default=%(default)d)') 12 | parser.add_argument('--learning_rate',default=1e-3,type=float,required=False,help='(default=%(default)f)') 13 | args=parser.parse_args() 14 | print '*'*100,'\n',args,'\n','*'*100 15 | 16 | # Import pytorch stuff 17 | import torch 18 | import torchvision 19 | 20 | # Set random seed 21 | np.random.seed(args.seed) 22 | torch.manual_seed(args.seed) 23 | if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) 24 | else: print '[CUDA unavailable]'; sys.exit() 25 | 26 | # Model import 27 | import cifar_model 28 | 29 | ######################################################################################################################## 30 | # Load data 31 | ######################################################################################################################## 32 | 33 | print 'Load data...' 34 | 35 | # Set some data set parameters 36 | image_size=32 37 | image_channels=3 38 | num_classes=10 39 | 40 | # Prepare data augmentation 41 | train_transform=torchvision.transforms.Compose([ 42 | torchvision.transforms.RandomCrop(image_size,padding=4), # Random crop sub-parts of the image 43 | torchvision.transforms.RandomHorizontalFlip(), # Horizontal flip with probability=0.5 44 | torchvision.transforms.ToTensor(), # Conversion from PILImage to Tensor (also normalizes between 0 and 1) 45 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # Standardize 46 | ]) 47 | test_transform=torchvision.transforms.Compose([ 48 | torchvision.transforms.ToTensor(), 49 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 50 | ]) 51 | 52 | # Prepare data loaders 53 | train_set=torchvision.datasets.CIFAR10(root=args.data_folder,train=True,transform=train_transform,download=True) # Downloads data and points to it 54 | train_loader=torch.utils.data.DataLoader(dataset=train_set,batch_size=args.batch_size,shuffle=True,num_workers=3) # Provides an iterator over the elements of the data set 55 | test_set=torchvision.datasets.CIFAR10(root=args.data_folder,train=False,transform=test_transform,download=True) 56 | test_loader=torch.utils.data.DataLoader(dataset=test_set,batch_size=args.batch_size,shuffle=False,num_workers=3) 57 | 58 | ######################################################################################################################## 59 | # Inits 60 | ######################################################################################################################## 61 | 62 | print 'Init...' 63 | 64 | # Instantiate and init the model, and move it to the GPU 65 | #model=cifar_model.ConvNet((image_size,image_size,image_channels),num_classes).cuda() 66 | model=cifar_model.ResNet((image_size,image_size,image_channels),num_classes).cuda() 67 | #model=torch.nn.DataParallel(model,device_ids=[0,2]) # Just to have an idea how easy it is to parallelize on GPUs 68 | 69 | # Define loss function 70 | criterion=torch.nn.CrossEntropyLoss() 71 | 72 | # Define optimizer 73 | optimizer=torch.optim.Adam(model.parameters(),lr=args.learning_rate) 74 | 75 | ######################################################################################################################## 76 | # Train model 77 | ######################################################################################################################## 78 | 79 | print 'Train...' 80 | 81 | # Set model to training mode (we're using batch normalization) 82 | model.train() 83 | 84 | # Loop training epochs 85 | lossvals=[] 86 | for e in tqdm(range(args.num_epochs),desc='Epoch',ncols=100,ascii=True): 87 | 88 | # Loop batches 89 | for images,labels in tqdm(train_loader,desc='> Batch',ncols=100,ascii=True): 90 | 91 | # Wrap the variables into the gradient propagation chain and move them to the GPU 92 | images=torch.autograd.Variable(images).cuda() 93 | labels=torch.autograd.Variable(labels).cuda() 94 | #print images[0]; sys.exit() 95 | 96 | # Forward pass 97 | outputs=model.forward(images) 98 | loss=criterion(outputs,labels) 99 | 100 | # Backward pass 101 | optimizer.zero_grad() 102 | loss.backward() 103 | optimizer.step() 104 | 105 | # Log stuff 106 | lossvals.append(loss.data.cpu().numpy()) 107 | if len(lossvals)%args.log_interval==0: 108 | msg='Epoch %d, iter %.2e: \tLoss=%.4f \tSmooth loss=%.4f'%(e+1,len(lossvals),lossvals[-1],np.mean(lossvals[-args.log_interval:])) 109 | tqdm.write(msg) 110 | 111 | ######################################################################################################################## 112 | # Test model 113 | ######################################################################################################################## 114 | 115 | print 'Test...' 116 | 117 | # Change model to evaluation mode (we're using batch normalization) 118 | model.eval() 119 | 120 | # Loop images 121 | hits=[] 122 | for images,labels in tqdm(test_loader,desc='Evaluation',ncols=100,ascii=True): 123 | 124 | # Wrap the variables into the gradient propagation chain and move them to the GPU 125 | images=torch.autograd.Variable(images,volatile=True).cuda() 126 | labels=labels.cuda() 127 | 128 | # Forward pass 129 | outputs=model.forward(images) 130 | 131 | # Eval 132 | _,predicted=torch.max(outputs.data,1) 133 | correct=(labels==predicted).int().cpu().numpy() 134 | hits+=list(correct) 135 | 136 | # Report 137 | print '='*100,'\nTest accuracy = %.1f%%\n'%(100*np.mean(hits)),'='*100 138 | 139 | ######################################################################################################################## 140 | -------------------------------------------------------------------------------- /src/cifar_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | ######################################################################################################################## 5 | # Small convnet -- Just for example 6 | ######################################################################################################################## 7 | 8 | class ConvNet(torch.nn.Module): 9 | 10 | def __init__(self,image_size,num_classes): 11 | super(ConvNet,self).__init__() 12 | # Get image size (we only deal with square images) 13 | size,_,num_channels=image_size 14 | 15 | # Define a convolutional layer 16 | self.conv1=torch.nn.Conv2d(num_channels,10,kernel_size=3,stride=1,padding=1) 17 | # Define a rectified linear unit 18 | self.relu=torch.nn.ReLU() 19 | # Define a pooling layer 20 | self.pool=torch.nn.MaxPool2d(kernel_size=2,stride=2) 21 | # Define another convolutional layer 22 | self.conv2=torch.nn.Conv2d(10,20,kernel_size=3,stride=1,padding=1) 23 | # We do not need to define model relus nor pooling (no parameters to train, we can reuse the same ones) 24 | 25 | # Define final fully-connected layers 26 | self.fc1=torch.nn.Linear(20*8*8,120) 27 | self.fc2=torch.nn.Linear(120,num_classes) 28 | return 29 | 30 | def forward(self,x): 31 | # First stage: convolution -> relu -> pooling 32 | y=self.pool(self.relu(self.conv1(x))) 33 | # Second stage: convolution -> relu -> pooling 34 | y=self.pool(self.relu(self.conv2(y))) 35 | # Reshape to batch_size-by-whatever 36 | y=y.view(x.size(0),-1) 37 | # Last stage: fc -> relu -> fc 38 | y=self.fc2(self.relu(self.fc1(y))) 39 | # Return predictions 40 | return y 41 | 42 | ######################################################################################################################## 43 | 44 | 45 | ######################################################################################################################## 46 | # He et al. "Deep residual learning for image recognition." Proceedings of the IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), pp. 770-778. 2016. 47 | # https://arxiv.org/abs/1512.03385 48 | ######################################################################################################################## 49 | 50 | class ResNet(torch.nn.Module): 51 | 52 | def __init__(self,image_size,num_classes): 53 | super(ResNet,self).__init__() 54 | # Get image size (we only deal with square images) 55 | size,_,num_channels=image_size 56 | 57 | # Some configuration of our model (corresponding to ResNet50) 58 | channels_init=64 59 | compression=4 60 | structure=[(3,64,1),(4,128,2),(6,256,2),(3,512,2)] 61 | 62 | # Define initial module (non-residual) 63 | self.conv1=torch.nn.Conv2d(num_channels,channels_init,kernel_size=3,stride=1,padding=1) 64 | self.bn1=torch.nn.BatchNorm2d(channels_init) 65 | self.relu=torch.nn.ReLU() 66 | 67 | # Stack residual blocks following structure 68 | self.layers=[] 69 | channels_current=channels_init 70 | size_current=size 71 | # Loop structure elements 72 | for n,channels,stride in structure: 73 | # Loop internal structure layers 74 | for i in range(n): 75 | # Only do a stride!=1 at first piece of the super-block 76 | if i>0: stride=1 77 | # Instantiate and append block 78 | b=Block(channels_current,channels,compression,stride).cuda() 79 | self.layers.append(b) 80 | # Update current number of channels and image size 81 | channels_current=channels 82 | size_current=size_current//stride 83 | # Convert a list of layers into a nested, sequential operation (options for more elaborate relations than nested exist, see ModuleList) 84 | self.layers=torch.nn.Sequential(*self.layers) 85 | 86 | # Define an average pooling layer 87 | self.avgpool=torch.nn.AvgPool2d(size_current) 88 | 89 | # Define the final classification layer 90 | self.fc=torch.nn.Linear(channels_current,num_classes) 91 | 92 | return 93 | 94 | def forward(self,x): 95 | # Apply initial module 96 | y=self.relu(self.bn1(self.conv1(x))) 97 | # Apply all blocks 98 | y=self.layers(y) 99 | # Apply pooling 100 | y=self.avgpool(y) 101 | # Reshape 102 | y=y.view(x.size(0),-1) 103 | # Apply classification layer 104 | y=self.fc(y) 105 | # Return prediction 106 | return y 107 | 108 | ######################################################################################################################## 109 | 110 | class Block(torch.nn.Module): 111 | 112 | def __init__(self,channels_current,channels,compression,stride): 113 | super(Block,self).__init__() 114 | # Set the number of internal channels 115 | channels_internal=channels//compression 116 | 117 | # Define a relu 118 | self.relu=torch.nn.ReLU() 119 | # Define the three sequential convolutions + batch normalization 120 | self.conv1=torch.nn.Conv2d(channels_current,channels_internal,kernel_size=1) 121 | self.bn1=torch.nn.BatchNorm2d(channels_internal) 122 | self.conv2=torch.nn.Conv2d(channels_internal,channels_internal,kernel_size=3,stride=stride,padding=1) 123 | self.bn2=torch.nn.BatchNorm2d(channels_internal) 124 | self.conv3=torch.nn.Conv2d(channels_internal,channels,kernel_size=1) 125 | self.bn3=torch.nn.BatchNorm2d(channels) 126 | 127 | # Create the shortcut 128 | self.shortcut=torch.nn.Sequential() 129 | # If number of channels changed or we did some downsampling, the shortcut needs to take care of that 130 | if channels!=channels_current or stride!=1: 131 | # A list of layers: convolution + batch normalization 132 | self.shortcut=torch.nn.Sequential( 133 | torch.nn.Conv2d(channels_current,channels,kernel_size=1,stride=stride), 134 | torch.nn.BatchNorm2d(channels) 135 | ) 136 | 137 | return 138 | 139 | def forward(self,x): 140 | # Apply the three sequential convolutions + batch normalization 141 | y=self.relu(self.bn1(self.conv1(x))) 142 | y=self.relu(self.bn2(self.conv2(y))) 143 | y=self.bn3(self.conv3(y)) 144 | # Add the shortcut 145 | y+=self.shortcut(x) 146 | # Activation 147 | y=self.relu(y) 148 | # Return predictions 149 | return y 150 | 151 | ######################################################################################################################## 152 | -------------------------------------------------------------------------------- /src/ptb_main.py: -------------------------------------------------------------------------------- 1 | import sys,argparse 2 | import numpy as np 3 | import cPickle as pickle 4 | from tqdm import tqdm 5 | 6 | # Parse arguments 7 | parser=argparse.ArgumentParser(description='Main script using CIFAR-10') 8 | parser.add_argument('--seed',default=333,type=int,required=False,help='(default=%(default)d)') 9 | parser.add_argument('--filename_in',default='../dat/ptb.pkl',type=str,required=False,help='(default=%(default)s)') 10 | parser.add_argument('--batch_size',default=20,type=int,required=False,help='(default=%(default)d)') 11 | parser.add_argument('--num_epochs',default=40,type=int,required=False,help='(default=%(default)d)') 12 | parser.add_argument('--bptt',default=35,type=int,required=False,help='(default=%(default)d)') 13 | parser.add_argument('--learning_rate',default=20,type=float,required=False,help='(default=%(default)f)') 14 | parser.add_argument('--clip_norm',default=0.25,type=float,required=False,help='(default=%(default)f)') 15 | parser.add_argument('--anneal_factor',default=2.0,type=float,required=False,help='(default=%(default)f)') 16 | args=parser.parse_args() 17 | print '*'*100,'\n',args,'\n','*'*100 18 | 19 | # Import pytorch stuff 20 | import torch 21 | import torchvision 22 | 23 | # Set random seed 24 | np.random.seed(args.seed) 25 | torch.manual_seed(args.seed) 26 | if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) 27 | else: print '[CUDA unavailable]'; sys.exit() 28 | 29 | # Model import 30 | import ptb_model 31 | 32 | ######################################################################################################################## 33 | # Load data 34 | ######################################################################################################################## 35 | 36 | print 'Load data...' 37 | 38 | # Load numpy data 39 | data_train,data_valid,data_test,vocabulary_size=pickle.load(open(args.filename_in,'rb')) 40 | 41 | # Make it pytorch 42 | data_train=torch.LongTensor(data_train.astype(np.int64)) 43 | data_valid=torch.LongTensor(data_valid.astype(np.int64)) 44 | data_test=torch.LongTensor(data_test.astype(np.int64)) 45 | 46 | # Make batches 47 | num_batches=data_train.size(0)//args.batch_size # Get number of batches 48 | data_train=data_train[:num_batches*args.batch_size] # Trim last elements 49 | data_train=data_train.view(args.batch_size,-1) # Reshape 50 | num_batches=data_valid.size(0)//args.batch_size 51 | data_valid=data_valid[:num_batches*args.batch_size] 52 | data_valid=data_valid.view(args.batch_size,-1) 53 | num_batches=data_test.size(0)//args.batch_size 54 | data_test=data_test[:num_batches*args.batch_size] 55 | data_test=data_test.view(args.batch_size,-1) 56 | 57 | 58 | ######################################################################################################################## 59 | # Inits 60 | ######################################################################################################################## 61 | 62 | print 'Init...' 63 | 64 | # Instantiate and init the model, and move it to the GPU 65 | model=ptb_model.BasicRNNLM(vocabulary_size).cuda() 66 | 67 | # Define loss function 68 | criterion=torch.nn.CrossEntropyLoss(size_average=False) 69 | 70 | # Define optimizer 71 | optimizer=torch.optim.SGD(model.parameters(),lr=args.learning_rate) 72 | 73 | ######################################################################################################################## 74 | # Train/test routines 75 | ######################################################################################################################## 76 | 77 | def train(data,model,criterion,optimizer): 78 | 79 | # Set model to training mode (we're using dropout) 80 | model.train() 81 | # Get initial hidden and memory states 82 | states=model.get_initial_states(data.size(0)) 83 | 84 | # Loop sequence length (train) 85 | for i in tqdm(range(0,data.size(1)-1,args.bptt),desc='> Train',ncols=100,ascii=True): 86 | 87 | # Get the chunk and wrap the variables into the gradient propagation chain + move them to the GPU 88 | seqlen=int(np.min([args.bptt,data.size(1)-1-i])) 89 | x=torch.autograd.Variable(data[:,i:i+seqlen]).cuda() 90 | y=torch.autograd.Variable(data[:,i+1:i+seqlen+1]).cuda() 91 | 92 | # Truncated backpropagation 93 | states=model.detach(states) # Otherwise the model would try to backprop all the way to the start of the data set 94 | 95 | # Forward pass 96 | logits,states=model.forward(x,states) 97 | loss=criterion(logits,y.view(-1)) 98 | 99 | # Backward pass 100 | optimizer.zero_grad() 101 | loss.backward() 102 | torch.nn.utils.clip_grad_norm(model.parameters(),args.clip_norm) 103 | optimizer.step() 104 | 105 | return model 106 | 107 | 108 | def eval(data,model,criterion): 109 | 110 | # Set model to evaluation mode (we're using dropout) 111 | model.eval() 112 | # Get initial hidden and memory states 113 | states=model.get_initial_states(data.size(0)) 114 | 115 | # Loop sequence length (validation) 116 | total_loss=0 117 | num_loss=0 118 | for i in tqdm(range(0,data.size(1)-1,args.bptt),desc='> Eval',ncols=100,ascii=True): 119 | 120 | # Get the chunk and wrap the variables into the gradient propagation chain + move them to the GPU 121 | seqlen=int(np.min([args.bptt,data.size(1)-1-i])) 122 | x=torch.autograd.Variable(data[:,i:i+seqlen],volatile=True).cuda() 123 | y=torch.autograd.Variable(data[:,i+1:i+seqlen+1],volatile=True).cuda() 124 | 125 | # Truncated backpropagation 126 | states=model.detach(states) # Otherwise the model would try to backprop all the way to the start of the data set 127 | 128 | # Forward pass 129 | logits,states=model.forward(x,states) 130 | loss=criterion(logits,y.view(-1)) 131 | 132 | # Log stuff 133 | total_loss+=loss.data.cpu().numpy() 134 | num_loss+=np.prod(y.size()) 135 | 136 | return float(total_loss)/float(num_loss) 137 | 138 | ######################################################################################################################## 139 | # Train/validation/test 140 | ######################################################################################################################## 141 | 142 | print 'Train...' 143 | 144 | # Loop training epochs 145 | lr=args.learning_rate 146 | best_val_loss=np.inf 147 | for e in tqdm(range(args.num_epochs),desc='Epoch',ncols=100,ascii=True): 148 | 149 | # Train 150 | model=train(data_train,model,criterion,optimizer) 151 | 152 | # Validation 153 | val_loss=eval(data_valid,model,criterion) 154 | 155 | # Anneal learning rate 156 | if val_loss