├── data └── small │ ├── images │ ├── images_val │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ ├── 9.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 20.png │ │ ├── 21.png │ │ ├── 22.png │ │ ├── 23.png │ │ ├── 24.png │ │ ├── 25.png │ │ ├── 26.png │ │ ├── 27.png │ │ ├── 28.png │ │ └── 29.png │ ├── images_test │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 21.png │ │ ├── 22.png │ │ ├── 23.png │ │ ├── 24.png │ │ ├── 25.png │ │ ├── 26.png │ │ ├── 27.png │ │ ├── 28.png │ │ ├── 29.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ └── 9.png │ └── images_train │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 10.png │ │ ├── 11.png │ │ ├── 12.png │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 16.png │ │ ├── 17.png │ │ ├── 18.png │ │ ├── 19.png │ │ ├── 2.png │ │ ├── 20.png │ │ ├── 21.png │ │ ├── 22.png │ │ ├── 23.png │ │ ├── 24.png │ │ ├── 25.png │ │ ├── 26.png │ │ ├── 27.png │ │ ├── 28.png │ │ ├── 29.png │ │ ├── 3.png │ │ ├── 30.png │ │ ├── 31.png │ │ ├── 32.png │ │ ├── 33.png │ │ ├── 34.png │ │ ├── 35.png │ │ ├── 36.png │ │ ├── 37.png │ │ ├── 38.png │ │ ├── 39.png │ │ ├── 4.png │ │ ├── 40.png │ │ ├── 41.png │ │ ├── 42.png │ │ ├── 43.png │ │ ├── 44.png │ │ ├── 45.png │ │ ├── 46.png │ │ ├── 47.png │ │ ├── 48.png │ │ ├── 49.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ ├── 8.png │ │ └── 9.png │ ├── vocab.json │ ├── val.json │ ├── test.json │ └── train.json ├── config.py ├── model ├── metrics.py ├── dataloader.py ├── utils.py └── model.py ├── README.md ├── .gitignore ├── caption.py ├── train.py └── LICENSE /data/small/images/images_val/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/0.png -------------------------------------------------------------------------------- /data/small/images/images_val/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/1.png -------------------------------------------------------------------------------- /data/small/images/images_val/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/2.png -------------------------------------------------------------------------------- /data/small/images/images_val/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/3.png -------------------------------------------------------------------------------- /data/small/images/images_val/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/4.png -------------------------------------------------------------------------------- /data/small/images/images_val/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/5.png -------------------------------------------------------------------------------- /data/small/images/images_val/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/6.png -------------------------------------------------------------------------------- /data/small/images/images_val/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/7.png -------------------------------------------------------------------------------- /data/small/images/images_val/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/8.png -------------------------------------------------------------------------------- /data/small/images/images_val/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/9.png -------------------------------------------------------------------------------- /data/small/images/images_test/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/0.png -------------------------------------------------------------------------------- /data/small/images/images_test/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/1.png -------------------------------------------------------------------------------- /data/small/images/images_test/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/10.png -------------------------------------------------------------------------------- /data/small/images/images_test/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/11.png -------------------------------------------------------------------------------- /data/small/images/images_test/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/12.png -------------------------------------------------------------------------------- /data/small/images/images_test/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/13.png -------------------------------------------------------------------------------- /data/small/images/images_test/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/14.png -------------------------------------------------------------------------------- /data/small/images/images_test/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/15.png -------------------------------------------------------------------------------- /data/small/images/images_test/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/16.png -------------------------------------------------------------------------------- /data/small/images/images_test/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/17.png -------------------------------------------------------------------------------- /data/small/images/images_test/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/18.png -------------------------------------------------------------------------------- /data/small/images/images_test/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/19.png -------------------------------------------------------------------------------- /data/small/images/images_test/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/2.png -------------------------------------------------------------------------------- /data/small/images/images_test/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/20.png -------------------------------------------------------------------------------- /data/small/images/images_test/21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/21.png -------------------------------------------------------------------------------- /data/small/images/images_test/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/22.png -------------------------------------------------------------------------------- /data/small/images/images_test/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/23.png -------------------------------------------------------------------------------- /data/small/images/images_test/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/24.png -------------------------------------------------------------------------------- /data/small/images/images_test/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/25.png -------------------------------------------------------------------------------- /data/small/images/images_test/26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/26.png -------------------------------------------------------------------------------- /data/small/images/images_test/27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/27.png -------------------------------------------------------------------------------- /data/small/images/images_test/28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/28.png -------------------------------------------------------------------------------- /data/small/images/images_test/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/29.png -------------------------------------------------------------------------------- /data/small/images/images_test/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/3.png -------------------------------------------------------------------------------- /data/small/images/images_test/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/4.png -------------------------------------------------------------------------------- /data/small/images/images_test/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/5.png -------------------------------------------------------------------------------- /data/small/images/images_test/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/6.png -------------------------------------------------------------------------------- /data/small/images/images_test/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/7.png -------------------------------------------------------------------------------- /data/small/images/images_test/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/8.png -------------------------------------------------------------------------------- /data/small/images/images_test/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_test/9.png -------------------------------------------------------------------------------- /data/small/images/images_train/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/0.png -------------------------------------------------------------------------------- /data/small/images/images_train/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/1.png -------------------------------------------------------------------------------- /data/small/images/images_train/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/10.png -------------------------------------------------------------------------------- /data/small/images/images_train/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/11.png -------------------------------------------------------------------------------- /data/small/images/images_train/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/12.png -------------------------------------------------------------------------------- /data/small/images/images_train/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/13.png -------------------------------------------------------------------------------- /data/small/images/images_train/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/14.png -------------------------------------------------------------------------------- /data/small/images/images_train/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/15.png -------------------------------------------------------------------------------- /data/small/images/images_train/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/16.png -------------------------------------------------------------------------------- /data/small/images/images_train/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/17.png -------------------------------------------------------------------------------- /data/small/images/images_train/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/18.png -------------------------------------------------------------------------------- /data/small/images/images_train/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/19.png -------------------------------------------------------------------------------- /data/small/images/images_train/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/2.png -------------------------------------------------------------------------------- /data/small/images/images_train/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/20.png -------------------------------------------------------------------------------- /data/small/images/images_train/21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/21.png -------------------------------------------------------------------------------- /data/small/images/images_train/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/22.png -------------------------------------------------------------------------------- /data/small/images/images_train/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/23.png -------------------------------------------------------------------------------- /data/small/images/images_train/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/24.png -------------------------------------------------------------------------------- /data/small/images/images_train/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/25.png -------------------------------------------------------------------------------- /data/small/images/images_train/26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/26.png -------------------------------------------------------------------------------- /data/small/images/images_train/27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/27.png -------------------------------------------------------------------------------- /data/small/images/images_train/28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/28.png -------------------------------------------------------------------------------- /data/small/images/images_train/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/29.png -------------------------------------------------------------------------------- /data/small/images/images_train/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/3.png -------------------------------------------------------------------------------- /data/small/images/images_train/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/30.png -------------------------------------------------------------------------------- /data/small/images/images_train/31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/31.png -------------------------------------------------------------------------------- /data/small/images/images_train/32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/32.png -------------------------------------------------------------------------------- /data/small/images/images_train/33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/33.png -------------------------------------------------------------------------------- /data/small/images/images_train/34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/34.png -------------------------------------------------------------------------------- /data/small/images/images_train/35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/35.png -------------------------------------------------------------------------------- /data/small/images/images_train/36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/36.png -------------------------------------------------------------------------------- /data/small/images/images_train/37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/37.png -------------------------------------------------------------------------------- /data/small/images/images_train/38.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/38.png -------------------------------------------------------------------------------- /data/small/images/images_train/39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/39.png -------------------------------------------------------------------------------- /data/small/images/images_train/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/4.png -------------------------------------------------------------------------------- /data/small/images/images_train/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/40.png -------------------------------------------------------------------------------- /data/small/images/images_train/41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/41.png -------------------------------------------------------------------------------- /data/small/images/images_train/42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/42.png -------------------------------------------------------------------------------- /data/small/images/images_train/43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/43.png -------------------------------------------------------------------------------- /data/small/images/images_train/44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/44.png -------------------------------------------------------------------------------- /data/small/images/images_train/45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/45.png -------------------------------------------------------------------------------- /data/small/images/images_train/46.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/46.png -------------------------------------------------------------------------------- /data/small/images/images_train/47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/47.png -------------------------------------------------------------------------------- /data/small/images/images_train/48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/48.png -------------------------------------------------------------------------------- /data/small/images/images_train/49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/49.png -------------------------------------------------------------------------------- /data/small/images/images_train/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/5.png -------------------------------------------------------------------------------- /data/small/images/images_train/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/6.png -------------------------------------------------------------------------------- /data/small/images/images_train/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/7.png -------------------------------------------------------------------------------- /data/small/images/images_train/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/8.png -------------------------------------------------------------------------------- /data/small/images/images_train/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_train/9.png -------------------------------------------------------------------------------- /data/small/images/images_val/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/10.png -------------------------------------------------------------------------------- /data/small/images/images_val/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/11.png -------------------------------------------------------------------------------- /data/small/images/images_val/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/12.png -------------------------------------------------------------------------------- /data/small/images/images_val/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/13.png -------------------------------------------------------------------------------- /data/small/images/images_val/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/14.png -------------------------------------------------------------------------------- /data/small/images/images_val/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/15.png -------------------------------------------------------------------------------- /data/small/images/images_val/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/16.png -------------------------------------------------------------------------------- /data/small/images/images_val/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/17.png -------------------------------------------------------------------------------- /data/small/images/images_val/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/18.png -------------------------------------------------------------------------------- /data/small/images/images_val/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/19.png -------------------------------------------------------------------------------- /data/small/images/images_val/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/20.png -------------------------------------------------------------------------------- /data/small/images/images_val/21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/21.png -------------------------------------------------------------------------------- /data/small/images/images_val/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/22.png -------------------------------------------------------------------------------- /data/small/images/images_val/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/23.png -------------------------------------------------------------------------------- /data/small/images/images_val/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/24.png -------------------------------------------------------------------------------- /data/small/images/images_val/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/25.png -------------------------------------------------------------------------------- /data/small/images/images_val/26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/26.png -------------------------------------------------------------------------------- /data/small/images/images_val/27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/27.png -------------------------------------------------------------------------------- /data/small/images/images_val/28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/28.png -------------------------------------------------------------------------------- /data/small/images/images_val/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qs956/Latex_OCR_Pytorch/HEAD/data/small/images/images_val/29.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #数据路径 2 | data_name = 'CROHME' # 模型名称,仅在保存的时候用到 3 | vocab_path = './data/CROHME/vocab.json' 4 | train_set_path = './data/CROHME/train.json' 5 | val_set_path = './data/CROHME/val.json' 6 | 7 | 8 | # 模型参数 9 | emb_dim = 30 # 词嵌入维数80 10 | attention_dim = 128 # attention 层维度 256 11 | decoder_dim = 128 # decoder维度 128 12 | dropout = 0.5 13 | buckets = [[240, 100], [320, 80], [400, 80], [400, 100], [480, 80], [480, 100], 14 | [560, 80], [560, 100], [640, 80], [640, 100], [720, 80], [720, 100], 15 | [720, 120], [720, 200], [800, 100], [800, 320], [1000, 200], 16 | [1000, 400], [1200, 200], [1600, 200], 17 | ] 18 | 19 | 20 | # 训练参数 21 | start_epoch = 0 22 | epochs = 250 # 不触发早停机制时候最大迭代次数 23 | epochs_since_improvement = 0 # 用于跟踪在验证集上分数没有提高的迭代次数 24 | batch_size = 1 #训练解批大小 25 | test_batch_size = 2 #验证集批大小 26 | encoder_lr = 1e-4 # 学习率 27 | decoder_lr = 4e-4 # 学习率 28 | grad_clip = 5. # 梯度裁剪阈值 29 | alpha_c = 1. # regularization parameter for 'doubly stochastic attention', as in the paper 30 | best_score = 0. # 目前最好的 score 31 | print_freq = 100 # 状态的批次打印间隔 32 | # checkpoint = 'BEST_checkpoint_CROHME.pth.tar' # checkpoint文件目录(用于断点继续训练) 33 | checkpoint = None # checkpoint文件目录(用于断点继续训练) 34 | save_freq = 2 #保存的间隔 -------------------------------------------------------------------------------- /data/small/vocab.json: -------------------------------------------------------------------------------- 1 | {"\\mathrm": 1, "b": 2, "\\}": 3, "P": 4, "\\:": 5, "\\quad": 6, "\\ell": 7, "\\rangle": 8, "\\tilde": 9, "x": 10, ",": 11, "\\left(": 12, "=": 13, "l": 14, "L": 15, "k": 16, "\\cos": 17, "w": 18, "+": 19, "\\bigtriangleup": 20, "\\scriptstyle": 21, "\\psi": 22, "\\left[": 23, "Q": 24, "f": 25, "\\bar": 26, "v": 27, "M": 28, "\\vspace": 29, "V": 30, "&": 31, "\\sin": 32, "\\{": 33, "\\dot": 34, ";": 35, "s": 36, "j": 37, "\\,": 38, ")": 39, "N": 40, "/": 41, "8": 42, "\\right.": 43, "\\Phi": 44, "{": 45, "\\rho": 46, "u": 47, "\\right]": 48, "\\equiv": 49, "K": 50, "}": 51, "\\begin{array}": 52, "\\sqrt": 53, "[": 54, "\\xi": 55, "c": 56, ">": 57, "5": 58, "\\in": 59, "\\partial": 60, "\\Sigma": 61, "a": 62, "\\exp": 63, "_": 64, "r": 65, "\\infty": 66, "0": 67, "\\hat": 68, "\\": 69, "g": 70, "e": 71, "1": 72, "o": 73, "]": 74, "\\cal": 75, "p": 76, "*": 77, "d": 78, "\\int": 79, "\\forall": 80, "\\sim": 81, "A": 82, "\\eta": 83, "\\simeq": 84, "\\sum": 85, "G": 86, "\\omega": 87, "\\varphi": 88, "\\delta": 89, "\\cdots": 90, "\\overline": 91, "\\langle": 92, "\\theta": 93, "\\left\\{": 94, "(": 95, "\\Pi": 96, "|": 97, "\\pm": 98, "\\frac": 99, ".": 100, "\\pi": 101, ":": 102, "\\mu": 103, "\\dagger": 104, "\\sp": 105, "\\textstyle": 106, "\\right\\}": 107, "\\Omega": 108, "\\lambda": 109, "E": 110, "F": 111, "\\!": 112, "R": 113, "t": 114, "\\rbrace": 115, "\\kappa": 116, "H": 117, "4": 118, "\\chi": 119, "2": 120, "^": 121, "\\phi": 122, "q": 123, "\\alpha": 124, "S": 125, "m": 126, "W": 127, "\\prime": 128, "\\right)": 129, "z": 130, "B": 131, "\\widetilde": 132, "\\\\": 133, "\\beta": 134, "\\varepsilon": 135, "h": 136, "\\end{array}": 137, "6": 138, "\\Gamma": 139, "i": 140, "\\gamma": 141, "\\times": 142, "T": 143, "\\sinh": 144, "\\tau": 145, "\\qquad": 146, "\\lbrace": 147, "n": 148, "\\cosh": 149, "3": 150, "-": 151, "y": 152, "C": 153, "J": 154, "D": 155, "X": 156, "": 157, "": 158, "": 159, "": 0} -------------------------------------------------------------------------------- /model/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import distance 3 | from nltk.translate.bleu_score import sentence_bleu 4 | 5 | def evaluate(losses, top5accs, references, hypotheses): 6 | #用于在验证集上计算各种评价指标指导模型早停 7 | # Calculate scores 8 | bleu4 = 0.0 9 | for i,j in zip(references,hypotheses): 10 | bleu4 += max(sentence_bleu([i],j),0.01) 11 | bleu4 = bleu4/len(references) 12 | Exact_Match = exact_match_score(references, hypotheses) 13 | Edit_Distance = edit_distance(references, hypotheses) 14 | Score = bleu4 + Exact_Match + Edit_Distance/10 15 | print( 16 | '\n * LOSS:{loss.avg:.3f},TOP-3 ACCURACY:{top5.avg:.3f},BLEU-4:{bleu:.3f},Exact Match:{Exact_Match:.1f},Edit Distance:{Edit_Distance:.3f},Score:{Score:.6f}'.format( 17 | loss=losses, 18 | top5=top5accs, 19 | bleu=bleu4, 20 | Exact_Match=Exact_Match, 21 | Edit_Distance=Edit_Distance, 22 | Score = Score)) 23 | return Score 24 | 25 | 26 | 27 | def exact_match_score(references, hypotheses): 28 | """Computes exact match scores. 29 | 30 | Args: 31 | references: list of list of tokens (one ref) 32 | hypotheses: list of list of tokens (one hypothesis) 33 | 34 | Returns: 35 | exact_match: (float) 1 is perfect 36 | 37 | """ 38 | exact_match = 0 39 | for ref, hypo in zip(references, hypotheses): 40 | if np.array_equal(ref, hypo): 41 | exact_match += 1 42 | 43 | return exact_match / float(max(len(hypotheses), 1)) 44 | 45 | def edit_distance(references, hypotheses): 46 | """Computes Levenshtein distance between two sequences. 47 | 48 | Args: 49 | references: list of list of token (one hypothesis) 50 | hypotheses: list of list of token (one hypothesis) 51 | 52 | Returns: 53 | 1 - levenshtein distance: (higher is better, 1 is perfect) 54 | 55 | """ 56 | d_leven, len_tot = 0, 0 57 | for ref, hypo in zip(references, hypotheses): 58 | d_leven += distance.levenshtein(ref, hypo) 59 | len_tot += float(max(len(ref), len(hypo))) 60 | 61 | return 1. - d_leven / len_tot -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Latex_OCR_Pytorch 2 | 3 | 主要是这个版本的Pytorch实现: 4 | 5 | [LinXueyuanStdio/LaTeX_OCR_PRO](https://github.com/LinXueyuanStdio/LaTeX_OCR_PRO) 6 | 7 | 感谢@LinXueyuanStdio 的工作以及指导.本项目与上述项目思路一致,但在实现上修改了一些地方: 8 | 9 | * 数据集的重新定义,但使用原有类似的预处理方式 10 | * 代码简化,目前仅保留主要部分,命令行控制等在后续补充 11 | * 内存优化,相对较少的内存需求,支持较大批量的训练。但批大小一样的情况下实测速度提高不大 12 | * 使用Checkpoint特性,在编码过程中出现OOM则自动进行分段计算 13 | * 在训练时候采用贪婪策略,Beam Search仅在推断时候采用 14 | * Scheduled Sampling策略 15 | 16 | Follow these paper: 17 | 18 | 1. [Show, Attend and Tell(Kelvin Xu...)](https://arxiv.org/abs/1502.03044) 19 | 2. [Harvard's paper and dataset](http://lstm.seas.harvard.edu/latex/) 20 | 21 | Follow these tutorial: 22 | 23 | 1. [Seq2Seq for LaTeX generation](https://guillaumegenthial.github.io/image-to-latex.html). 24 | 2. [a PyTorch Tutorial to Image Captioning](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning). 25 | 26 | ## 环境 27 | 1. Python >= 3.6 28 | 29 | 2. Pytorch >= 1.2 30 | 31 | ## 数据 32 | 使用[LinXueyuanStdio/Data-for-LaTeX_OCR](https://github.com/LinXueyuanStdio/Data-for-LaTeX_OCR) 数据集,原仓库较大,后续提供打包下载. 33 | 34 | 已包括上述仓库中small数据集 35 | 印刷体数据全集[百度云](https://pan.baidu.com/s/1xIsgHDhVu85L8cGdqqG7kw ) 提取码:tapj [Google Drive](https://drive.google.com/open?id=1THp_O7uwavcjsnQXsxx_JPvYn9-gml7T) 36 | 自己划分的混合CROHME2011,2012数据集[Google Drive](https://drive.google.com/open?id=1KgpAzA7k8ayjPTstin6M8ykGsW8GR9bu) 37 | 38 | 39 | ## 训练模型 40 | 在自己划分CROHME2011,2012数据集上使用以下参数的训练模型[Google Drive](https://drive.google.com/open?id=1_geqm9a86TJKK9RpZ39d9X5655s4NXa9) 41 | emb_dim = 30 42 | attention_dim = 128 43 | decoder_dim = 128 44 | 后续补充模型测试结果以及colab 45 | 46 | ## 数据格式 47 | 48 | 数据集文件生成参考[utils.py](./model/utils.py)的get_latex_ocrdata 49 | 50 | 数据集文件json格式,包括训练集文件,验证集文件,字典文件. 51 | 52 | 字典格式: 53 | 54 | python字典(符号——编号)的json储存 55 | 56 | 数据集格式: 57 | 58 | ``` 59 | ​```shell 60 | 训练/验证数据集 61 | ├── file_name1 图片文件名 str 62 | │ ├── img_path:文件路径(到文件名,含后缀) str 63 | │ ├── size:图片尺寸 [长,宽] list 64 | │ ├── caption:图片代表的公式,各个符号之间必须要空格分隔 str 65 | │ └── caption_len:len(caption.split()) int 66 | | ... 67 | eg: 68 | { 69 | "0.png": 70 | { 71 | "img_path":"./mydata/0.png", 72 | "size":[442,62], 73 | "caption":"\frac { a + b } { 2 }", 74 | "caption_len":9, 75 | } 76 | "2.png":... 77 | } 78 | 79 | ​``` 80 | ``` 81 | 82 | 图片预处理 83 | 84 | 参考dataloader/data_turn主要进行以下操作 85 | 86 | 1. 灰度化 87 | 2. 裁剪公式部分 88 | 3. 上下左右各padding 8个像素 89 | 4. `[可选]`下采样 90 | 91 | 92 | ## To do 93 | 94 | - [ ] 推断部分 95 | - [ ] Attention层的可视化 96 | - [x] 预训练模型 97 | - [x] 打包的训练数据 98 | - [ ] perplexity指标 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data/full 2 | /data/CROHME 3 | /other 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # train result 135 | *.tar 136 | *.zip -------------------------------------------------------------------------------- /model/dataloader.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import json 4 | import cv2 5 | import numpy as np 6 | from config import vocab_path,buckets 7 | from torch.utils.data import Dataset 8 | from model.utils import load_json 9 | 10 | vocab = load_json(vocab_path) 11 | 12 | def get_new_size(old_size, buckets=buckets,ratio = 2): 13 | """Computes new size from buckets 14 | 15 | Args: 16 | old_size: (width, height) 17 | buckets: list of sizes 18 | 19 | Returns: 20 | new_size: original size or first bucket in iter order that matches the 21 | size. 22 | 23 | """ 24 | if buckets is None: 25 | return old_size 26 | else: 27 | w, h = old_size[0]/ratio,old_size[1]/ratio 28 | for (idx,(w_b, h_b)) in enumerate(buckets): 29 | if w_b >= w and h_b >= h: 30 | return w_b, h_b,idx 31 | 32 | return old_size 33 | 34 | def data_turn(img_data,pad_size = [8,8,8,8],resize = False): 35 | #找到字符区域边界 36 | nnz_inds = np.where(img_data != 255) 37 | y_min = np.min(nnz_inds[1]) 38 | y_max = np.max(nnz_inds[1]) 39 | x_min = np.min(nnz_inds[0]) 40 | x_max = np.max(nnz_inds[0]) 41 | old_im = img_data[x_min:x_max+1,y_min:y_max+1] 42 | 43 | #pad the image 44 | top, left, bottom, right = pad_size 45 | old_size = (old_im.shape[0] + left + right, old_im.shape[1] + top + bottom) 46 | new_im = np.ones(old_size , dtype = np.uint8)*255 47 | new_im[top:top+old_im.shape[0],left:left+old_im.shape[1]] = old_im 48 | if resize: 49 | new_size = get_new_size(old_size, buckets)[:2] 50 | new_im = cv2.resize(new_im,new_size, cv2.INTER_LANCZOS4) 51 | return new_im 52 | 53 | 54 | def label_transform(text,start_type = '',end_type = '',pad_type = '',max_len = 160): 55 | text = text.split() 56 | text = [start_type] + text + [end_type] 57 | # while len(text)']]*(int(max(cap_len_batch))-len(cap_batch[ii])) 136 | cap_batch = torch.LongTensor(cap_batch) 137 | yield torch.cat(img_batch,dim = 0),cap_batch,cap_len_batch 138 | img_batch,cap_batch,cap_len_batch = [],[],torch.zeros(self.batch_size).int() 139 | idx = 0 140 | if len(img_batch)==0: 141 | continue 142 | for ii in range(len(cap_batch)): 143 | cap_batch[ii] += [vocab['']]*(int(max(cap_len_batch))-len(cap_batch[ii])) 144 | cap_batch = torch.LongTensor(cap_batch) 145 | yield torch.cat(img_batch,dim = 0),cap_batch,cap_len_batch[:idx] 146 | 147 | def __iter__(self): 148 | return self.iter 149 | 150 | def __len__(self): # 总数据的多少 151 | count = 0 152 | for i in self.bucket_data: 153 | count += np.ceil(len(i)/self.batch_size) 154 | return int(count) -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import cv2 5 | import torch 6 | 7 | def load_json(path): 8 | with open(path,'r')as f: 9 | data = json.load(f) 10 | return data 11 | 12 | def cal_word_freq(vocab,formuladataset): 13 | #统计词频用于计算perplexity 14 | word_count = {} 15 | for i in vocab.values(): 16 | word_count[i] = 0 17 | count = 0 18 | for i in formuladataset.data.values(): 19 | words = i['caption'].split() 20 | for j in words: 21 | word_count[vocab[j]] += 1 22 | count += 1 23 | for i in word_count.keys(): 24 | word_count[i] = word_count[i]/count 25 | return word_count 26 | 27 | def get_latex_ocrdata(path,mode = 'val'): 28 | assert mode in ['val','train','test'] 29 | match = [] 30 | with open(path + 'matching/'+mode+'.matching.txt','r')as f: 31 | for i in f.readlines(): 32 | match.append(i[:-1]) 33 | 34 | formula = [] 35 | with open(path + 'formulas/'+mode+'.formulas.norm.txt','r')as f: 36 | for i in f.readlines(): 37 | formula.append(i[:-1]) 38 | 39 | vocab_temp = set() 40 | data = {} 41 | 42 | for i in match: 43 | img_path = path + 'images/images_' + mode + '/' + i.split()[0] 44 | try: 45 | img = cv2.imread(img_path) 46 | except: 47 | print('Can\'t read'+i.split()[0]) 48 | continue 49 | if img is None: 50 | continue 51 | size = (img.shape[1],img.shape[0]) 52 | del img 53 | temp = formula[int(i.split()[1])].replace('\\n','') 54 | # token = set() 55 | for j in temp.split(): 56 | # token.add(j) 57 | vocab_temp.add(j) 58 | data[i.split()[0]] = {'img_path':img_path,'size':size, 59 | 'caption':temp,'caption_len':len(temp.split())+2}#这里需要加上开始以及停止符 60 | # data[i.split()[0]] = {'img_path':path + 'images/images_' + mode + '/' + i.split()[0], 61 | # 'token':list(token),'caption':temp,'caption_len':len(temp.split())+2}#这里需要加上开始以及停止符 62 | vocab_temp = list(vocab_temp) 63 | vocab = {} 64 | for i in range(len(vocab_temp)): 65 | vocab[vocab_temp[i]] = i+1 66 | vocab[''] = len(vocab) + 1 67 | vocab[''] = len(vocab) + 1 68 | vocab[''] = len(vocab) + 1 69 | vocab[''] = 0 70 | return vocab,data 71 | 72 | 73 | def init_embedding(embeddings): 74 | """ 75 | Fills embedding tensor with values from the uniform distribution. 76 | :param embeddings: embedding tensor 77 | """ 78 | bias = np.sqrt(3.0 / embeddings.size(1)) 79 | torch.nn.init.uniform_(embeddings, -bias, bias) 80 | 81 | 82 | def load_embeddings(emb_file, word_map): 83 | """ 84 | Creates an embedding tensor for the specified word map, for loading into the model. 85 | :param emb_file: file containing embeddings (stored in GloVe format) 86 | :param word_map: word map 87 | :return: embeddings in the same order as the words in the word map, dimension of embeddings 88 | """ 89 | 90 | # Find embedding dimension 91 | with open(emb_file, 'r') as f: 92 | emb_dim = len(f.readline().split(' ')) - 1 93 | 94 | vocab = set(word_map.keys()) 95 | 96 | # Create tensor to hold embeddings, initialize 97 | embeddings = torch.FloatTensor(len(vocab), emb_dim) 98 | init_embedding(embeddings) 99 | 100 | # Read embedding file 101 | print("\nLoading embeddings...") 102 | for line in open(emb_file, 'r'): 103 | line = line.split(' ') 104 | 105 | emb_word = line[0] 106 | embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:]))) 107 | 108 | # Ignore word if not in train_vocab 109 | if emb_word not in vocab: 110 | continue 111 | 112 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding) 113 | 114 | return embeddings, emb_dim 115 | 116 | 117 | def clip_gradient(optimizer, grad_clip): 118 | """ 119 | 梯度裁剪用于避免梯度爆炸 120 | :param optimizer: optimizer with the gradients to be clipped 121 | :param grad_clip: clip value 122 | """ 123 | for group in optimizer.param_groups: 124 | for param in group['params']: 125 | if param.grad is not None: 126 | param.grad.data.clamp_(-grad_clip, grad_clip) 127 | 128 | 129 | def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, 130 | decoder_optimizer,score, is_best): 131 | """ 132 | Saves model checkpoint. 133 | :param data_name: base name of processed dataset 134 | :param epoch: epoch number 135 | :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score 136 | :param encoder: encoder model 137 | :param decoder: decoder model 138 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning 139 | :param decoder_optimizer: optimizer to update decoder's weights 140 | :param bleu4: validation BLEU-4 score for this epoch 141 | :param is_best: is this checkpoint the best so far? 142 | """ 143 | state = {'epoch': epoch, 144 | 'epochs_since_improvement': epochs_since_improvement, 145 | 'score': score, 146 | 'encoder': encoder, 147 | 'decoder': decoder, 148 | 'encoder_optimizer':encoder_optimizer, 149 | 'decoder_optimizer': decoder_optimizer} 150 | filename = 'checkpoint_' + data_name + '.pth.tar' 151 | torch.save(state, filename) 152 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 153 | if is_best: 154 | torch.save(state, 'BEST_' + filename) 155 | 156 | 157 | class AverageMeter(object): 158 | """ 159 | 一个用于跟踪变量当前值,平均值,和以及计数的对象 160 | """ 161 | 162 | def __init__(self): 163 | self.reset() 164 | 165 | def reset(self): 166 | self.val = 0 167 | self.avg = 0 168 | self.sum = 0 169 | self.count = 0 170 | 171 | def update(self, val, n=1): 172 | self.val = val 173 | self.sum += val * n 174 | self.count += n 175 | self.avg = self.sum / self.count 176 | 177 | 178 | def adjust_learning_rate(optimizer, shrink_factor): 179 | """ 180 | Shrinks learning rate by a specified factor. 181 | :param optimizer: optimizer whose learning rate must be shrunk. 182 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 183 | """ 184 | 185 | print("\nDECAYING learning rate.") 186 | for param_group in optimizer.param_groups: 187 | param_group['lr'] = param_group['lr'] * shrink_factor 188 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 189 | 190 | 191 | def accuracy(scores, targets, k): 192 | """ 193 | Computes top-k accuracy, from predicted and true labels. 194 | :param scores: scores from the model 195 | :param targets: true labels 196 | :param k: k in top-k accuracy 197 | :return: top-k accuracy 198 | """ 199 | 200 | batch_size = targets.size(0) 201 | _, ind = scores.topk(k, 1, True, True) 202 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 203 | correct_total = correct.view(-1).float().sum() # 0D tensor 204 | return correct_total.item() * (100.0 / batch_size) -------------------------------------------------------------------------------- /data/small/val.json: -------------------------------------------------------------------------------- 1 | {"0.png": {"img_path": "./data/small/images/images_val/0.png", "size": [442, 62], "caption": "d s ^ { 2 } = ( 1 - { \\frac { q c o s \\theta } { r } } ) ^ { \\frac { 2 } { 1 + \\alpha ^ { 2 } } } \\lbrace d r ^ { 2 } + r ^ { 2 } d \\theta ^ { 2 } + r ^ { 2 } s i n ^ { 2 } \\theta d \\varphi ^ { 2 } \\rbrace - { \\frac { d t ^ { 2 } } { ( 1 - { \\frac { q c o s \\theta } { r } } ) ^ { \\frac { 2 } { 1 + \\alpha ^ { 2 } } } } } .", "caption_len": 130}, "1.png": {"img_path": "./data/small/images/images_val/1.png", "size": [150, 59], "caption": "\\widetilde \\gamma _ { \\mathrm { h o p f } } \\simeq \\sum _ { n > 0 } \\widetilde { G } _ { n } { \\frac { ( - a ) ^ { n } } { 2 ^ { 2 n - 1 } } }", "caption_len": 53}, "2.png": {"img_path": "./data/small/images/images_val/2.png", "size": [180, 36], "caption": "( { \\cal L } _ { a } g ) _ { i j } = 0 , ( { \\cal L } _ { a } H ) _ { i j k } = 0 ,", "caption_len": 41}, "3.png": {"img_path": "./data/small/images/images_val/3.png", "size": [253, 50], "caption": "S _ { s t a t } = 2 \\pi \\sqrt { N _ { 5 } ^ { ( 1 ) } N _ { 5 } ^ { ( 2 ) } N _ { 5 } ^ { ( 3 ) } } \\left( \\sqrt { n } + \\sqrt { \\bar { n } } \\right)", "caption_len": 63}, "4.png": {"img_path": "./data/small/images/images_val/4.png", "size": [107, 65], "caption": "\\hat { N } _ { 3 } = \\sum \\sp f _ { j = 1 } a _ { j } \\sp { \\dagger } a _ { j } .", "caption_len": 35}, "5.png": {"img_path": "./data/small/images/images_val/5.png", "size": [141, 35], "caption": "\\, ^ { * } d ^ { * } H = \\kappa ^ { * } d \\phi = J _ { B } .", "caption_len": 28}, "6.png": {"img_path": "./data/small/images/images_val/6.png", "size": [373, 55], "caption": "{ \\frac { \\phi ^ { \\prime \\prime } } { A } } + { \\frac { 1 } { A } } \\left( - { \\frac { 1 } { 2 } } { \\frac { A ^ { \\prime } } { A } } + 2 { \\frac { B ^ { \\prime } } { B } } + { \\frac { 2 } { r } } \\right) \\phi ^ { \\prime } - { \\frac { 2 } { r ^ { 2 } } } \\phi - \\lambda \\phi ( \\phi ^ { 2 } - \\eta ^ { 2 } ) = 0 .", "caption_len": 115}, "7.png": {"img_path": "./data/small/images/images_val/7.png", "size": [152, 36], "caption": "\\partial _ { \\mu } ( F ^ { \\mu u } - e j ^ { \\mu } x ^ { u } ) = 0 .", "caption_len": 30}, "8.png": {"img_path": "./data/small/images/images_val/8.png", "size": [382, 63], "caption": "V _ { n s } ( { \\tilde { x } } ) = \\left( \\frac { { \\tilde { m } } N ^ { 2 } } { 1 6 \\pi } \\right) N g ^ { 2 n s - 1 } { \\tilde { x } } ^ { 2 } \\left\\{ { \\tilde { x } } ^ { 2 } - \\frac { 2 { \\tilde { b } } } { 3 } { \\tilde { x } } + \\frac { { \\tilde { b } } ^ { 2 } } { 3 } - ( - 1 ) ^ { n s } { \\tilde { c } } \\right\\} .", "caption_len": 124}, "9.png": {"img_path": "./data/small/images/images_val/9.png", "size": [284, 49], "caption": "g _ { i j } ( x ) = { \\frac { 1 } { a ^ { 2 } } } \\delta _ { i j } , \\phi ^ { a } ( x ) = \\phi ^ { a } , \\quad ( a , \\phi ^ { a } \\! : \\mathrm { c o n s t . } )", "caption_len": 68}, "10.png": {"img_path": "./data/small/images/images_val/10.png", "size": [185, 63], "caption": "\\rho _ { L } ( q ) = \\sum _ { m = 1 } ^ { L } \\ P _ { L } ( m ) \\ { \\frac { 1 } { q ^ { m - 1 } } } .", "caption_len": 48}, "11.png": {"img_path": "./data/small/images/images_val/11.png", "size": [145, 55], "caption": "e x p \\left( - \\frac { \\partial } { \\partial \\alpha _ { j } } \\theta ^ { j k } \\frac { \\partial } { \\partial \\alpha _ { k } } \\right)", "caption_len": 38}, "12.png": {"img_path": "./data/small/images/images_val/12.png", "size": [149, 36], "caption": "L _ { 0 } = \\Phi ( w ) = \\bigtriangleup \\Phi ( w ) ,", "caption_len": 19}, "13.png": {"img_path": "./data/small/images/images_val/13.png", "size": [143, 39], "caption": "\\left( D ^ { * } D ^ { * } + m ^ { 2 } \\right) { \\cal H } = 0", "caption_len": 26}, "14.png": {"img_path": "./data/small/images/images_val/14.png", "size": [98, 54], "caption": "{ \\frac { d V } { d \\Phi } } = - { \\frac { w \\Phi } { \\Phi _ { \\! _ { 0 } } ^ { 2 } } } .", "caption_len": 38}, "15.png": {"img_path": "./data/small/images/images_val/15.png", "size": [369, 49], "caption": "g ( z , \\bar { z } ) = - \\frac { 1 } { 2 } \\left[ x ( z , \\bar { z } ) s + x ^ { * } ( z , \\bar { z } ) s ^ { * } + u ^ { * } ( z , \\bar { z } ) t + u ( z , \\bar { z } ) t ^ { * } \\right] ,", "caption_len": 82}, "16.png": {"img_path": "./data/small/images/images_val/16.png", "size": [107, 37], "caption": "x _ { \\mu } ^ { c } = x _ { \\mu } + A _ { \\mu } .", "caption_len": 24}, "17.png": {"img_path": "./data/small/images/images_val/17.png", "size": [145, 57], "caption": "s = { \\frac { S } { V } } = { \\frac { A _ { H } } { l _ { p } ^ { 8 } V } } = { \\frac { T ^ { 2 } } { \\gamma } } .", "caption_len": 51}, "18.png": {"img_path": "./data/small/images/images_val/18.png", "size": [275, 55], "caption": "\\psi ( \\gamma ) = \\exp { - ( { \\textstyle { \\frac { g ^ { 2 } } { 2 } } } ) \\int _ { \\gamma } d y ^ { a } \\int _ { \\gamma } d y ^ { a ^ { \\prime } } D _ { 1 } ( y - y ^ { \\prime } ) }", "caption_len": 69}, "19.png": {"img_path": "./data/small/images/images_val/19.png", "size": [356, 55], "caption": "E = E _ { 0 } + \\frac { 1 } { 2 \\sinh ( \\gamma ( 0 ) / 2 ) } \\sinh \\left( \\gamma ( 0 ) \\left( \\frac { 1 } { 2 } + c ( 0 ) \\right) \\right) h c u _ { \\mathrm { v i b } }", "caption_len": 59}, "20.png": {"img_path": "./data/small/images/images_val/20.png", "size": [152, 52], "caption": "\\langle T _ { z z } \\rangle = - 3 \\times \\frac { \\pi ^ { 2 } } { 1 4 4 0 a ^ { 4 } } .", "caption_len": 34}, "21.png": {"img_path": "./data/small/images/images_val/21.png", "size": [292, 53], "caption": "\\partial _ { u } \\xi _ { z } ^ { ( 1 ) } + { \\frac { 1 } { u } } \\xi _ { z } ^ { ( 1 ) } = { \\frac { 1 } { ( \\pi T R ) ^ { 2 } u } } \\left[ C _ { z } H _ { z z } ^ { \\prime } + C _ { t } H _ { t z } ^ { \\prime } \\right] .", "caption_len": 92}, "22.png": {"img_path": "./data/small/images/images_val/22.png", "size": [356, 38], "caption": "S \\sim \\tilde { \\psi } Q _ { o } \\tilde { \\psi } + g _ { s } ^ { 1 / 2 } \\tilde { \\psi } ^ { 3 } + \\tilde { \\phi } Q _ { c } \\tilde { \\phi } + g _ { s } \\tilde { \\phi } ^ { 3 } + \\tilde { \\phi } B ( g _ { s } ^ { 1 / 2 } \\tilde { \\psi } ) + \\cdots .", "caption_len": 91}, "23.png": {"img_path": "./data/small/images/images_val/23.png", "size": [415, 63], "caption": "C ( x ^ { \\prime } , x ^ { \\prime \\prime } ) = C \\Phi ( x ^ { \\prime } , x ^ { \\prime \\prime } ) \\ , \\quad \\Phi ( x ^ { \\prime } , x ^ { \\prime \\prime } ) = \\exp \\left[ - i e \\int _ { x ^ { \\prime \\prime } } ^ { x ^ { \\prime } } d x ^ { \\mu } A _ { \\mu } ( x ) \\right] \\ ,", "caption_len": 93}, "24.png": {"img_path": "./data/small/images/images_val/24.png", "size": [308, 73], "caption": "\\tilde { \\alpha } = \\alpha \\beta ^ { - m } = \\left( \\begin{array} { c c c } { \\omega _ { k } ^ { - 2 y } \\omega _ { 2 d } ^ { 2 m } } & { 0 } & { 0 } \\\\ { 0 } & { \\omega _ { k } ^ { y } \\omega _ { 2 d } ^ { - m } } & { 0 } \\\\ { 0 } & { 0 } & { \\omega _ { k } ^ { y } \\omega _ { 2 d } ^ { - m } } \\\\ \\end{array} \\right)", "caption_len": 119}, "25.png": {"img_path": "./data/small/images/images_val/25.png", "size": [331, 38], "caption": "d s ^ { 2 } = H ^ { - 2 } f ( r ) d t ^ { 2 } + H ^ { 2 / ( n - 1 ) } ( f ( r ) ^ { - 1 } d r ^ { 2 } + r ^ { 2 } d \\Omega _ { n } ^ { 2 } ) ,", "caption_len": 71}, "26.png": {"img_path": "./data/small/images/images_val/26.png", "size": [283, 37], "caption": "y ^ { 2 } = \\rho \\cosh \\beta \\sin \\theta \\sin \\phi \\qquad \\qquad y ^ { 3 } = \\rho \\cos \\theta", "caption_len": 26}, "27.png": {"img_path": "./data/small/images/images_val/27.png", "size": [350, 39], "caption": "e ^ { A } = e ^ { A _ { 0 } } \\left( t _ { 0 } - \\mathrm { s i g n } ( m ) t \\right) ^ { - \\frac { m } { 2 } } , \\chi = \\chi _ { 0 } \\left( t _ { 0 } - \\mathrm { s i g n } ( m ) t \\right) ^ { m } ,", "caption_len": 79}, "28.png": {"img_path": "./data/small/images/images_val/28.png", "size": [282, 50], "caption": "\\gamma _ { j } { \\cal P } _ { j i } = \\frac { 4 } { 3 } \\{ [ A d T ] [ t _ { 8 } ^ { c } , [ t _ { 8 } ^ { c } , { \\gamma } _ { j } ] ] [ A d T ^ { - 1 } ] \\} { A d { \\hat { g } } } _ { i j } .", "caption_len": 88}, "29.png": {"img_path": "./data/small/images/images_val/29.png", "size": [97, 49], "caption": "K _ { \\mu u } = \\frac { 1 } { 2 } \\dot { g } _ { \\mu u } .", "caption_len": 26}} -------------------------------------------------------------------------------- /data/small/test.json: -------------------------------------------------------------------------------- 1 | {"0.png": {"img_path": "./data/small/images/images_test/0.png", "size": [442, 62], "caption": "d s ^ { 2 } = ( 1 - { \\frac { q c o s \\theta } { r } } ) ^ { \\frac { 2 } { 1 + \\alpha ^ { 2 } } } \\lbrace d r ^ { 2 } + r ^ { 2 } d \\theta ^ { 2 } + r ^ { 2 } s i n ^ { 2 } \\theta d \\varphi ^ { 2 } \\rbrace - { \\frac { d t ^ { 2 } } { ( 1 - { \\frac { q c o s \\theta } { r } } ) ^ { \\frac { 2 } { 1 + \\alpha ^ { 2 } } } } } .", "caption_len": 130}, "1.png": {"img_path": "./data/small/images/images_test/1.png", "size": [150, 59], "caption": "\\widetilde \\gamma _ { \\mathrm { h o p f } } \\simeq \\sum _ { n > 0 } \\widetilde { G } _ { n } { \\frac { ( - a ) ^ { n } } { 2 ^ { 2 n - 1 } } }", "caption_len": 53}, "2.png": {"img_path": "./data/small/images/images_test/2.png", "size": [180, 36], "caption": "( { \\cal L } _ { a } g ) _ { i j } = 0 , ( { \\cal L } _ { a } H ) _ { i j k } = 0 ,", "caption_len": 41}, "3.png": {"img_path": "./data/small/images/images_test/3.png", "size": [253, 50], "caption": "S _ { s t a t } = 2 \\pi \\sqrt { N _ { 5 } ^ { ( 1 ) } N _ { 5 } ^ { ( 2 ) } N _ { 5 } ^ { ( 3 ) } } \\left( \\sqrt { n } + \\sqrt { \\bar { n } } \\right)", "caption_len": 63}, "4.png": {"img_path": "./data/small/images/images_test/4.png", "size": [107, 65], "caption": "\\hat { N } _ { 3 } = \\sum \\sp f _ { j = 1 } a _ { j } \\sp { \\dagger } a _ { j } .", "caption_len": 35}, "5.png": {"img_path": "./data/small/images/images_test/5.png", "size": [141, 35], "caption": "\\, ^ { * } d ^ { * } H = \\kappa ^ { * } d \\phi = J _ { B } .", "caption_len": 28}, "6.png": {"img_path": "./data/small/images/images_test/6.png", "size": [373, 55], "caption": "{ \\frac { \\phi ^ { \\prime \\prime } } { A } } + { \\frac { 1 } { A } } \\left( - { \\frac { 1 } { 2 } } { \\frac { A ^ { \\prime } } { A } } + 2 { \\frac { B ^ { \\prime } } { B } } + { \\frac { 2 } { r } } \\right) \\phi ^ { \\prime } - { \\frac { 2 } { r ^ { 2 } } } \\phi - \\lambda \\phi ( \\phi ^ { 2 } - \\eta ^ { 2 } ) = 0 .", "caption_len": 115}, "7.png": {"img_path": "./data/small/images/images_test/7.png", "size": [152, 36], "caption": "\\partial _ { \\mu } ( F ^ { \\mu u } - e j ^ { \\mu } x ^ { u } ) = 0 .", "caption_len": 30}, "8.png": {"img_path": "./data/small/images/images_test/8.png", "size": [382, 63], "caption": "V _ { n s } ( { \\tilde { x } } ) = \\left( \\frac { { \\tilde { m } } N ^ { 2 } } { 1 6 \\pi } \\right) N g ^ { 2 n s - 1 } { \\tilde { x } } ^ { 2 } \\left\\{ { \\tilde { x } } ^ { 2 } - \\frac { 2 { \\tilde { b } } } { 3 } { \\tilde { x } } + \\frac { { \\tilde { b } } ^ { 2 } } { 3 } - ( - 1 ) ^ { n s } { \\tilde { c } } \\right\\} .", "caption_len": 124}, "9.png": {"img_path": "./data/small/images/images_test/9.png", "size": [284, 49], "caption": "g _ { i j } ( x ) = { \\frac { 1 } { a ^ { 2 } } } \\delta _ { i j } , \\phi ^ { a } ( x ) = \\phi ^ { a } , \\quad ( a , \\phi ^ { a } \\! : \\mathrm { c o n s t . } )", "caption_len": 68}, "10.png": {"img_path": "./data/small/images/images_test/10.png", "size": [185, 63], "caption": "\\rho _ { L } ( q ) = \\sum _ { m = 1 } ^ { L } \\ P _ { L } ( m ) \\ { \\frac { 1 } { q ^ { m - 1 } } } .", "caption_len": 48}, "11.png": {"img_path": "./data/small/images/images_test/11.png", "size": [145, 55], "caption": "e x p \\left( - \\frac { \\partial } { \\partial \\alpha _ { j } } \\theta ^ { j k } \\frac { \\partial } { \\partial \\alpha _ { k } } \\right)", "caption_len": 38}, "12.png": {"img_path": "./data/small/images/images_test/12.png", "size": [149, 36], "caption": "L _ { 0 } = \\Phi ( w ) = \\bigtriangleup \\Phi ( w ) ,", "caption_len": 19}, "13.png": {"img_path": "./data/small/images/images_test/13.png", "size": [143, 39], "caption": "\\left( D ^ { * } D ^ { * } + m ^ { 2 } \\right) { \\cal H } = 0", "caption_len": 26}, "14.png": {"img_path": "./data/small/images/images_test/14.png", "size": [98, 54], "caption": "{ \\frac { d V } { d \\Phi } } = - { \\frac { w \\Phi } { \\Phi _ { \\! _ { 0 } } ^ { 2 } } } .", "caption_len": 38}, "15.png": {"img_path": "./data/small/images/images_test/15.png", "size": [369, 49], "caption": "g ( z , \\bar { z } ) = - \\frac { 1 } { 2 } \\left[ x ( z , \\bar { z } ) s + x ^ { * } ( z , \\bar { z } ) s ^ { * } + u ^ { * } ( z , \\bar { z } ) t + u ( z , \\bar { z } ) t ^ { * } \\right] ,", "caption_len": 82}, "16.png": {"img_path": "./data/small/images/images_test/16.png", "size": [107, 37], "caption": "x _ { \\mu } ^ { c } = x _ { \\mu } + A _ { \\mu } .", "caption_len": 24}, "17.png": {"img_path": "./data/small/images/images_test/17.png", "size": [145, 57], "caption": "s = { \\frac { S } { V } } = { \\frac { A _ { H } } { l _ { p } ^ { 8 } V } } = { \\frac { T ^ { 2 } } { \\gamma } } .", "caption_len": 51}, "18.png": {"img_path": "./data/small/images/images_test/18.png", "size": [275, 55], "caption": "\\psi ( \\gamma ) = \\exp { - ( { \\textstyle { \\frac { g ^ { 2 } } { 2 } } } ) \\int _ { \\gamma } d y ^ { a } \\int _ { \\gamma } d y ^ { a ^ { \\prime } } D _ { 1 } ( y - y ^ { \\prime } ) }", "caption_len": 69}, "19.png": {"img_path": "./data/small/images/images_test/19.png", "size": [356, 55], "caption": "E = E _ { 0 } + \\frac { 1 } { 2 \\sinh ( \\gamma ( 0 ) / 2 ) } \\sinh \\left( \\gamma ( 0 ) \\left( \\frac { 1 } { 2 } + c ( 0 ) \\right) \\right) h c u _ { \\mathrm { v i b } }", "caption_len": 59}, "20.png": {"img_path": "./data/small/images/images_test/20.png", "size": [152, 52], "caption": "\\langle T _ { z z } \\rangle = - 3 \\times \\frac { \\pi ^ { 2 } } { 1 4 4 0 a ^ { 4 } } .", "caption_len": 34}, "21.png": {"img_path": "./data/small/images/images_test/21.png", "size": [292, 53], "caption": "\\partial _ { u } \\xi _ { z } ^ { ( 1 ) } + { \\frac { 1 } { u } } \\xi _ { z } ^ { ( 1 ) } = { \\frac { 1 } { ( \\pi T R ) ^ { 2 } u } } \\left[ C _ { z } H _ { z z } ^ { \\prime } + C _ { t } H _ { t z } ^ { \\prime } \\right] .", "caption_len": 92}, "22.png": {"img_path": "./data/small/images/images_test/22.png", "size": [356, 38], "caption": "S \\sim \\tilde { \\psi } Q _ { o } \\tilde { \\psi } + g _ { s } ^ { 1 / 2 } \\tilde { \\psi } ^ { 3 } + \\tilde { \\phi } Q _ { c } \\tilde { \\phi } + g _ { s } \\tilde { \\phi } ^ { 3 } + \\tilde { \\phi } B ( g _ { s } ^ { 1 / 2 } \\tilde { \\psi } ) + \\cdots .", "caption_len": 91}, "23.png": {"img_path": "./data/small/images/images_test/23.png", "size": [415, 63], "caption": "C ( x ^ { \\prime } , x ^ { \\prime \\prime } ) = C \\Phi ( x ^ { \\prime } , x ^ { \\prime \\prime } ) \\ , \\quad \\Phi ( x ^ { \\prime } , x ^ { \\prime \\prime } ) = \\exp \\left[ - i e \\int _ { x ^ { \\prime \\prime } } ^ { x ^ { \\prime } } d x ^ { \\mu } A _ { \\mu } ( x ) \\right] \\ ,", "caption_len": 93}, "24.png": {"img_path": "./data/small/images/images_test/24.png", "size": [308, 73], "caption": "\\tilde { \\alpha } = \\alpha \\beta ^ { - m } = \\left( \\begin{array} { c c c } { \\omega _ { k } ^ { - 2 y } \\omega _ { 2 d } ^ { 2 m } } & { 0 } & { 0 } \\\\ { 0 } & { \\omega _ { k } ^ { y } \\omega _ { 2 d } ^ { - m } } & { 0 } \\\\ { 0 } & { 0 } & { \\omega _ { k } ^ { y } \\omega _ { 2 d } ^ { - m } } \\\\ \\end{array} \\right)", "caption_len": 119}, "25.png": {"img_path": "./data/small/images/images_test/25.png", "size": [331, 38], "caption": "d s ^ { 2 } = H ^ { - 2 } f ( r ) d t ^ { 2 } + H ^ { 2 / ( n - 1 ) } ( f ( r ) ^ { - 1 } d r ^ { 2 } + r ^ { 2 } d \\Omega _ { n } ^ { 2 } ) ,", "caption_len": 71}, "26.png": {"img_path": "./data/small/images/images_test/26.png", "size": [283, 37], "caption": "y ^ { 2 } = \\rho \\cosh \\beta \\sin \\theta \\sin \\phi \\qquad \\qquad y ^ { 3 } = \\rho \\cos \\theta", "caption_len": 26}, "27.png": {"img_path": "./data/small/images/images_test/27.png", "size": [350, 39], "caption": "e ^ { A } = e ^ { A _ { 0 } } \\left( t _ { 0 } - \\mathrm { s i g n } ( m ) t \\right) ^ { - \\frac { m } { 2 } } , \\chi = \\chi _ { 0 } \\left( t _ { 0 } - \\mathrm { s i g n } ( m ) t \\right) ^ { m } ,", "caption_len": 79}, "28.png": {"img_path": "./data/small/images/images_test/28.png", "size": [282, 50], "caption": "\\gamma _ { j } { \\cal P } _ { j i } = \\frac { 4 } { 3 } \\{ [ A d T ] [ t _ { 8 } ^ { c } , [ t _ { 8 } ^ { c } , { \\gamma } _ { j } ] ] [ A d T ^ { - 1 } ] \\} { A d { \\hat { g } } } _ { i j } .", "caption_len": 88}, "29.png": {"img_path": "./data/small/images/images_test/29.png", "size": [97, 49], "caption": "K _ { \\mu u } = \\frac { 1 } { 2 } \\dot { g } _ { \\mu u } .", "caption_len": 26}} -------------------------------------------------------------------------------- /caption.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import json,cv2 6 | import torchvision.transforms as transforms 7 | import matplotlib.pyplot as plt 8 | import matplotlib.cm as cm 9 | import skimage.transform 10 | import argparse 11 | from model.dataloader import data_turn 12 | from PIL import Image 13 | 14 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | device = "cpu" 16 | 17 | 18 | def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=5): 19 | """ 20 | Reads an image and captions it with beam search. 21 | 22 | :param encoder: encoder model 23 | :param decoder: decoder model 24 | :param image_path: path to image 25 | :param word_map: word map 26 | :param beam_size: number of sequences to consider at each decode-step 27 | :return: caption, weights for visualization 28 | """ 29 | 30 | k = beam_size 31 | vocab_size = len(word_map) 32 | 33 | # 图片读取以及预处理过程 34 | img = cv2.imread(image_path) 35 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)#图片由BGR转灰度 36 | img = data_turn(img,resize = True)#图片预处理 37 | image = torch.FloatTensor(img).to(device) 38 | 39 | with torch.no_grad(): 40 | # Encode 41 | image = image.unsqueeze(0).unsqueeze(0) # (1, 3, 256, 256) 42 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 43 | enc_image_size = encoder_out.size(2),encoder_out.size(3) 44 | encoder_dim = encoder_out.size(1)#这里和普通的resnet输出的不同,resnet是最后一个维度是C 45 | 46 | # Flatten encoding 47 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 48 | num_pixels = encoder_out.size(1) 49 | 50 | # We'll treat the problem as having a batch size of k 51 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 52 | 53 | # Tensor to store top k previous words at each step; now they're just 54 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 55 | 56 | # Tensor to store top k sequences; now they're just 57 | seqs = k_prev_words # (k, 1) 58 | 59 | # Tensor to store top k sequences' scores; now they're just 0 60 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 61 | 62 | # Tensor to store top k sequences' alphas; now they're just 1s 63 | seqs_alpha = torch.ones(k, 1, enc_image_size[0], enc_image_size[1]).to(device) # (k, 1, enc_image_size, enc_image_size) 64 | 65 | # Lists to store completed sequences, their alphas and scores 66 | complete_seqs = list() 67 | complete_seqs_alpha = list() 68 | complete_seqs_scores = list() 69 | 70 | # Start decoding 71 | step = 1 72 | # h, c = decoder.init_hidden_state(encoder_out) 73 | h = decoder.init_hidden_state(encoder_out) 74 | 75 | # s <= k,一旦输出就会跳出该过程 76 | while True: 77 | 78 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 79 | 80 | awe, alpha = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 81 | # awe, _ = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 82 | 83 | alpha = alpha.view(-1, enc_image_size[0], enc_image_size[1]) # (s, enc_image_size, enc_image_size) 84 | 85 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) 86 | awe = gate * awe 87 | 88 | # h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 89 | h = decoder.decode_step(torch.cat([embeddings, awe], dim=1), h) # (s, decoder_dim) 90 | 91 | scores = decoder.fc(h) # (s, vocab_size) 92 | scores = F.log_softmax(scores, dim=1) 93 | 94 | # Add 95 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 96 | 97 | # 对于第一步,所有k个点都有相同的分数 (since same k previous words, h, c) 98 | if step == 1: 99 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 100 | else: 101 | # 展开并找到最高分数及其展开的索引 102 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 103 | 104 | # 将展开的索引转换为实际的分数索引 105 | prev_word_inds = top_k_words / vocab_size # (s) 106 | next_word_inds = top_k_words % vocab_size # (s) 107 | 108 | # 把新的单词加入到序列中, alphas 109 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 110 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 111 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 112 | 113 | # Which sequences are incomplete (didn't reach )? 114 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 115 | next_word != word_map['']] 116 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 117 | 118 | # Set aside complete sequences 119 | if len(complete_inds) > 0: 120 | complete_seqs.extend(seqs[complete_inds].tolist()) 121 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 122 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 123 | k -= len(complete_inds) # reduce beam length accordingly 124 | 125 | # 处理未结束的序列 126 | if k == 0: 127 | break 128 | seqs = seqs[incomplete_inds] 129 | seqs_alpha = seqs_alpha[incomplete_inds] 130 | h = h[prev_word_inds[incomplete_inds]] 131 | # c = c[prev_word_inds[incomplete_inds]] 132 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 133 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 134 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 135 | 136 | # Break if things have been going on too long 137 | print('step',step) 138 | if step > 160: 139 | break 140 | step += 1 141 | 142 | complete_seqs_scores = np.array(complete_seqs_scores) 143 | i = np.argmax(complete_seqs_scores) 144 | # i = complete_seqs_scores.index(max(complete_seqs_scores)) 145 | seq = complete_seqs[i] 146 | alphas = complete_seqs_alpha[i] 147 | 148 | return seq, alphas 149 | # return seq 150 | 151 | 152 | def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True): 153 | """ 154 | Visualizes caption with weights at every word. 155 | 156 | Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb 157 | 158 | :param image_path: path to image that has been captioned 159 | :param seq: caption 160 | :param alphas: weights 161 | :param rev_word_map: reverse word mapping, i.e. ix2word 162 | :param smooth: smooth weights? 163 | """ 164 | image = Image.open(image_path) 165 | # image = image.resize([14 * 24, 14 * 24], Image.LANCZOS) 166 | 167 | words = [rev_word_map[ind] for ind in seq] 168 | print(words) 169 | print(alphas.shape) 170 | 171 | # for t in range(len(words)): 172 | # if t > 50: 173 | # break 174 | # plt.subplot(np.ceil(len(words) / 5.), 5, t + 1) 175 | 176 | # plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12) 177 | # plt.imshow(image) 178 | # current_alpha = alphas[t, :] 179 | # if smooth: 180 | # alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=24, sigma=8) 181 | # else: 182 | # alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 24, 14 * 24]) 183 | # if t == 0: 184 | # plt.imshow(alpha, alpha=0) 185 | # else: 186 | # plt.imshow(alpha, alpha=0.8) 187 | # plt.set_cmap(cm.Greys_r) 188 | # plt.axis('off') 189 | # plt.show() 190 | 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser(description='Show, Attend, and Tell - Tutorial - Generate Caption') 194 | 195 | parser.add_argument('--img', '-i', default='./data/CROHME/images/images_train/TrainData2_15_sub_33.png',help='path to image') 196 | parser.add_argument('--model', '-m',default='BEST_checkpoint_CROHME.pth.tar', help='path to model') 197 | parser.add_argument('--word_map', '-wm', default='./data/CROHME/vocab.json',help='path to word map JSON') 198 | parser.add_argument('--beam_size', '-b', default=3, type=int, help='beam size for beam search') 199 | parser.add_argument('--dont_smooth', dest='smooth', action='store_false', help='do not smooth alpha overlay') 200 | 201 | args = parser.parse_args() 202 | 203 | # Load model 204 | checkpoint = torch.load(args.model, map_location=str(device)) 205 | decoder = checkpoint['decoder'] 206 | decoder = decoder.to(device) 207 | decoder.eval() 208 | encoder = checkpoint['encoder'] 209 | encoder = encoder.to(device) 210 | encoder.eval() 211 | 212 | # Load word map (word2ix) 213 | with open(args.word_map, 'r') as j: 214 | word_map = json.load(j) 215 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word 216 | 217 | # Encode, decode with attention and beam search 218 | seq, alphas = caption_image_beam_search(encoder, decoder, args.img, word_map, args.beam_size) 219 | print(seq) 220 | alphas = torch.FloatTensor(alphas) 221 | 222 | # Visualize caption and attention of best sequence 223 | visualize_att(args.img, seq, alphas, rev_word_map, args.smooth) -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | from torch import nn 5 | import torch.nn.functional as F 6 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | device = "cpu" 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self): 11 | super(Encoder,self).__init__() 12 | self.conv1 = nn.Conv2d(1,64,3,stride=1,padding=1) 13 | self.maxpool1 = nn.MaxPool2d(2,stride=1,padding=1) 14 | 15 | self.conv2 = nn.Conv2d(64,128,3,stride=1,padding=1) 16 | self.maxpool2 = nn.MaxPool2d(2,stride=1,padding=1) 17 | 18 | self.conv3 = nn.Conv2d(128,256,3,stride=1,padding=1) 19 | 20 | self.conv4 = nn.Conv2d(256,256,3,stride=1,padding=1) 21 | self.maxpool3 = nn.MaxPool2d((2,1),stride=(2,1),padding=(1,0)) 22 | 23 | self.conv5 = nn.Conv2d(256,512,3,stride=1,padding=1) 24 | self.maxpool4 = nn.MaxPool2d((1,2),stride=(1,2),padding=(0,1)) 25 | 26 | self.conv6 = nn.Conv2d(512,512,3) 27 | def forward(self,x): 28 | #layer1 29 | x = self.conv1(x) 30 | x = self.maxpool1(x) 31 | x = F.relu(x) 32 | 33 | #layer2 34 | x = self.conv2(x) 35 | x = self.maxpool2(x) 36 | x = F.relu(x) 37 | 38 | #layer3 39 | x = self.conv3(x) 40 | x = F.relu(x) 41 | 42 | #layer4 43 | x = self.conv4(x) 44 | x = self.maxpool3(x) 45 | x = F.relu(x) 46 | 47 | #layer5 48 | x = self.conv5(x) 49 | x = self.maxpool4(x) 50 | x = F.relu(x) 51 | 52 | #layer6 53 | x = self.conv6(x) 54 | x = F.relu(x) 55 | 56 | #位置嵌入 57 | x = x.permute(0,2,3,1) 58 | x = self.add_timing_signal_nd(x) 59 | x = x.permute(0,3,1,2) 60 | 61 | x = x.contiguous() 62 | return x 63 | #修改自: 64 | # https://github.com/tensorflow/tensor2tensor/blob/37465a1759e278e8f073cd04cd9b4fe377d3c740/tensor2tensor/layers/common_attention.py 65 | def add_timing_signal_nd(self, x, min_timescale=1.0, max_timescale=1.0e4): 66 | """Adds a bunch of sinusoids of different frequencies to a Tensor. 67 | 68 | Each channel of the input Tensor is incremented by a sinusoid of a difft 69 | frequency and phase in one of the positional dimensions. 70 | 71 | This allows attention to learn to use absolute and relative positions. 72 | Timing signals should be added to some precursors of both the query and the 73 | memory inputs to attention. 74 | 75 | The use of relative position is possible because sin(a+b) and cos(a+b) can 76 | be experessed in terms of b, sin(a) and cos(a). 77 | 78 | x is a Tensor with n "positional" dimensions, e.g. one dimension for a 79 | sequence or two dimensions for an image 80 | 81 | We use a geometric sequence of timescales starting with 82 | min_timescale and ending with max_timescale. The number of different 83 | timescales is equal to channels // (n * 2). For each timescale, we 84 | generate the two sinusoidal signals sin(timestep/timescale) and 85 | cos(timestep/timescale). All of these sinusoids are concatenated in 86 | the channels dimension. 87 | 88 | Args: 89 | x: a Tensor with shape [batch, d1 ... dn, channels] 90 | min_timescale: a float 91 | max_timescale: a float 92 | 93 | Returns: 94 | a Tensor the same shape as x. 95 | 96 | """ 97 | static_shape = list(x.shape) # [2, 512, 50, 120] 98 | num_dims = len(static_shape) - 2 # 2 99 | channels = x.shape[-1] # 512 100 | num_timescales = channels // (num_dims * 2) # 512 // (2*2) = 128 101 | log_timescale_increment = ( 102 | math.log(float(max_timescale) / float(min_timescale)) / 103 | (float(num_timescales) - 1)) 104 | inv_timescales = min_timescale * torch.exp( 105 | torch.FloatTensor([i for i in range(num_timescales)]) * -log_timescale_increment) # len == 128 106 | for dim in range(num_dims): # dim == 0; 1 107 | length = x.shape[dim + 1] # 要跳过前两个维度 108 | position = torch.arange(length).float() # len == 50 109 | scaled_time = torch.reshape(position,(-1,1)) * torch.reshape(inv_timescales,(1,-1)) 110 | #[50,1] x [1,128] = [50,128] 111 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], axis=1).to(device) # [50, 256] 112 | prepad = dim * 2 * num_timescales # 0; 256 113 | postpad = channels - (dim + 1) * 2 * num_timescales # 512-(1;2)*2*128 = 256; 0 114 | signal = F.pad(signal, (prepad,postpad,0,0)) # [50, 512] 115 | for _ in range(1 + dim): # 1; 2 116 | signal = signal.unsqueeze(0) 117 | for _ in range(num_dims - 1 - dim): # 1, 0 118 | signal = signal.unsqueeze(-2) 119 | x += signal # [1, 14, 1, 512]; [1, 1, 14, 512] 120 | return x 121 | 122 | class Attention(nn.Module): 123 | """ 124 | Attention Network. 125 | """ 126 | 127 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 128 | """ 129 | :param encoder_dim: feature size of encoded images 130 | :param decoder_dim: size of decoder's RNN 131 | :param attention_dim: size of the attention network 132 | """ 133 | super(Attention, self).__init__() 134 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 135 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 136 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 137 | self.relu = nn.ReLU() 138 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 139 | 140 | def forward(self, encoder_out, decoder_hidden): 141 | """ 142 | Forward propagation. 143 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 144 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 145 | :return: attention weighted encoding, weights 146 | """ 147 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) 148 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 149 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) 150 | alpha = self.softmax(att) # (batch_size, num_pixels) 151 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) 152 | 153 | return attention_weighted_encoding, alpha 154 | 155 | 156 | class DecoderWithAttention(nn.Module): 157 | """ 158 | Decoder. 159 | """ 160 | 161 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=512, dropout=0.5, p=0): 162 | """ 163 | :param attention_dim: size of attention network 164 | :param embed_dim: embedding size 165 | :param decoder_dim: size of decoder's RNN 166 | :param vocab_size: size of vocabulary 167 | :param encoder_dim: feature size of encoded images 168 | :param dropout: dropout 169 | """ 170 | super(DecoderWithAttention, self).__init__() 171 | 172 | self.encoder_dim = encoder_dim 173 | self.attention_dim = attention_dim 174 | self.embed_dim = embed_dim 175 | self.decoder_dim = decoder_dim 176 | self.vocab_size = vocab_size 177 | self.dropout = dropout 178 | 179 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim).to(device) # attention network 180 | 181 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 182 | self.dropout = nn.Dropout(p=self.dropout) 183 | # self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 184 | self.decode_step = nn.GRUCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 185 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 186 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 187 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 188 | self.sigmoid = nn.Sigmoid() 189 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 190 | self.init_weights() # initialize some layers with the uniform distribution 191 | self.p = p #teacher forcing概率 192 | 193 | def init_weights(self): 194 | """ 195 | Initializes some parameters with values from the uniform distribution, for easier convergence. 196 | """ 197 | self.embedding.weight.data.uniform_(-0.1, 0.1) 198 | self.fc.bias.data.fill_(0) 199 | self.fc.weight.data.uniform_(-0.1, 0.1) 200 | self.embedding.weight.data.uniform_(-0.1,0.1) 201 | 202 | def init_hidden_state(self, encoder_out): 203 | """ 204 | 根据编码器的图片输出初始化解码器中LSTM层状态 205 | :param encoder_out: 编码器的输出 (batch_size, num_pixels, encoder_dim) 206 | :return: hidden state, cell state 207 | """ 208 | mean_encoder_out = encoder_out.mean(dim=1) 209 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 210 | # c = self.init_c(mean_encoder_out) 211 | # return h, c 212 | return h 213 | 214 | def forward(self, encoder_out, encoded_captions, caption_lengths,p = 1): 215 | """ 216 | Forward propagation. 217 | :param encoder_out: encoder的输出 (batch_size, enc_image_size, enc_image_size, encoder_dim) 218 | :param encoded_captions: caption的编码张量,不是字符串! (batch_size, max_caption_length) 219 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 220 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 221 | """ 222 | self.p = p 223 | batch_size = encoder_out.size(0) 224 | encoder_dim = encoder_out.size(1)#这里和普通的resnet输出的不同,resnet是最后一个维度是C 225 | vocab_size = self.vocab_size 226 | 227 | # 把特征图展平作为上下文向量 228 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 229 | num_pixels = encoder_out.size(1) 230 | 231 | # Sort input data by decreasing lengths; why? apparent below 232 | caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True) 233 | # print('sort_ind',sort_ind,'encoder_out',encoder_out.shape,'encoder_captions',encoded_captions.shape) 234 | encoder_out = encoder_out[sort_ind] 235 | encoded_captions = encoded_captions[sort_ind] 236 | 237 | # Embedding 238 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 239 | 240 | # 初始化GRU状态 241 | # h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 242 | h = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 243 | 244 | # 我们一旦生成了就已经完成了解码 245 | # 因此需要解码的长度实际是 lengths - 1 246 | decode_lengths = (caption_lengths - 1).tolist() 247 | # 新建两个张量用于存放 word predicion scores and alphas 248 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) 249 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) 250 | 251 | # 在每一个时间步根据解码器的前一个状态以及经过attention加权后的encoder输出进行解码 252 | for t in range(max(decode_lengths)): 253 | #decode_lengths是解码长度降序的排列,batch_size_t求出当前时间步中需要进行解码的数量 254 | batch_size_t = sum([l > t for l in decode_lengths]) 255 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 256 | h[:batch_size_t]) 257 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 258 | attention_weighted_encoding = gate * attention_weighted_encoding 259 | # h, c = self.decode_step( 260 | # torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 261 | # (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 262 | #teahcer forcing 263 | if t==1 or (np.random.rand() < self.p) : 264 | h = self.decode_step( 265 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 266 | h[:batch_size_t]) # (batch_size_t, decoder_dim) 267 | else: 268 | h = self.decode_step( 269 | torch.cat([self.embedding(torch.argmax(predictions[:batch_size_t, t, :],dim = 1)), attention_weighted_encoding], dim=1), 270 | h[:batch_size_t]) # (batch_size_t, decoder_dim) 271 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 272 | predictions[:batch_size_t, t, :] = preds 273 | alphas[:batch_size_t, t, :] = alpha 274 | 275 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind -------------------------------------------------------------------------------- /data/small/train.json: -------------------------------------------------------------------------------- 1 | {"0.png": {"img_path": "./data/small/images/images_train/0.png", "size": [442, 62], "caption": "d s ^ { 2 } = ( 1 - { \\frac { q c o s \\theta } { r } } ) ^ { \\frac { 2 } { 1 + \\alpha ^ { 2 } } } \\lbrace d r ^ { 2 } + r ^ { 2 } d \\theta ^ { 2 } + r ^ { 2 } s i n ^ { 2 } \\theta d \\varphi ^ { 2 } \\rbrace - { \\frac { d t ^ { 2 } } { ( 1 - { \\frac { q c o s \\theta } { r } } ) ^ { \\frac { 2 } { 1 + \\alpha ^ { 2 } } } } } .", "caption_len": 130}, "1.png": {"img_path": "./data/small/images/images_train/1.png", "size": [150, 59], "caption": "\\widetilde \\gamma _ { \\mathrm { h o p f } } \\simeq \\sum _ { n > 0 } \\widetilde { G } _ { n } { \\frac { ( - a ) ^ { n } } { 2 ^ { 2 n - 1 } } }", "caption_len": 53}, "2.png": {"img_path": "./data/small/images/images_train/2.png", "size": [180, 36], "caption": "( { \\cal L } _ { a } g ) _ { i j } = 0 , ( { \\cal L } _ { a } H ) _ { i j k } = 0 ,", "caption_len": 41}, "3.png": {"img_path": "./data/small/images/images_train/3.png", "size": [253, 50], "caption": "S _ { s t a t } = 2 \\pi \\sqrt { N _ { 5 } ^ { ( 1 ) } N _ { 5 } ^ { ( 2 ) } N _ { 5 } ^ { ( 3 ) } } \\left( \\sqrt { n } + \\sqrt { \\bar { n } } \\right)", "caption_len": 63}, "4.png": {"img_path": "./data/small/images/images_train/4.png", "size": [107, 65], "caption": "\\hat { N } _ { 3 } = \\sum \\sp f _ { j = 1 } a _ { j } \\sp { \\dagger } a _ { j } .", "caption_len": 35}, "5.png": {"img_path": "./data/small/images/images_train/5.png", "size": [141, 35], "caption": "\\, ^ { * } d ^ { * } H = \\kappa ^ { * } d \\phi = J _ { B } .", "caption_len": 28}, "6.png": {"img_path": "./data/small/images/images_train/6.png", "size": [373, 55], "caption": "{ \\frac { \\phi ^ { \\prime \\prime } } { A } } + { \\frac { 1 } { A } } \\left( - { \\frac { 1 } { 2 } } { \\frac { A ^ { \\prime } } { A } } + 2 { \\frac { B ^ { \\prime } } { B } } + { \\frac { 2 } { r } } \\right) \\phi ^ { \\prime } - { \\frac { 2 } { r ^ { 2 } } } \\phi - \\lambda \\phi ( \\phi ^ { 2 } - \\eta ^ { 2 } ) = 0 .", "caption_len": 115}, "7.png": {"img_path": "./data/small/images/images_train/7.png", "size": [152, 36], "caption": "\\partial _ { \\mu } ( F ^ { \\mu u } - e j ^ { \\mu } x ^ { u } ) = 0 .", "caption_len": 30}, "8.png": {"img_path": "./data/small/images/images_train/8.png", "size": [382, 63], "caption": "V _ { n s } ( { \\tilde { x } } ) = \\left( \\frac { { \\tilde { m } } N ^ { 2 } } { 1 6 \\pi } \\right) N g ^ { 2 n s - 1 } { \\tilde { x } } ^ { 2 } \\left\\{ { \\tilde { x } } ^ { 2 } - \\frac { 2 { \\tilde { b } } } { 3 } { \\tilde { x } } + \\frac { { \\tilde { b } } ^ { 2 } } { 3 } - ( - 1 ) ^ { n s } { \\tilde { c } } \\right\\} .", "caption_len": 124}, "9.png": {"img_path": "./data/small/images/images_train/9.png", "size": [284, 49], "caption": "g _ { i j } ( x ) = { \\frac { 1 } { a ^ { 2 } } } \\delta _ { i j } , \\phi ^ { a } ( x ) = \\phi ^ { a } , \\quad ( a , \\phi ^ { a } \\! : \\mathrm { c o n s t . } )", "caption_len": 68}, "10.png": {"img_path": "./data/small/images/images_train/10.png", "size": [185, 63], "caption": "\\rho _ { L } ( q ) = \\sum _ { m = 1 } ^ { L } \\ P _ { L } ( m ) \\ { \\frac { 1 } { q ^ { m - 1 } } } .", "caption_len": 48}, "11.png": {"img_path": "./data/small/images/images_train/11.png", "size": [145, 55], "caption": "e x p \\left( - \\frac { \\partial } { \\partial \\alpha _ { j } } \\theta ^ { j k } \\frac { \\partial } { \\partial \\alpha _ { k } } \\right)", "caption_len": 38}, "12.png": {"img_path": "./data/small/images/images_train/12.png", "size": [149, 36], "caption": "L _ { 0 } = \\Phi ( w ) = \\bigtriangleup \\Phi ( w ) ,", "caption_len": 19}, "13.png": {"img_path": "./data/small/images/images_train/13.png", "size": [143, 39], "caption": "\\left( D ^ { * } D ^ { * } + m ^ { 2 } \\right) { \\cal H } = 0", "caption_len": 26}, "14.png": {"img_path": "./data/small/images/images_train/14.png", "size": [98, 54], "caption": "{ \\frac { d V } { d \\Phi } } = - { \\frac { w \\Phi } { \\Phi _ { \\! _ { 0 } } ^ { 2 } } } .", "caption_len": 38}, "15.png": {"img_path": "./data/small/images/images_train/15.png", "size": [369, 49], "caption": "g ( z , \\bar { z } ) = - \\frac { 1 } { 2 } \\left[ x ( z , \\bar { z } ) s + x ^ { * } ( z , \\bar { z } ) s ^ { * } + u ^ { * } ( z , \\bar { z } ) t + u ( z , \\bar { z } ) t ^ { * } \\right] ,", "caption_len": 82}, "16.png": {"img_path": "./data/small/images/images_train/16.png", "size": [107, 37], "caption": "x _ { \\mu } ^ { c } = x _ { \\mu } + A _ { \\mu } .", "caption_len": 24}, "17.png": {"img_path": "./data/small/images/images_train/17.png", "size": [145, 57], "caption": "s = { \\frac { S } { V } } = { \\frac { A _ { H } } { l _ { p } ^ { 8 } V } } = { \\frac { T ^ { 2 } } { \\gamma } } .", "caption_len": 51}, "18.png": {"img_path": "./data/small/images/images_train/18.png", "size": [275, 55], "caption": "\\psi ( \\gamma ) = \\exp { - ( { \\textstyle { \\frac { g ^ { 2 } } { 2 } } } ) \\int _ { \\gamma } d y ^ { a } \\int _ { \\gamma } d y ^ { a ^ { \\prime } } D _ { 1 } ( y - y ^ { \\prime } ) }", "caption_len": 69}, "19.png": {"img_path": "./data/small/images/images_train/19.png", "size": [356, 55], "caption": "E = E _ { 0 } + \\frac { 1 } { 2 \\sinh ( \\gamma ( 0 ) / 2 ) } \\sinh \\left( \\gamma ( 0 ) \\left( \\frac { 1 } { 2 } + c ( 0 ) \\right) \\right) h c u _ { \\mathrm { v i b } }", "caption_len": 59}, "20.png": {"img_path": "./data/small/images/images_train/20.png", "size": [152, 52], "caption": "\\langle T _ { z z } \\rangle = - 3 \\times \\frac { \\pi ^ { 2 } } { 1 4 4 0 a ^ { 4 } } .", "caption_len": 34}, "21.png": {"img_path": "./data/small/images/images_train/21.png", "size": [292, 53], "caption": "\\partial _ { u } \\xi _ { z } ^ { ( 1 ) } + { \\frac { 1 } { u } } \\xi _ { z } ^ { ( 1 ) } = { \\frac { 1 } { ( \\pi T R ) ^ { 2 } u } } \\left[ C _ { z } H _ { z z } ^ { \\prime } + C _ { t } H _ { t z } ^ { \\prime } \\right] .", "caption_len": 92}, "22.png": {"img_path": "./data/small/images/images_train/22.png", "size": [356, 38], "caption": "S \\sim \\tilde { \\psi } Q _ { o } \\tilde { \\psi } + g _ { s } ^ { 1 / 2 } \\tilde { \\psi } ^ { 3 } + \\tilde { \\phi } Q _ { c } \\tilde { \\phi } + g _ { s } \\tilde { \\phi } ^ { 3 } + \\tilde { \\phi } B ( g _ { s } ^ { 1 / 2 } \\tilde { \\psi } ) + \\cdots .", "caption_len": 91}, "23.png": {"img_path": "./data/small/images/images_train/23.png", "size": [415, 63], "caption": "C ( x ^ { \\prime } , x ^ { \\prime \\prime } ) = C \\Phi ( x ^ { \\prime } , x ^ { \\prime \\prime } ) \\ , \\quad \\Phi ( x ^ { \\prime } , x ^ { \\prime \\prime } ) = \\exp \\left[ - i e \\int _ { x ^ { \\prime \\prime } } ^ { x ^ { \\prime } } d x ^ { \\mu } A _ { \\mu } ( x ) \\right] \\ ,", "caption_len": 93}, "24.png": {"img_path": "./data/small/images/images_train/24.png", "size": [308, 73], "caption": "\\tilde { \\alpha } = \\alpha \\beta ^ { - m } = \\left( \\begin{array} { c c c } { \\omega _ { k } ^ { - 2 y } \\omega _ { 2 d } ^ { 2 m } } & { 0 } & { 0 } \\\\ { 0 } & { \\omega _ { k } ^ { y } \\omega _ { 2 d } ^ { - m } } & { 0 } \\\\ { 0 } & { 0 } & { \\omega _ { k } ^ { y } \\omega _ { 2 d } ^ { - m } } \\\\ \\end{array} \\right)", "caption_len": 119}, "25.png": {"img_path": "./data/small/images/images_train/25.png", "size": [331, 38], "caption": "d s ^ { 2 } = H ^ { - 2 } f ( r ) d t ^ { 2 } + H ^ { 2 / ( n - 1 ) } ( f ( r ) ^ { - 1 } d r ^ { 2 } + r ^ { 2 } d \\Omega _ { n } ^ { 2 } ) ,", "caption_len": 71}, "26.png": {"img_path": "./data/small/images/images_train/26.png", "size": [283, 37], "caption": "y ^ { 2 } = \\rho \\cosh \\beta \\sin \\theta \\sin \\phi \\qquad \\qquad y ^ { 3 } = \\rho \\cos \\theta", "caption_len": 26}, "27.png": {"img_path": "./data/small/images/images_train/27.png", "size": [350, 39], "caption": "e ^ { A } = e ^ { A _ { 0 } } \\left( t _ { 0 } - \\mathrm { s i g n } ( m ) t \\right) ^ { - \\frac { m } { 2 } } , \\chi = \\chi _ { 0 } \\left( t _ { 0 } - \\mathrm { s i g n } ( m ) t \\right) ^ { m } ,", "caption_len": 79}, "28.png": {"img_path": "./data/small/images/images_train/28.png", "size": [282, 50], "caption": "\\gamma _ { j } { \\cal P } _ { j i } = \\frac { 4 } { 3 } \\{ [ A d T ] [ t _ { 8 } ^ { c } , [ t _ { 8 } ^ { c } , { \\gamma } _ { j } ] ] [ A d T ^ { - 1 } ] \\} { A d { \\hat { g } } } _ { i j } .", "caption_len": 88}, "29.png": {"img_path": "./data/small/images/images_train/29.png", "size": [97, 49], "caption": "K _ { \\mu u } = \\frac { 1 } { 2 } \\dot { g } _ { \\mu u } .", "caption_len": 26}, "30.png": {"img_path": "./data/small/images/images_train/30.png", "size": [268, 57], "caption": "X ( u ) = { \\frac { \\left( \\pm i + e ^ { 3 \\eta } \\right) \\left( - 1 + { e ^ { u } } \\right) \\left( 1 + { e ^ { u } } \\right) x _ { 1 } } { 2 { e ^ { u } } \\left( \\pm i + { e ^ { 3 \\eta + u } } \\right) } } ,", "caption_len": 77}, "31.png": {"img_path": "./data/small/images/images_train/31.png", "size": [148, 53], "caption": "\\beta ( g ) \\frac { \\partial } { \\partial g } = 2 g \\beta ( g ) \\frac { \\partial } { \\partial g ^ { 2 } }", "caption_len": 33}, "32.png": {"img_path": "./data/small/images/images_train/32.png", "size": [329, 38], "caption": "A = a r ^ { \\beta } , \\quad B = b r ^ { \\beta + 2 } ; \\qquad a / b = c ( \\beta + 2 ) / ( \\beta - 2 ) ,", "caption_len": 41}, "33.png": {"img_path": "./data/small/images/images_train/33.png", "size": [175, 37], "caption": "\\delta W _ { P \\mu } = A _ { \\mu } \\Phi + B _ { P \\mu } ^ { \\alpha } K _ { P } ^ { \\alpha } \\ .", "caption_len": 38}, "34.png": {"img_path": "./data/small/images/images_train/34.png", "size": [267, 52], "caption": "\\frac { 1 } { d - 2 } \\tilde { \\Pi } ^ { 2 } - \\tilde { \\Pi } _ { a b } \\tilde { \\Pi } ^ { a b } = \\frac { \\left( d - 1 \\right) \\left( d - 2 \\right) } { \\ell ^ { 2 } } + R", "caption_len": 61}, "35.png": {"img_path": "./data/small/images/images_train/35.png", "size": [128, 38], "caption": "\\hat { e } = e / \\varepsilon , \\hat { G } _ { 4 } = G _ { 4 } ,", "caption_len": 26}, "36.png": {"img_path": "./data/small/images/images_train/36.png", "size": [268, 39], "caption": "V _ { ( n , m ) } ( z , \\overline { { z } } ) = : \\exp i ( p _ { + } \\phi ( z ) + p _ { - } \\bar { \\phi } ( \\overline { { z } } ) ) : \\: .", "caption_len": 57}, "37.png": {"img_path": "./data/small/images/images_train/37.png", "size": [396, 40], "caption": "\\langle f | g \\rangle _ { { \\cal L } ^ { 1 | 2 } } = \\langle f _ { 0 } | g _ { 0 } \\rangle _ { \\cal L } ^ { s } + \\langle f _ { 1 } | g _ { 1 } \\rangle _ { \\cal L } ^ { s + 1 / 2 } + \\langle f _ { 2 } | g _ { 2 } \\rangle _ { \\cal L } ^ { s + 1 / 2 } + \\langle f _ { 3 } | g _ { 3 } \\rangle _ { \\cal L } ^ { s + 1 } ,", "caption_len": 123}, "38.png": {"img_path": "./data/small/images/images_train/38.png", "size": [384, 52], "caption": "\\tilde { s } ^ { 0 } ( x , y ) = i e ^ { 2 } \\int \\! d ^ { 4 } \\! z S _ { \\mathrm { F } } ( x , z ) \\gamma ^ { \\mu } S _ { \\mathrm { F } } ( z , y ) [ d _ { \\mu } ( x - z ) + d _ { \\mu } ( z - y ) ]", "caption_len": 85}, "39.png": {"img_path": "./data/small/images/images_train/39.png", "size": [132, 55], "caption": "\\left\\{ \\begin{array} { l c l l } { \\phi ( \\infty ) } & { = } & { 0 } & { , \\vspace { 3 m m } } \\\\ { \\phi ( 0 ) } & { = } & { 1 } & { . } \\\\ \\end{array} \\right.", "caption_len": 56}, "40.png": {"img_path": "./data/small/images/images_train/40.png", "size": [118, 52], "caption": "{ \\cal P } _ { \\delta x } \\equiv { \\frac { k ^ { 3 } } { 2 \\pi ^ { 2 } } } | \\delta x | ^ { 2 } ,", "caption_len": 39}, "41.png": {"img_path": "./data/small/images/images_train/41.png", "size": [191, 36], "caption": "\\psi ( x ) = - 2 \\phi ( x ) + 2 \\phi ( L ) + c ,", "caption_len": 22}, "42.png": {"img_path": "./data/small/images/images_train/42.png", "size": [225, 40], "caption": "{ } ^ { ( { } ^ { \\scriptstyle x } y ) } ( { } ^ { x } z ) = { } ^ { x } ( { } ^ { y } z ) , \\qquad \\forall x , y , z \\in X .", "caption_len": 53}, "43.png": {"img_path": "./data/small/images/images_train/43.png", "size": [231, 41], "caption": "\\delta ( L _ { 1 } + L _ { 2 } ) = 2 \\delta \\bar { \\theta } ( 1 + \\gamma ^ { ( p ) } ) T _ { ( p ) } ^ { u } \\partial _ { u } \\theta .", "caption_len": 52}, "44.png": {"img_path": "./data/small/images/images_train/44.png", "size": [371, 57], "caption": "\\frac { 1 } { 2 \\lambda f ^ { 2 } } \\int d ^ { 4 } X \\frac { d ^ { 4 } q } { \\left( 2 \\pi \\right) ^ { 4 } } \\left( \\varphi ( X ) \\right) ^ { 2 } \\tilde { \\pi } _ { 0 } ( q ) \\left[ \\partial _ { q } ^ { 2 } + \\frac { 4 i \\lambda } { q ^ { 2 } - \\Sigma ^ { 2 } ( q ) } \\right] \\tilde { \\pi } _ { 0 } ( q ) ,", "caption_len": 108}, "45.png": {"img_path": "./data/small/images/images_train/45.png", "size": [363, 50], "caption": "G ^ { \\mu u \\mu ^ { \\prime } u ^ { \\prime } } = g ^ { \\mu \\mu ^ { \\prime } } g ^ { u u ^ { \\prime } } + g ^ { \\mu u ^ { \\prime } } g ^ { u \\mu ^ { \\prime } } - { \\frac { 2 } { D } } g ^ { \\mu u } g ^ { \\mu ^ { \\prime } u ^ { \\prime } } + C g ^ { \\mu u } g ^ { \\mu ^ { \\prime } u ^ { \\prime } } \\: .", "caption_len": 114}, "46.png": {"img_path": "./data/small/images/images_train/46.png", "size": [345, 36], "caption": "[ M _ { \\mu u } , M _ { \\rho \\tau } ] = g _ { \\mu \\tau } M _ { u \\rho } - g _ { u \\tau } M _ { \\mu \\rho } + g _ { u \\rho } M _ { \\mu \\tau } - g _ { \\mu \\rho } M _ { u \\tau } ,", "caption_len": 70}, "47.png": {"img_path": "./data/small/images/images_train/47.png", "size": [185, 64], "caption": "A _ { 0 } = \\pm \\sqrt { { \\frac { 4 } { 3 ( 1 - \\alpha ) } } } e ^ { ( \\alpha - 1 ) \\phi } \\ .", "caption_len": 38}, "48.png": {"img_path": "./data/small/images/images_train/48.png", "size": [187, 56], "caption": "C _ { m } ( \\mu ) = { \\frac { 1 } { 2 \\pi i } } \\int _ { \\Gamma _ { r } } { \\frac { C _ { m } ( z ) } { z - \\mu } } d z ,", "caption_len": 52}, "49.png": {"img_path": "./data/small/images/images_train/49.png", "size": [308, 38], "caption": "\\xi = \\alpha ^ { - 1 } \\sqrt { \\rho } \\cosh ( 2 \\alpha ^ { 2 } t ) , \\quad \\eta = \\alpha ^ { - 1 } \\sqrt { \\rho } \\sinh ( 2 \\alpha ^ { 2 } t )", "caption_len": 48}} -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from config import * 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim 5 | import torch.utils.data 6 | from torch import nn 7 | from tqdm import tqdm 8 | from torch.nn.utils.rnn import pack_padded_sequence 9 | from model.utils import * 10 | from model import metrics,dataloader,model 11 | from torch.utils.checkpoint import checkpoint as train_ck 12 | 13 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | device = "cpu" 15 | 16 | 17 | model.device = device 18 | ''' 19 | 如果网络的输入数据维度或类型上变化不大,设置 torch.backends.cudnn.benchmark = true 可以增加运行效率; 20 | 如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。 21 | ''' 22 | cudnn.benchmark = True 23 | 24 | 25 | def main(): 26 | """ 27 | Training and validation. 28 | """ 29 | 30 | global best_score, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map 31 | 32 | # 字典文件 33 | word_map = load_json(vocab_path) 34 | 35 | # Initialize / load checkpoint 36 | if checkpoint is None: 37 | decoder = model.DecoderWithAttention(attention_dim=attention_dim, 38 | embed_dim=emb_dim, 39 | decoder_dim=decoder_dim, 40 | vocab_size=len(word_map), 41 | dropout=dropout) 42 | decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()), 43 | lr=decoder_lr) 44 | encoder = model.Encoder() 45 | # encoder_optimizer = None 46 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 47 | lr=encoder_lr) 48 | 49 | else: 50 | checkpoint = torch.load(checkpoint) 51 | start_epoch = checkpoint['epoch'] + 1 52 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 53 | best_score = checkpoint['score'] 54 | decoder = checkpoint['decoder'] 55 | encoder_optimizer = checkpoint['encoder_optimizer'] 56 | decoder_optimizer = checkpoint['decoder_optimizer'] 57 | encoder = checkpoint['encoder'] 58 | # encoder_optimizer = checkpoint['encoder_optimizer'] 59 | # encoder_optimizer = None 60 | # if fine_tune_encoder is True and encoder_optimizer is None: 61 | # encoder.fine_tune(fine_tune_encoder) 62 | # encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 63 | # lr=encoder_lr) 64 | 65 | # Move to GPU, if available 66 | decoder = decoder.to(device) 67 | encoder = encoder.to(device) 68 | 69 | # 使用交叉熵损失函数 70 | criterion = nn.CrossEntropyLoss().to(device) 71 | 72 | # 自定义的数据集 73 | train_loader = dataloader.formuladataset(train_set_path,batch_size = batch_size,ratio = 5) 74 | val_loader = dataloader.formuladataset(val_set_path,batch_size = test_batch_size,ratio = 5) 75 | 76 | # #统计验证集的词频 77 | # words_freq = cal_word_freq(word_map,val_loader) 78 | # print(words_freq) 79 | p = 1#teacher forcing概率 80 | # Epochs 81 | for epoch in range(start_epoch, epochs): 82 | train_loader.shuffle() 83 | val_loader.shuffle() 84 | #每2个epoch衰减一次teahcer forcing的概率 85 | if p > 0.05: 86 | if (epoch % 3 == 0 and epoch != 0): 87 | p *= 0.75 88 | else: 89 | p = 0 90 | print('start epoch:%u'%epoch,'p:%.2f'%p) 91 | 92 | # 如果迭代4次后没有改善,则对学习率进行衰减,如果迭代20次都没有改善则触发早停.直到最大迭代次数 93 | if epochs_since_improvement == 30: 94 | break 95 | if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0: 96 | adjust_learning_rate(decoder_optimizer, 0.7) 97 | adjust_learning_rate(encoder_optimizer, 0.8) 98 | #动态学习率调节 99 | # torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.8, 100 | # patience=4, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=1e-6, eps=1e-8) 101 | 102 | # One epoch's training 103 | train(train_loader=train_loader, 104 | encoder=encoder, 105 | decoder=decoder, 106 | criterion=criterion, 107 | encoder_optimizer=decoder_optimizer, 108 | decoder_optimizer=decoder_optimizer, 109 | epoch=epoch,p=p)#encoder_optimizer=encoder_optimizer, 110 | 111 | # One epoch's validation 112 | recent_score = validate(val_loader=val_loader, 113 | encoder=encoder, 114 | decoder=decoder, 115 | criterion=criterion) 116 | if (p==0): 117 | print('Stop teacher forcing!') 118 | # Check if there was an improvement 119 | is_best = recent_score > best_score 120 | best_score = max(recent_score, best_score) 121 | if not is_best: 122 | epochs_since_improvement += 1 123 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 124 | else: 125 | print('New Best Score!(%d)'%(best_score,)) 126 | epochs_since_improvement = 0 127 | 128 | if epoch % save_freq == 0: 129 | print('Saveing...') 130 | save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder,encoder_optimizer, 131 | decoder_optimizer, recent_score, is_best) 132 | print('--------------------------------------------------------------------------') 133 | 134 | 135 | def train(train_loader, encoder, decoder, criterion, encoder_optimizer,decoder_optimizer, epoch, p): 136 | """ 137 | Performs one epoch's training. 138 | :param train_loader: 训练集的dataloader 139 | :param encoder: encoder model 140 | :param decoder: decoder model 141 | :param criterion: 损失函数 142 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 143 | :param decoder_optimizer: optimizer to update decoder's weights 144 | :param epoch: epoch number 145 | """ 146 | 147 | decoder.train() # train mode (dropout and batchnorm is used) 148 | encoder.train() 149 | 150 | batch_time = AverageMeter() # forward prop. + back prop. time 151 | losses = AverageMeter() # loss (per word decoded) 152 | top3accs = AverageMeter() # top5 accuracy 153 | 154 | start = time.time() 155 | 156 | # Batches 157 | # for i, (imgs, caps, caplens) in tqdm(enumerate(train_loader)): 158 | for i, (imgs, caps, caplens) in enumerate(train_loader): 159 | # Move to GPU, if available 160 | imgs = imgs.to(device) 161 | caps = caps.to(device) 162 | caplens = caplens.to(device) 163 | 164 | # Forward prop. 165 | # try: 166 | # imgs = encoder(imgs) 167 | # scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 168 | # except: 169 | # imgs.requires_grad = True 170 | # imgs = train_ck(encoder,imgs) 171 | try: 172 | imgs = encoder(imgs) 173 | except: 174 | imgs = train_ck(encoder,imgs) 175 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens, p=p) 176 | 177 | # 由于加入开始符以及停止符,caption从第二位开始,知道结束符 178 | targets = caps_sorted[:, 1:] 179 | 180 | # Remove timesteps that we didn't decode at, or are pads 181 | # pack_padded_sequence is an easy trick to do this 182 | # scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) 183 | # targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) 184 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data 185 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 186 | 187 | # Calculate loss 188 | scores = scores.to(device) 189 | loss = criterion(scores, targets) 190 | 191 | # 加入 doubly stochastic attention 正则化 192 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() 193 | 194 | # 反向传播 195 | encoder_optimizer.zero_grad() 196 | decoder_optimizer.zero_grad() 197 | loss.backward() 198 | 199 | # 梯度裁剪 200 | if grad_clip is not None: 201 | clip_gradient(decoder_optimizer, grad_clip) 202 | # if encoder_optimizer is not None: 203 | # clip_gradient(encoder_optimizer, grad_clip) 204 | 205 | # 更新权重 206 | decoder_optimizer.step() 207 | encoder_optimizer.step() 208 | # if encoder_optimizer is not None: 209 | # encoder_optimizer.step() 210 | 211 | # Keep track of metrics 212 | top3 = accuracy(scores, targets, 3) 213 | losses.update(loss.item(), sum(decode_lengths)) 214 | top3accs.update(top3, sum(decode_lengths)) 215 | batch_time.update(time.time() - start) 216 | 217 | start = time.time() 218 | 219 | # Print status 220 | if i % print_freq == 0: 221 | print('Epoch: [{0}][{1}/{2}]\t' 222 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 223 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 224 | 'Top-3 Accuracy {top3.val:.3f} ({top3.avg:.3f})'.format(epoch, i, len(train_loader), 225 | batch_time=batch_time, 226 | loss=losses, 227 | top3=top3accs)) 228 | # if i % save_freq == 0: 229 | # save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder,encoder_optimizer, 230 | # decoder_optimizer, 0,0) 231 | del imgs, scores, caps_sorted, decode_lengths, alphas, sort_ind, loss, targets 232 | torch.cuda.empty_cache() 233 | 234 | 235 | def validate(val_loader, encoder, decoder, criterion): 236 | """ 237 | Performs one epoch's validation. 238 | :param val_loader: 用于验证集的dataloader 239 | :param encoder: encoder model 240 | :param decoder: decoder model 241 | :param criterion: 损失函数 242 | :return: 验证集上的BLEU-4 score 243 | """ 244 | decoder.eval() # 推断模式,取消dropout以及批标准化 245 | if encoder is not None: 246 | encoder.eval() 247 | 248 | batch_time = AverageMeter() 249 | losses = AverageMeter() 250 | top3accs = AverageMeter() 251 | 252 | start = time.time() 253 | 254 | references = list() # references (true captions) for calculating BLEU-4 score 255 | hypotheses = list() # hypotheses (predictions) 256 | 257 | # explicitly disable gradient calculation to avoid CUDA memory error 258 | with torch.no_grad(): 259 | # Batches 260 | # for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader): 261 | # for i, (imgs, caps, caplens) in tqdm(enumerate(val_loader)): 262 | for i, (imgs, caps, caplens) in enumerate(val_loader): 263 | 264 | # Move to device, if available 265 | imgs = imgs.to(device) 266 | caps = caps.to(device) 267 | caplens = caplens.to(device) 268 | 269 | # Forward prop. 270 | if encoder is not None: 271 | imgs = encoder(imgs) 272 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens, p=0) 273 | 274 | # Since we decoded starting with , the targets are all words after , up to 275 | targets = caps_sorted[:, 1:] 276 | 277 | # Remove timesteps that we didn't decode at, or are pads 278 | # pack_padded_sequence is an easy trick to do this 279 | scores_copy = scores.clone() 280 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data 281 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 282 | 283 | # Calculate loss 284 | loss = criterion(scores, targets) 285 | 286 | # Add doubly stochastic attention regularization 287 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() 288 | 289 | # Keep track of metrics 290 | losses.update(loss.item(), sum(decode_lengths)) 291 | top3 = accuracy(scores, targets, 3) 292 | top3accs.update(top3, sum(decode_lengths)) 293 | batch_time.update(time.time() - start) 294 | 295 | start = time.time() 296 | 297 | if i % print_freq == 0: 298 | print('Validation: [{0}/{1}],' 299 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f}),' 300 | 'Loss {loss.val:.4f} ({loss.avg:.4f}),' 301 | 'Top-3 Accuracy {top3.val:.3f} ({top3.avg:.3f}),'.format(i, len(val_loader), batch_time=batch_time, 302 | loss=losses, top3=top3accs)) 303 | 304 | # Store references (true captions), and hypothesis (prediction) for each image 305 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 306 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 307 | 308 | # References 309 | # allcaps = allcaps[sort_ind] # because images were sorted in the decoder 310 | # for j in range(allcaps.shape[0]): 311 | # img_caps = allcaps[j].tolist() 312 | # img_captions = list( 313 | # map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}], 314 | # img_caps)) # remove and pads 315 | # references.append(img_captions) 316 | caplens = caplens[sort_ind] 317 | caps = caps[sort_ind] 318 | for i in range(len(caplens)): 319 | references.append(caps[i][1:caplens[i]].tolist()) 320 | # Hypotheses 321 | # 这里直接使用greedy模式进行评价,在推断中一般使用集束搜索模式 322 | _, preds = torch.max(scores_copy, dim=2) 323 | preds = preds.tolist() 324 | temp_preds = list() 325 | for j, p in enumerate(preds): 326 | temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads 327 | preds = temp_preds 328 | hypotheses.extend(preds) 329 | 330 | assert len(references) == len(hypotheses) 331 | 332 | Score = metrics.evaluate(losses, top3accs, references, hypotheses) 333 | return Score 334 | 335 | 336 | if __name__ == '__main__': 337 | main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, if necessary. 660 | For more information on this, and how to apply and follow the GNU AGPL, see 661 | . 662 | --------------------------------------------------------------------------------