├── .gitignore ├── README.md ├── data ├── butterfly_GT.bmp ├── butterfly_GT_bicubic_x3.bmp ├── butterfly_GT_srcnn_x3.bmp ├── ppt3.bmp ├── ppt3_bicubic_x3.bmp ├── ppt3_srcnn_x3.bmp ├── zebra.bmp ├── zebra_bicubic_x3.bmp └── zebra_srcnn_x3.bmp ├── datasets.py ├── models.py ├── prepare.py ├── test.py ├── thumbnails └── fig1.png ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SRCNN 2 | 3 | This repository is implementation of the ["Image Super-Resolution Using Deep Convolutional Networks"](https://arxiv.org/abs/1501.00092). 4 | 5 |
6 | 7 | ## Differences from the original 8 | 9 | - Added the zero-padding 10 | - Used the Adam instead of the SGD 11 | - Removed the weights initialization 12 | 13 | ## Requirements 14 | 15 | - PyTorch 1.0.0 16 | - Numpy 1.15.4 17 | - Pillow 5.4.1 18 | - h5py 2.8.0 19 | - tqdm 4.30.0 20 | 21 | ## Train 22 | 23 | The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below. 24 | 25 | | Dataset | Scale | Type | Link | 26 | |---------|-------|------|------| 27 | | 91-image | 2 | Train | [Download](https://www.dropbox.com/s/2hsah93sxgegsry/91-image_x2.h5?dl=0) | 28 | | 91-image | 3 | Train | [Download](https://www.dropbox.com/s/curldmdf11iqakd/91-image_x3.h5?dl=0) | 29 | | 91-image | 4 | Train | [Download](https://www.dropbox.com/s/22afykv4amfxeio/91-image_x4.h5?dl=0) | 30 | | Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/r8qs6tp395hgh8g/Set5_x2.h5?dl=0) | 31 | | Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/58ywjac4te3kbqq/Set5_x3.h5?dl=0) | 32 | | Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/0rz86yn3nnrodlb/Set5_x4.h5?dl=0) | 33 | 34 | Otherwise, you can use `prepare.py` to create custom dataset. 35 | 36 | ```bash 37 | python train.py --train-file "BLAH_BLAH/91-image_x3.h5" \ 38 | --eval-file "BLAH_BLAH/Set5_x3.h5" \ 39 | --outputs-dir "BLAH_BLAH/outputs" \ 40 | --scale 3 \ 41 | --lr 1e-4 \ 42 | --batch-size 16 \ 43 | --num-epochs 400 \ 44 | --num-workers 8 \ 45 | --seed 123 46 | ``` 47 | 48 | ## Test 49 | 50 | Pre-trained weights can be downloaded from the links below. 51 | 52 | | Model | Scale | Link | 53 | |-------|-------|------| 54 | | 9-5-5 | 2 | [Download](https://www.dropbox.com/s/rxluu1y8ptjm4rn/srcnn_x2.pth?dl=0) | 55 | | 9-5-5 | 3 | [Download](https://www.dropbox.com/s/zn4fdobm2kw0c58/srcnn_x3.pth?dl=0) | 56 | | 9-5-5 | 4 | [Download](https://www.dropbox.com/s/pd5b2ketm0oamhj/srcnn_x4.pth?dl=0) | 57 | 58 | The results are stored in the same path as the query image. 59 | 60 | ```bash 61 | python test.py --weights-file "BLAH_BLAH/srcnn_x3.pth" \ 62 | --image-file "data/butterfly_GT.bmp" \ 63 | --scale 3 64 | ``` 65 | 66 | ## Results 67 | 68 | We used the network settings for experiments, i.e., . 69 | 70 | PSNR was calculated on the Y channel. 71 | 72 | ### Set5 73 | 74 | | Eval. Mat | Scale | SRCNN | SRCNN (Ours) | 75 | |-----------|-------|-------|--------------| 76 | | PSNR | 2 | 36.66 | 36.65 | 77 | | PSNR | 3 | 32.75 | 33.29 | 78 | | PSNR | 4 | 30.49 | 30.25 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 90 | 93 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 106 | 109 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 122 | 125 | 128 | 129 |
Original
BICUBIC x3
SRCNN x3 (27.53 dB)
88 |
89 |
91 |
92 |
94 |
95 |
Original
BICUBIC x3
SRCNN x3 (29.30 dB)
104 |
105 |
107 |
108 |
110 |
111 |
Original
BICUBIC x3
SRCNN x3 (28.58 dB)
120 |
121 |
123 |
124 |
126 |
127 |
130 | -------------------------------------------------------------------------------- /data/butterfly_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/butterfly_GT.bmp -------------------------------------------------------------------------------- /data/butterfly_GT_bicubic_x3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/butterfly_GT_bicubic_x3.bmp -------------------------------------------------------------------------------- /data/butterfly_GT_srcnn_x3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/butterfly_GT_srcnn_x3.bmp -------------------------------------------------------------------------------- /data/ppt3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/ppt3.bmp -------------------------------------------------------------------------------- /data/ppt3_bicubic_x3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/ppt3_bicubic_x3.bmp -------------------------------------------------------------------------------- /data/ppt3_srcnn_x3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/ppt3_srcnn_x3.bmp -------------------------------------------------------------------------------- /data/zebra.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/zebra.bmp -------------------------------------------------------------------------------- /data/zebra_bicubic_x3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/zebra_bicubic_x3.bmp -------------------------------------------------------------------------------- /data/zebra_srcnn_x3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/zebra_srcnn_x3.bmp -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TrainDataset(Dataset): 7 | def __init__(self, h5_file): 8 | super(TrainDataset, self).__init__() 9 | self.h5_file = h5_file 10 | 11 | def __getitem__(self, idx): 12 | with h5py.File(self.h5_file, 'r') as f: 13 | return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0) 14 | 15 | def __len__(self): 16 | with h5py.File(self.h5_file, 'r') as f: 17 | return len(f['lr']) 18 | 19 | 20 | class EvalDataset(Dataset): 21 | def __init__(self, h5_file): 22 | super(EvalDataset, self).__init__() 23 | self.h5_file = h5_file 24 | 25 | def __getitem__(self, idx): 26 | with h5py.File(self.h5_file, 'r') as f: 27 | return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0) 28 | 29 | def __len__(self): 30 | with h5py.File(self.h5_file, 'r') as f: 31 | return len(f['lr']) 32 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SRCNN(nn.Module): 5 | def __init__(self, num_channels=1): 6 | super(SRCNN, self).__init__() 7 | self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2) 8 | self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2) 9 | self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2) 10 | self.relu = nn.ReLU(inplace=True) 11 | 12 | def forward(self, x): 13 | x = self.relu(self.conv1(x)) 14 | x = self.relu(self.conv2(x)) 15 | x = self.conv3(x) 16 | return x 17 | -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import h5py 4 | import numpy as np 5 | import PIL.Image as pil_image 6 | from utils import convert_rgb_to_y 7 | 8 | 9 | def train(args): 10 | h5_file = h5py.File(args.output_path, 'w') 11 | 12 | lr_patches = [] 13 | hr_patches = [] 14 | 15 | for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))): 16 | hr = pil_image.open(image_path).convert('RGB') 17 | hr_width = (hr.width // args.scale) * args.scale 18 | hr_height = (hr.height // args.scale) * args.scale 19 | hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) 20 | lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) 21 | lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC) 22 | hr = np.array(hr).astype(np.float32) 23 | lr = np.array(lr).astype(np.float32) 24 | hr = convert_rgb_to_y(hr) 25 | lr = convert_rgb_to_y(lr) 26 | 27 | for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride): 28 | for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride): 29 | lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size]) 30 | hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size]) 31 | 32 | lr_patches = np.array(lr_patches) 33 | hr_patches = np.array(hr_patches) 34 | 35 | h5_file.create_dataset('lr', data=lr_patches) 36 | h5_file.create_dataset('hr', data=hr_patches) 37 | 38 | h5_file.close() 39 | 40 | 41 | def eval(args): 42 | h5_file = h5py.File(args.output_path, 'w') 43 | 44 | lr_group = h5_file.create_group('lr') 45 | hr_group = h5_file.create_group('hr') 46 | 47 | for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))): 48 | hr = pil_image.open(image_path).convert('RGB') 49 | hr_width = (hr.width // args.scale) * args.scale 50 | hr_height = (hr.height // args.scale) * args.scale 51 | hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) 52 | lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) 53 | lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC) 54 | hr = np.array(hr).astype(np.float32) 55 | lr = np.array(lr).astype(np.float32) 56 | hr = convert_rgb_to_y(hr) 57 | lr = convert_rgb_to_y(lr) 58 | 59 | lr_group.create_dataset(str(i), data=lr) 60 | hr_group.create_dataset(str(i), data=hr) 61 | 62 | h5_file.close() 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--images-dir', type=str, required=True) 68 | parser.add_argument('--output-path', type=str, required=True) 69 | parser.add_argument('--patch-size', type=int, default=33) 70 | parser.add_argument('--stride', type=int, default=14) 71 | parser.add_argument('--scale', type=int, default=2) 72 | parser.add_argument('--eval', action='store_true') 73 | args = parser.parse_args() 74 | 75 | if not args.eval: 76 | train(args) 77 | else: 78 | eval(args) 79 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | import numpy as np 6 | import PIL.Image as pil_image 7 | 8 | from models import SRCNN 9 | from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--weights-file', type=str, required=True) 15 | parser.add_argument('--image-file', type=str, required=True) 16 | parser.add_argument('--scale', type=int, default=3) 17 | args = parser.parse_args() 18 | 19 | cudnn.benchmark = True 20 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 21 | 22 | model = SRCNN().to(device) 23 | 24 | state_dict = model.state_dict() 25 | for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items(): 26 | if n in state_dict.keys(): 27 | state_dict[n].copy_(p) 28 | else: 29 | raise KeyError(n) 30 | 31 | model.eval() 32 | 33 | image = pil_image.open(args.image_file).convert('RGB') 34 | 35 | image_width = (image.width // args.scale) * args.scale 36 | image_height = (image.height // args.scale) * args.scale 37 | image = image.resize((image_width, image_height), resample=pil_image.BICUBIC) 38 | image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC) 39 | image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC) 40 | image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale))) 41 | 42 | image = np.array(image).astype(np.float32) 43 | ycbcr = convert_rgb_to_ycbcr(image) 44 | 45 | y = ycbcr[..., 0] 46 | y /= 255. 47 | y = torch.from_numpy(y).to(device) 48 | y = y.unsqueeze(0).unsqueeze(0) 49 | 50 | with torch.no_grad(): 51 | preds = model(y).clamp(0.0, 1.0) 52 | 53 | psnr = calc_psnr(y, preds) 54 | print('PSNR: {:.2f}'.format(psnr)) 55 | 56 | preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0) 57 | 58 | output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0]) 59 | output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8) 60 | output = pil_image.fromarray(output) 61 | output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale))) 62 | -------------------------------------------------------------------------------- /thumbnails/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/thumbnails/fig1.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import copy 4 | 5 | import torch 6 | from torch import nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.data.dataloader import DataLoader 10 | from tqdm import tqdm 11 | 12 | from models import SRCNN 13 | from datasets import TrainDataset, EvalDataset 14 | from utils import AverageMeter, calc_psnr 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--train-file', type=str, required=True) 20 | parser.add_argument('--eval-file', type=str, required=True) 21 | parser.add_argument('--outputs-dir', type=str, required=True) 22 | parser.add_argument('--scale', type=int, default=3) 23 | parser.add_argument('--lr', type=float, default=1e-4) 24 | parser.add_argument('--batch-size', type=int, default=16) 25 | parser.add_argument('--num-epochs', type=int, default=400) 26 | parser.add_argument('--num-workers', type=int, default=8) 27 | parser.add_argument('--seed', type=int, default=123) 28 | args = parser.parse_args() 29 | 30 | args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale)) 31 | 32 | if not os.path.exists(args.outputs_dir): 33 | os.makedirs(args.outputs_dir) 34 | 35 | cudnn.benchmark = True 36 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 37 | 38 | torch.manual_seed(args.seed) 39 | 40 | model = SRCNN().to(device) 41 | criterion = nn.MSELoss() 42 | optimizer = optim.Adam([ 43 | {'params': model.conv1.parameters()}, 44 | {'params': model.conv2.parameters()}, 45 | {'params': model.conv3.parameters(), 'lr': args.lr * 0.1} 46 | ], lr=args.lr) 47 | 48 | train_dataset = TrainDataset(args.train_file) 49 | train_dataloader = DataLoader(dataset=train_dataset, 50 | batch_size=args.batch_size, 51 | shuffle=True, 52 | num_workers=args.num_workers, 53 | pin_memory=True, 54 | drop_last=True) 55 | eval_dataset = EvalDataset(args.eval_file) 56 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) 57 | 58 | best_weights = copy.deepcopy(model.state_dict()) 59 | best_epoch = 0 60 | best_psnr = 0.0 61 | 62 | for epoch in range(args.num_epochs): 63 | model.train() 64 | epoch_losses = AverageMeter() 65 | 66 | with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t: 67 | t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1)) 68 | 69 | for data in train_dataloader: 70 | inputs, labels = data 71 | 72 | inputs = inputs.to(device) 73 | labels = labels.to(device) 74 | 75 | preds = model(inputs) 76 | 77 | loss = criterion(preds, labels) 78 | 79 | epoch_losses.update(loss.item(), len(inputs)) 80 | 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg)) 86 | t.update(len(inputs)) 87 | 88 | torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch))) 89 | 90 | model.eval() 91 | epoch_psnr = AverageMeter() 92 | 93 | for data in eval_dataloader: 94 | inputs, labels = data 95 | 96 | inputs = inputs.to(device) 97 | labels = labels.to(device) 98 | 99 | with torch.no_grad(): 100 | preds = model(inputs).clamp(0.0, 1.0) 101 | 102 | epoch_psnr.update(calc_psnr(preds, labels), len(inputs)) 103 | 104 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg)) 105 | 106 | if epoch_psnr.avg > best_psnr: 107 | best_epoch = epoch 108 | best_psnr = epoch_psnr.avg 109 | best_weights = copy.deepcopy(model.state_dict()) 110 | 111 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) 112 | torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth')) 113 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def convert_rgb_to_y(img): 6 | if type(img) == np.ndarray: 7 | return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. 8 | elif type(img) == torch.Tensor: 9 | if len(img.shape) == 4: 10 | img = img.squeeze(0) 11 | return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. 12 | else: 13 | raise Exception('Unknown Type', type(img)) 14 | 15 | 16 | def convert_rgb_to_ycbcr(img): 17 | if type(img) == np.ndarray: 18 | y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. 19 | cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256. 20 | cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256. 21 | return np.array([y, cb, cr]).transpose([1, 2, 0]) 22 | elif type(img) == torch.Tensor: 23 | if len(img.shape) == 4: 24 | img = img.squeeze(0) 25 | y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. 26 | cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256. 27 | cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256. 28 | return torch.cat([y, cb, cr], 0).permute(1, 2, 0) 29 | else: 30 | raise Exception('Unknown Type', type(img)) 31 | 32 | 33 | def convert_ycbcr_to_rgb(img): 34 | if type(img) == np.ndarray: 35 | r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921 36 | g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576 37 | b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836 38 | return np.array([r, g, b]).transpose([1, 2, 0]) 39 | elif type(img) == torch.Tensor: 40 | if len(img.shape) == 4: 41 | img = img.squeeze(0) 42 | r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921 43 | g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576 44 | b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836 45 | return torch.cat([r, g, b], 0).permute(1, 2, 0) 46 | else: 47 | raise Exception('Unknown Type', type(img)) 48 | 49 | 50 | def calc_psnr(img1, img2): 51 | return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2)) 52 | 53 | 54 | class AverageMeter(object): 55 | def __init__(self): 56 | self.reset() 57 | 58 | def reset(self): 59 | self.val = 0 60 | self.avg = 0 61 | self.sum = 0 62 | self.count = 0 63 | 64 | def update(self, val, n=1): 65 | self.val = val 66 | self.sum += val * n 67 | self.count += n 68 | self.avg = self.sum / self.count 69 | --------------------------------------------------------------------------------