├── .gitignore
├── README.md
├── data
├── butterfly_GT.bmp
├── butterfly_GT_bicubic_x3.bmp
├── butterfly_GT_srcnn_x3.bmp
├── ppt3.bmp
├── ppt3_bicubic_x3.bmp
├── ppt3_srcnn_x3.bmp
├── zebra.bmp
├── zebra_bicubic_x3.bmp
└── zebra_srcnn_x3.bmp
├── datasets.py
├── models.py
├── prepare.py
├── test.py
├── thumbnails
└── fig1.png
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SRCNN
2 |
3 | This repository is implementation of the ["Image Super-Resolution Using Deep Convolutional Networks"](https://arxiv.org/abs/1501.00092).
4 |
5 |
6 |
7 | ## Differences from the original
8 |
9 | - Added the zero-padding
10 | - Used the Adam instead of the SGD
11 | - Removed the weights initialization
12 |
13 | ## Requirements
14 |
15 | - PyTorch 1.0.0
16 | - Numpy 1.15.4
17 | - Pillow 5.4.1
18 | - h5py 2.8.0
19 | - tqdm 4.30.0
20 |
21 | ## Train
22 |
23 | The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below.
24 |
25 | | Dataset | Scale | Type | Link |
26 | |---------|-------|------|------|
27 | | 91-image | 2 | Train | [Download](https://www.dropbox.com/s/2hsah93sxgegsry/91-image_x2.h5?dl=0) |
28 | | 91-image | 3 | Train | [Download](https://www.dropbox.com/s/curldmdf11iqakd/91-image_x3.h5?dl=0) |
29 | | 91-image | 4 | Train | [Download](https://www.dropbox.com/s/22afykv4amfxeio/91-image_x4.h5?dl=0) |
30 | | Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/r8qs6tp395hgh8g/Set5_x2.h5?dl=0) |
31 | | Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/58ywjac4te3kbqq/Set5_x3.h5?dl=0) |
32 | | Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/0rz86yn3nnrodlb/Set5_x4.h5?dl=0) |
33 |
34 | Otherwise, you can use `prepare.py` to create custom dataset.
35 |
36 | ```bash
37 | python train.py --train-file "BLAH_BLAH/91-image_x3.h5" \
38 | --eval-file "BLAH_BLAH/Set5_x3.h5" \
39 | --outputs-dir "BLAH_BLAH/outputs" \
40 | --scale 3 \
41 | --lr 1e-4 \
42 | --batch-size 16 \
43 | --num-epochs 400 \
44 | --num-workers 8 \
45 | --seed 123
46 | ```
47 |
48 | ## Test
49 |
50 | Pre-trained weights can be downloaded from the links below.
51 |
52 | | Model | Scale | Link |
53 | |-------|-------|------|
54 | | 9-5-5 | 2 | [Download](https://www.dropbox.com/s/rxluu1y8ptjm4rn/srcnn_x2.pth?dl=0) |
55 | | 9-5-5 | 3 | [Download](https://www.dropbox.com/s/zn4fdobm2kw0c58/srcnn_x3.pth?dl=0) |
56 | | 9-5-5 | 4 | [Download](https://www.dropbox.com/s/pd5b2ketm0oamhj/srcnn_x4.pth?dl=0) |
57 |
58 | The results are stored in the same path as the query image.
59 |
60 | ```bash
61 | python test.py --weights-file "BLAH_BLAH/srcnn_x3.pth" \
62 | --image-file "data/butterfly_GT.bmp" \
63 | --scale 3
64 | ```
65 |
66 | ## Results
67 |
68 | We used the network settings for experiments, i.e.,
.
69 |
70 | PSNR was calculated on the Y channel.
71 |
72 | ### Set5
73 |
74 | | Eval. Mat | Scale | SRCNN | SRCNN (Ours) |
75 | |-----------|-------|-------|--------------|
76 | | PSNR | 2 | 36.66 | 36.65 |
77 | | PSNR | 3 | 32.75 | 33.29 |
78 | | PSNR | 4 | 30.49 | 30.25 |
79 |
80 |
81 |
82 | Original |
83 | BICUBIC x3 |
84 | SRCNN x3 (27.53 dB) |
85 |
86 |
87 |
88 |
89 | |
90 |
91 |
92 | |
93 |
94 |
95 | |
96 |
97 |
98 | Original |
99 | BICUBIC x3 |
100 | SRCNN x3 (29.30 dB) |
101 |
102 |
103 |
104 |
105 | |
106 |
107 |
108 | |
109 |
110 |
111 | |
112 |
113 |
114 | Original |
115 | BICUBIC x3 |
116 | SRCNN x3 (28.58 dB) |
117 |
118 |
119 |
120 |
121 | |
122 |
123 |
124 | |
125 |
126 |
127 | |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/data/butterfly_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/butterfly_GT.bmp
--------------------------------------------------------------------------------
/data/butterfly_GT_bicubic_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/butterfly_GT_bicubic_x3.bmp
--------------------------------------------------------------------------------
/data/butterfly_GT_srcnn_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/butterfly_GT_srcnn_x3.bmp
--------------------------------------------------------------------------------
/data/ppt3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/ppt3.bmp
--------------------------------------------------------------------------------
/data/ppt3_bicubic_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/ppt3_bicubic_x3.bmp
--------------------------------------------------------------------------------
/data/ppt3_srcnn_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/ppt3_srcnn_x3.bmp
--------------------------------------------------------------------------------
/data/zebra.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/zebra.bmp
--------------------------------------------------------------------------------
/data/zebra_bicubic_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/zebra_bicubic_x3.bmp
--------------------------------------------------------------------------------
/data/zebra_srcnn_x3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/data/zebra_srcnn_x3.bmp
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class TrainDataset(Dataset):
7 | def __init__(self, h5_file):
8 | super(TrainDataset, self).__init__()
9 | self.h5_file = h5_file
10 |
11 | def __getitem__(self, idx):
12 | with h5py.File(self.h5_file, 'r') as f:
13 | return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
14 |
15 | def __len__(self):
16 | with h5py.File(self.h5_file, 'r') as f:
17 | return len(f['lr'])
18 |
19 |
20 | class EvalDataset(Dataset):
21 | def __init__(self, h5_file):
22 | super(EvalDataset, self).__init__()
23 | self.h5_file = h5_file
24 |
25 | def __getitem__(self, idx):
26 | with h5py.File(self.h5_file, 'r') as f:
27 | return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)
28 |
29 | def __len__(self):
30 | with h5py.File(self.h5_file, 'r') as f:
31 | return len(f['lr'])
32 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class SRCNN(nn.Module):
5 | def __init__(self, num_channels=1):
6 | super(SRCNN, self).__init__()
7 | self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
8 | self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
9 | self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
10 | self.relu = nn.ReLU(inplace=True)
11 |
12 | def forward(self, x):
13 | x = self.relu(self.conv1(x))
14 | x = self.relu(self.conv2(x))
15 | x = self.conv3(x)
16 | return x
17 |
--------------------------------------------------------------------------------
/prepare.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import h5py
4 | import numpy as np
5 | import PIL.Image as pil_image
6 | from utils import convert_rgb_to_y
7 |
8 |
9 | def train(args):
10 | h5_file = h5py.File(args.output_path, 'w')
11 |
12 | lr_patches = []
13 | hr_patches = []
14 |
15 | for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
16 | hr = pil_image.open(image_path).convert('RGB')
17 | hr_width = (hr.width // args.scale) * args.scale
18 | hr_height = (hr.height // args.scale) * args.scale
19 | hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
20 | lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
21 | lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
22 | hr = np.array(hr).astype(np.float32)
23 | lr = np.array(lr).astype(np.float32)
24 | hr = convert_rgb_to_y(hr)
25 | lr = convert_rgb_to_y(lr)
26 |
27 | for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
28 | for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
29 | lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
30 | hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])
31 |
32 | lr_patches = np.array(lr_patches)
33 | hr_patches = np.array(hr_patches)
34 |
35 | h5_file.create_dataset('lr', data=lr_patches)
36 | h5_file.create_dataset('hr', data=hr_patches)
37 |
38 | h5_file.close()
39 |
40 |
41 | def eval(args):
42 | h5_file = h5py.File(args.output_path, 'w')
43 |
44 | lr_group = h5_file.create_group('lr')
45 | hr_group = h5_file.create_group('hr')
46 |
47 | for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
48 | hr = pil_image.open(image_path).convert('RGB')
49 | hr_width = (hr.width // args.scale) * args.scale
50 | hr_height = (hr.height // args.scale) * args.scale
51 | hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
52 | lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
53 | lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
54 | hr = np.array(hr).astype(np.float32)
55 | lr = np.array(lr).astype(np.float32)
56 | hr = convert_rgb_to_y(hr)
57 | lr = convert_rgb_to_y(lr)
58 |
59 | lr_group.create_dataset(str(i), data=lr)
60 | hr_group.create_dataset(str(i), data=hr)
61 |
62 | h5_file.close()
63 |
64 |
65 | if __name__ == '__main__':
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('--images-dir', type=str, required=True)
68 | parser.add_argument('--output-path', type=str, required=True)
69 | parser.add_argument('--patch-size', type=int, default=33)
70 | parser.add_argument('--stride', type=int, default=14)
71 | parser.add_argument('--scale', type=int, default=2)
72 | parser.add_argument('--eval', action='store_true')
73 | args = parser.parse_args()
74 |
75 | if not args.eval:
76 | train(args)
77 | else:
78 | eval(args)
79 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.backends.cudnn as cudnn
5 | import numpy as np
6 | import PIL.Image as pil_image
7 |
8 | from models import SRCNN
9 | from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
10 |
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--weights-file', type=str, required=True)
15 | parser.add_argument('--image-file', type=str, required=True)
16 | parser.add_argument('--scale', type=int, default=3)
17 | args = parser.parse_args()
18 |
19 | cudnn.benchmark = True
20 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
21 |
22 | model = SRCNN().to(device)
23 |
24 | state_dict = model.state_dict()
25 | for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
26 | if n in state_dict.keys():
27 | state_dict[n].copy_(p)
28 | else:
29 | raise KeyError(n)
30 |
31 | model.eval()
32 |
33 | image = pil_image.open(args.image_file).convert('RGB')
34 |
35 | image_width = (image.width // args.scale) * args.scale
36 | image_height = (image.height // args.scale) * args.scale
37 | image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
38 | image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
39 | image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
40 | image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))
41 |
42 | image = np.array(image).astype(np.float32)
43 | ycbcr = convert_rgb_to_ycbcr(image)
44 |
45 | y = ycbcr[..., 0]
46 | y /= 255.
47 | y = torch.from_numpy(y).to(device)
48 | y = y.unsqueeze(0).unsqueeze(0)
49 |
50 | with torch.no_grad():
51 | preds = model(y).clamp(0.0, 1.0)
52 |
53 | psnr = calc_psnr(y, preds)
54 | print('PSNR: {:.2f}'.format(psnr))
55 |
56 | preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
57 |
58 | output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
59 | output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
60 | output = pil_image.fromarray(output)
61 | output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale)))
62 |
--------------------------------------------------------------------------------
/thumbnails/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjn870/SRCNN-pytorch/064dbaac09859f5fa1b35608ab90145e2d60828b/thumbnails/fig1.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import copy
4 |
5 | import torch
6 | from torch import nn
7 | import torch.optim as optim
8 | import torch.backends.cudnn as cudnn
9 | from torch.utils.data.dataloader import DataLoader
10 | from tqdm import tqdm
11 |
12 | from models import SRCNN
13 | from datasets import TrainDataset, EvalDataset
14 | from utils import AverageMeter, calc_psnr
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--train-file', type=str, required=True)
20 | parser.add_argument('--eval-file', type=str, required=True)
21 | parser.add_argument('--outputs-dir', type=str, required=True)
22 | parser.add_argument('--scale', type=int, default=3)
23 | parser.add_argument('--lr', type=float, default=1e-4)
24 | parser.add_argument('--batch-size', type=int, default=16)
25 | parser.add_argument('--num-epochs', type=int, default=400)
26 | parser.add_argument('--num-workers', type=int, default=8)
27 | parser.add_argument('--seed', type=int, default=123)
28 | args = parser.parse_args()
29 |
30 | args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
31 |
32 | if not os.path.exists(args.outputs_dir):
33 | os.makedirs(args.outputs_dir)
34 |
35 | cudnn.benchmark = True
36 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
37 |
38 | torch.manual_seed(args.seed)
39 |
40 | model = SRCNN().to(device)
41 | criterion = nn.MSELoss()
42 | optimizer = optim.Adam([
43 | {'params': model.conv1.parameters()},
44 | {'params': model.conv2.parameters()},
45 | {'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
46 | ], lr=args.lr)
47 |
48 | train_dataset = TrainDataset(args.train_file)
49 | train_dataloader = DataLoader(dataset=train_dataset,
50 | batch_size=args.batch_size,
51 | shuffle=True,
52 | num_workers=args.num_workers,
53 | pin_memory=True,
54 | drop_last=True)
55 | eval_dataset = EvalDataset(args.eval_file)
56 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
57 |
58 | best_weights = copy.deepcopy(model.state_dict())
59 | best_epoch = 0
60 | best_psnr = 0.0
61 |
62 | for epoch in range(args.num_epochs):
63 | model.train()
64 | epoch_losses = AverageMeter()
65 |
66 | with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
67 | t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))
68 |
69 | for data in train_dataloader:
70 | inputs, labels = data
71 |
72 | inputs = inputs.to(device)
73 | labels = labels.to(device)
74 |
75 | preds = model(inputs)
76 |
77 | loss = criterion(preds, labels)
78 |
79 | epoch_losses.update(loss.item(), len(inputs))
80 |
81 | optimizer.zero_grad()
82 | loss.backward()
83 | optimizer.step()
84 |
85 | t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
86 | t.update(len(inputs))
87 |
88 | torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))
89 |
90 | model.eval()
91 | epoch_psnr = AverageMeter()
92 |
93 | for data in eval_dataloader:
94 | inputs, labels = data
95 |
96 | inputs = inputs.to(device)
97 | labels = labels.to(device)
98 |
99 | with torch.no_grad():
100 | preds = model(inputs).clamp(0.0, 1.0)
101 |
102 | epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
103 |
104 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
105 |
106 | if epoch_psnr.avg > best_psnr:
107 | best_epoch = epoch
108 | best_psnr = epoch_psnr.avg
109 | best_weights = copy.deepcopy(model.state_dict())
110 |
111 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
112 | torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
113 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def convert_rgb_to_y(img):
6 | if type(img) == np.ndarray:
7 | return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
8 | elif type(img) == torch.Tensor:
9 | if len(img.shape) == 4:
10 | img = img.squeeze(0)
11 | return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
12 | else:
13 | raise Exception('Unknown Type', type(img))
14 |
15 |
16 | def convert_rgb_to_ycbcr(img):
17 | if type(img) == np.ndarray:
18 | y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
19 | cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
20 | cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
21 | return np.array([y, cb, cr]).transpose([1, 2, 0])
22 | elif type(img) == torch.Tensor:
23 | if len(img.shape) == 4:
24 | img = img.squeeze(0)
25 | y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
26 | cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
27 | cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
28 | return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
29 | else:
30 | raise Exception('Unknown Type', type(img))
31 |
32 |
33 | def convert_ycbcr_to_rgb(img):
34 | if type(img) == np.ndarray:
35 | r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
36 | g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
37 | b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
38 | return np.array([r, g, b]).transpose([1, 2, 0])
39 | elif type(img) == torch.Tensor:
40 | if len(img.shape) == 4:
41 | img = img.squeeze(0)
42 | r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
43 | g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
44 | b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
45 | return torch.cat([r, g, b], 0).permute(1, 2, 0)
46 | else:
47 | raise Exception('Unknown Type', type(img))
48 |
49 |
50 | def calc_psnr(img1, img2):
51 | return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
52 |
53 |
54 | class AverageMeter(object):
55 | def __init__(self):
56 | self.reset()
57 |
58 | def reset(self):
59 | self.val = 0
60 | self.avg = 0
61 | self.sum = 0
62 | self.count = 0
63 |
64 | def update(self, val, n=1):
65 | self.val = val
66 | self.sum += val * n
67 | self.count += n
68 | self.avg = self.sum / self.count
69 |
--------------------------------------------------------------------------------