├── README.md ├── result ├── kinetics_classnames.json └── structure.png ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # SlowFast 2 | 3 | A PyTorch implementation of SlowFast based on ICCV 2019 paper 4 | [SlowFast Networks for Video Recognition](https://arxiv.org/abs/1812.03982). 5 | 6 | ![Network Architecture](result/structure.png) 7 | 8 | ## Requirements 9 | 10 | - [Anaconda](https://www.anaconda.com/download/) 11 | - [PyTorch](https://pytorch.org) 12 | 13 | ``` 14 | conda install pytorch=1.9.1 torchvision cudatoolkit -c pytorch 15 | ``` 16 | 17 | - [PyTorchVideo](https://pytorchvideo.org) 18 | 19 | ``` 20 | pip install pytorchvideo 21 | ``` 22 | 23 | ## Dataset 24 | 25 | [kinetics-400](https://github.com/cvdfoundation/kinetics-dataset) dataset is used in this repo, you could download these 26 | datasets from official websites. The data directory structure is shown as follows: 27 | 28 | ``` 29 | ├──data 30 | ├── train 31 | ├── abseiling 32 | ├── _4YTwq0-73Y_000044_000054.mp4 33 | └── ... 34 | ... 35 | ├── archery 36 | same structure as abseiling 37 | ├── test 38 | same structure as train 39 | ... 40 | ``` 41 | 42 | ## Usage 43 | 44 | ### Train Model 45 | 46 | ``` 47 | python train.py --batch_size 16 48 | optional arguments: 49 | --data_root Datasets root path [default value is 'data'] 50 | --batch_size Number of videos in each mini-batch [default value is 8] 51 | --epochs Number of epochs over the model to train [default value is 10] 52 | --save_root Result saved root path [default value is 'result'] 53 | ``` 54 | 55 | ### Test Model 56 | 57 | ``` 58 | python test.py --video_path data/test/beatboxing/5s_gFWie1Ys_000069_000079.mp4 59 | optional arguments: 60 | --model_path Model path [default value is 'result/slow_fast.pth'] 61 | --video_path Video path [default value is 'data/test/applauding/_V-dzjftmCQ_000023_000033.mp4'] 62 | ``` 63 | -------------------------------------------------------------------------------- /result/kinetics_classnames.json: -------------------------------------------------------------------------------- 1 | { 2 | "\"sharpening knives\"": 290, 3 | "\"eating ice cream\"": 115, 4 | "\"cutting nails\"": 81, 5 | "\"changing wheel\"": 53, 6 | "\"bench pressing\"": 19, 7 | "deadlifting": 88, 8 | "\"eating carrots\"": 111, 9 | "marching": 192, 10 | "\"throwing discus\"": 358, 11 | "\"playing flute\"": 231, 12 | "\"cooking on campfire\"": 72, 13 | "\"breading or breadcrumbing\"": 33, 14 | "\"playing badminton\"": 218, 15 | "\"ripping paper\"": 276, 16 | "\"playing saxophone\"": 244, 17 | "\"milking cow\"": 197, 18 | "\"juggling balls\"": 169, 19 | "\"flying kite\"": 130, 20 | "capoeira": 43, 21 | "\"making jewelry\"": 187, 22 | "drinking": 100, 23 | "\"playing cymbals\"": 228, 24 | "\"cleaning gutters\"": 61, 25 | "\"hurling (sport)\"": 161, 26 | "\"playing organ\"": 239, 27 | "\"tossing coin\"": 361, 28 | "wrestling": 395, 29 | "\"driving car\"": 103, 30 | "headbutting": 150, 31 | "\"gymnastics tumbling\"": 147, 32 | "\"making bed\"": 186, 33 | "abseiling": 0, 34 | "\"holding snake\"": 155, 35 | "\"rock climbing\"": 278, 36 | "\"cooking egg\"": 71, 37 | "\"long jump\"": 182, 38 | "\"bee keeping\"": 17, 39 | "\"trimming or shaving beard\"": 365, 40 | "\"cleaning shoes\"": 63, 41 | "\"dancing gangnam style\"": 86, 42 | "\"catching or throwing softball\"": 50, 43 | "\"ice skating\"": 164, 44 | "jogging": 168, 45 | "\"eating spaghetti\"": 116, 46 | "bobsledding": 28, 47 | "\"assembling computer\"": 8, 48 | "\"playing cricket\"": 227, 49 | "\"playing monopoly\"": 238, 50 | "\"golf putting\"": 143, 51 | "\"making pizza\"": 188, 52 | "\"javelin throw\"": 166, 53 | "\"peeling potatoes\"": 211, 54 | "clapping": 57, 55 | "\"brushing hair\"": 36, 56 | "\"flipping pancake\"": 129, 57 | "\"drinking beer\"": 101, 58 | "\"dribbling basketball\"": 99, 59 | "\"playing bagpipes\"": 219, 60 | "somersaulting": 325, 61 | "\"canoeing or kayaking\"": 42, 62 | "\"riding unicycle\"": 275, 63 | "texting": 355, 64 | "\"tasting beer\"": 352, 65 | "\"hockey stop\"": 154, 66 | "\"playing clarinet\"": 225, 67 | "\"waxing legs\"": 389, 68 | "\"curling hair\"": 80, 69 | "\"running on treadmill\"": 281, 70 | "\"tai chi\"": 346, 71 | "\"driving tractor\"": 104, 72 | "\"shaving legs\"": 293, 73 | "\"sharpening pencil\"": 291, 74 | "\"making sushi\"": 190, 75 | "\"spray painting\"": 327, 76 | "situp": 305, 77 | "\"playing kickball\"": 237, 78 | "\"sticking tongue out\"": 331, 79 | "headbanging": 149, 80 | "\"folding napkins\"": 132, 81 | "\"playing piano\"": 241, 82 | "skydiving": 312, 83 | "\"dancing charleston\"": 85, 84 | "\"ice fishing\"": 163, 85 | "tickling": 359, 86 | "bandaging": 13, 87 | "\"high jump\"": 151, 88 | "\"making a sandwich\"": 185, 89 | "\"riding mountain bike\"": 271, 90 | "\"cutting pineapple\"": 82, 91 | "\"feeding goats\"": 125, 92 | "\"dancing macarena\"": 87, 93 | "\"playing basketball\"": 220, 94 | "krumping": 179, 95 | "\"high kick\"": 152, 96 | "\"balloon blowing\"": 12, 97 | "\"playing accordion\"": 217, 98 | "\"playing chess\"": 224, 99 | "\"hula hooping\"": 159, 100 | "\"pushing wheelchair\"": 263, 101 | "\"riding camel\"": 268, 102 | "\"blowing out candles\"": 27, 103 | "\"extinguishing fire\"": 121, 104 | "\"using computer\"": 373, 105 | "\"jumpstyle dancing\"": 173, 106 | "yawning": 397, 107 | "writing": 396, 108 | "\"jumping into pool\"": 172, 109 | "\"doing laundry\"": 96, 110 | "\"egg hunting\"": 118, 111 | "\"sanding floor\"": 284, 112 | "\"moving furniture\"": 200, 113 | "\"exercising arm\"": 119, 114 | "\"sword fighting\"": 345, 115 | "\"sign language interpreting\"": 303, 116 | "\"counting money\"": 74, 117 | "bartending": 15, 118 | "\"cleaning windows\"": 65, 119 | "\"blasting sand\"": 23, 120 | "\"petting cat\"": 213, 121 | "sniffing": 320, 122 | "bowling": 31, 123 | "\"playing poker\"": 242, 124 | "\"taking a shower\"": 347, 125 | "\"washing hands\"": 382, 126 | "\"water sliding\"": 384, 127 | "\"presenting weather forecast\"": 254, 128 | "tobogganing": 360, 129 | "celebrating": 51, 130 | "\"getting a haircut\"": 138, 131 | "snorkeling": 321, 132 | "\"weaving basket\"": 390, 133 | "\"playing squash or racquetball\"": 245, 134 | "parasailing": 206, 135 | "\"news anchoring\"": 202, 136 | "\"belly dancing\"": 18, 137 | "windsurfing": 393, 138 | "\"braiding hair\"": 32, 139 | "\"crossing river\"": 78, 140 | "\"laying bricks\"": 181, 141 | "\"roller skating\"": 280, 142 | "hopscotch": 156, 143 | "\"playing trumpet\"": 248, 144 | "\"dying hair\"": 108, 145 | "\"trimming trees\"": 366, 146 | "\"pumping fist\"": 256, 147 | "\"playing keyboard\"": 236, 148 | "snowboarding": 322, 149 | "\"garbage collecting\"": 136, 150 | "\"playing controller\"": 226, 151 | "dodgeball": 94, 152 | "\"recording music\"": 266, 153 | "\"country line dancing\"": 75, 154 | "\"dancing ballet\"": 84, 155 | "gargling": 137, 156 | "ironing": 165, 157 | "\"push up\"": 260, 158 | "\"frying vegetables\"": 135, 159 | "\"ski jumping\"": 307, 160 | "\"mowing lawn\"": 201, 161 | "\"getting a tattoo\"": 139, 162 | "\"rock scissors paper\"": 279, 163 | "cheerleading": 55, 164 | "\"using remote controller (not gaming)\"": 374, 165 | "\"shaking head\"": 289, 166 | "sailing": 282, 167 | "\"training dog\"": 363, 168 | "hurdling": 160, 169 | "\"fixing hair\"": 128, 170 | "\"climbing ladder\"": 67, 171 | "\"filling eyebrows\"": 126, 172 | "\"springboard diving\"": 329, 173 | "\"eating watermelon\"": 117, 174 | "\"drumming fingers\"": 106, 175 | "\"waxing back\"": 386, 176 | "\"playing didgeridoo\"": 229, 177 | "\"swimming backstroke\"": 339, 178 | "\"biking through snow\"": 22, 179 | "\"washing feet\"": 380, 180 | "\"mopping floor\"": 198, 181 | "\"throwing ball\"": 357, 182 | "\"eating doughnuts\"": 113, 183 | "\"drinking shots\"": 102, 184 | "\"tying bow tie\"": 368, 185 | "dining": 91, 186 | "\"surfing water\"": 337, 187 | "\"sweeping floor\"": 338, 188 | "\"grooming dog\"": 145, 189 | "\"catching fish\"": 47, 190 | "\"pumping gas\"": 257, 191 | "\"riding or walking with horse\"": 273, 192 | "\"massaging person's head\"": 196, 193 | "archery": 5, 194 | "\"ice climbing\"": 162, 195 | "\"playing recorder\"": 243, 196 | "\"decorating the christmas tree\"": 89, 197 | "\"peeling apples\"": 210, 198 | "snowmobiling": 324, 199 | "\"playing ukulele\"": 249, 200 | "\"eating burger\"": 109, 201 | "\"building cabinet\"": 38, 202 | "\"stomping grapes\"": 332, 203 | "\"drop kicking\"": 105, 204 | "\"passing American football (not in game)\"": 209, 205 | "applauding": 3, 206 | "hugging": 158, 207 | "\"eating hotdog\"": 114, 208 | "\"pole vault\"": 253, 209 | "\"reading newspaper\"": 265, 210 | "\"snatch weight lifting\"": 318, 211 | "zumba": 399, 212 | "\"playing ice hockey\"": 235, 213 | "breakdancing": 34, 214 | "\"feeding fish\"": 124, 215 | "\"shredding paper\"": 300, 216 | "\"catching or throwing frisbee\"": 49, 217 | "\"exercising with an exercise ball\"": 120, 218 | "\"pushing cart\"": 262, 219 | "\"swimming butterfly stroke\"": 341, 220 | "\"riding scooter\"": 274, 221 | "spraying": 328, 222 | "\"folding paper\"": 133, 223 | "\"golf driving\"": 142, 224 | "\"robot dancing\"": 277, 225 | "\"bending back\"": 20, 226 | "testifying": 354, 227 | "\"waxing chest\"": 387, 228 | "\"carving pumpkin\"": 46, 229 | "\"hitting baseball\"": 153, 230 | "\"riding elephant\"": 269, 231 | "\"brushing teeth\"": 37, 232 | "\"pull ups\"": 255, 233 | "\"riding a bike\"": 267, 234 | "skateboarding": 306, 235 | "\"cleaning pool\"": 62, 236 | "\"playing paintball\"": 240, 237 | "\"massaging back\"": 193, 238 | "\"shoveling snow\"": 299, 239 | "\"surfing crowd\"": 336, 240 | "unboxing": 371, 241 | "faceplanting": 122, 242 | "trapezing": 364, 243 | "\"swinging legs\"": 343, 244 | "hoverboarding": 157, 245 | "\"playing violin\"": 250, 246 | "\"wrapping present\"": 394, 247 | "\"blowing nose\"": 26, 248 | "\"kicking field goal\"": 174, 249 | "\"picking fruit\"": 214, 250 | "\"swinging on something\"": 344, 251 | "\"giving or receiving award\"": 140, 252 | "\"planting trees\"": 215, 253 | "\"water skiing\"": 383, 254 | "\"washing dishes\"": 379, 255 | "\"punching bag\"": 258, 256 | "\"massaging legs\"": 195, 257 | "\"throwing axe\"": 356, 258 | "\"salsa dancing\"": 283, 259 | "bookbinding": 29, 260 | "\"tying tie\"": 370, 261 | "\"skiing crosscountry\"": 309, 262 | "\"shining shoes\"": 295, 263 | "\"making snowman\"": 189, 264 | "\"front raises\"": 134, 265 | "\"doing nails\"": 97, 266 | "\"massaging feet\"": 194, 267 | "\"playing drums\"": 230, 268 | "smoking": 316, 269 | "\"punching person (boxing)\"": 259, 270 | "cartwheeling": 45, 271 | "\"passing American football (in game)\"": 208, 272 | "\"shaking hands\"": 288, 273 | "plastering": 216, 274 | "\"watering plants\"": 385, 275 | "kissing": 176, 276 | "slapping": 314, 277 | "\"playing harmonica\"": 233, 278 | "welding": 391, 279 | "\"smoking hookah\"": 317, 280 | "\"scrambling eggs\"": 285, 281 | "\"cooking chicken\"": 70, 282 | "\"pushing car\"": 261, 283 | "\"opening bottle\"": 203, 284 | "\"cooking sausages\"": 73, 285 | "\"catching or throwing baseball\"": 48, 286 | "\"swimming breast stroke\"": 340, 287 | "digging": 90, 288 | "\"playing xylophone\"": 252, 289 | "\"doing aerobics\"": 95, 290 | "\"playing trombone\"": 247, 291 | "knitting": 178, 292 | "\"waiting in line\"": 377, 293 | "\"tossing salad\"": 362, 294 | "squat": 330, 295 | "vault": 376, 296 | "\"using segway\"": 375, 297 | "\"crawling baby\"": 77, 298 | "\"reading book\"": 264, 299 | "motorcycling": 199, 300 | "barbequing": 14, 301 | "\"cleaning floor\"": 60, 302 | "\"playing cello\"": 223, 303 | "drawing": 98, 304 | "auctioning": 9, 305 | "\"carrying baby\"": 44, 306 | "\"diving cliff\"": 93, 307 | "busking": 41, 308 | "\"cutting watermelon\"": 83, 309 | "\"scuba diving\"": 286, 310 | "\"riding mechanical bull\"": 270, 311 | "\"making tea\"": 191, 312 | "\"playing tennis\"": 246, 313 | "crying": 79, 314 | "\"dunking basketball\"": 107, 315 | "\"cracking neck\"": 76, 316 | "\"arranging flowers\"": 7, 317 | "\"building shed\"": 39, 318 | "\"golf chipping\"": 141, 319 | "\"tasting food\"": 353, 320 | "\"shaving head\"": 292, 321 | "\"answering questions\"": 2, 322 | "\"climbing tree\"": 68, 323 | "\"skipping rope\"": 311, 324 | "kitesurfing": 177, 325 | "\"juggling fire\"": 170, 326 | "laughing": 180, 327 | "paragliding": 205, 328 | "\"contact juggling\"": 69, 329 | "slacklining": 313, 330 | "\"arm wrestling\"": 6, 331 | "\"making a cake\"": 184, 332 | "\"finger snapping\"": 127, 333 | "\"grooming horse\"": 146, 334 | "\"opening present\"": 204, 335 | "\"tapping pen\"": 351, 336 | "singing": 304, 337 | "\"shot put\"": 298, 338 | "\"cleaning toilet\"": 64, 339 | "\"spinning poi\"": 326, 340 | "\"setting table\"": 287, 341 | "\"tying knot (not on a tie)\"": 369, 342 | "\"blowing glass\"": 24, 343 | "\"eating chips\"": 112, 344 | "\"tap dancing\"": 349, 345 | "\"climbing a rope\"": 66, 346 | "\"brush painting\"": 35, 347 | "\"chopping wood\"": 56, 348 | "\"stretching leg\"": 334, 349 | "\"petting animal (not cat)\"": 212, 350 | "\"baking cookies\"": 11, 351 | "\"stretching arm\"": 333, 352 | "beatboxing": 16, 353 | "jetskiing": 167, 354 | "\"bending metal\"": 21, 355 | "sneezing": 319, 356 | "\"folding clothes\"": 131, 357 | "\"sled dog racing\"": 315, 358 | "\"tapping guitar\"": 350, 359 | "\"bouncing on trampoline\"": 30, 360 | "\"waxing eyebrows\"": 388, 361 | "\"air drumming\"": 1, 362 | "\"kicking soccer ball\"": 175, 363 | "\"washing hair\"": 381, 364 | "\"riding mule\"": 272, 365 | "\"blowing leaves\"": 25, 366 | "\"strumming guitar\"": 335, 367 | "\"playing cards\"": 222, 368 | "snowkiting": 323, 369 | "\"playing bass guitar\"": 221, 370 | "\"applying cream\"": 4, 371 | "\"shooting basketball\"": 296, 372 | "\"walking the dog\"": 378, 373 | "\"triple jump\"": 367, 374 | "\"shearing sheep\"": 294, 375 | "\"clay pottery making\"": 58, 376 | "\"bungee jumping\"": 40, 377 | "\"unloading truck\"": 372, 378 | "\"shuffling cards\"": 301, 379 | "\"shooting goal (soccer)\"": 297, 380 | "\"tango dancing\"": 348, 381 | "\"side kick\"": 302, 382 | "\"grinding meat\"": 144, 383 | "yoga": 398, 384 | "\"hammer throw\"": 148, 385 | "\"changing oil\"": 52, 386 | "\"checking tires\"": 54, 387 | "parkour": 207, 388 | "\"eating cake\"": 110, 389 | "\"skiing slalom\"": 310, 390 | "\"juggling soccer ball\"": 171, 391 | "whistling": 392, 392 | "\"feeding birds\"": 123, 393 | "\"playing volleyball\"": 251, 394 | "\"swing dancing\"": 342, 395 | "\"skiing (not slalom or crosscountry)\"": 308, 396 | "lunge": 183, 397 | "\"disc golfing\"": 92, 398 | "\"clean and jerk\"": 59, 399 | "\"playing guitar\"": 232, 400 | "\"baby waking up\"": 10, 401 | "\"playing harp\"": 234 402 | } -------------------------------------------------------------------------------- /result/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/SlowFast/7c848524906d9028dd44a03034bab155eed67d25/result/structure.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import torch 5 | from pytorchvideo.data.encoded_video import EncodedVideo 6 | from pytorchvideo.models import create_slowfast 7 | 8 | from utils import num_classes, clip_duration, test_transform 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='Test Model') 12 | parser.add_argument('--model_path', default='result/slow_fast.pth', type=str, help='Model path') 13 | parser.add_argument('--video_path', default='data/test/applauding/_V-dzjftmCQ_000023_000033.mp4', type=str, 14 | help='Video path') 15 | 16 | opt = parser.parse_args() 17 | model_path, video_path = opt.model_path, opt.video_path 18 | slow_fast = create_slowfast(model_num_class=num_classes) 19 | slow_fast.load_state_dict(torch.load(model_path, 'cpu')) 20 | slow_fast = slow_fast.cuda().eval() 21 | with open('result/kinetics_classnames.json', 'r') as f: 22 | kinetics_classnames = json.load(f) 23 | 24 | # create an id to label name mapping 25 | kinetics_id_to_classname = {} 26 | for k, v in kinetics_classnames.items(): 27 | kinetics_id_to_classname[v] = str(k).replace('"', "") 28 | 29 | video = EncodedVideo.from_path(video_path, decode_audio=False) 30 | video_data = video.get_clip(start_sec=0, end_sec=clip_duration) 31 | video_data = test_transform(video_data) 32 | inputs = [i.cuda()[None, ...] for i in video_data['video']] 33 | pred = slow_fast(inputs) 34 | 35 | # get the predicted classes 36 | pred_classes = pred.topk(k=5).indices 37 | pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes[0]] 38 | print('predicted labels: {}'.format(pred_class_names)) 39 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from pytorchvideo.data import make_clip_sampler, labeled_video_dataset 10 | from pytorchvideo.models import create_slowfast 11 | from torch.backends import cudnn 12 | from torch.nn import CrossEntropyLoss 13 | from torch.optim import Adam 14 | from torch.utils.data import DataLoader 15 | from tqdm import tqdm 16 | 17 | from utils import train_transform, test_transform, clip_duration, num_classes 18 | 19 | # for reproducibility 20 | random.seed(1) 21 | np.random.seed(1) 22 | torch.manual_seed(1) 23 | cudnn.deterministic = True 24 | cudnn.benchmark = False 25 | 26 | 27 | # train for one epoch 28 | def train(model, data_loader, train_optimizer): 29 | model.train() 30 | total_loss, total_acc, total_num = 0.0, 0, 0 31 | train_bar = tqdm(data_loader, total=math.ceil(train_data.num_videos / batch_size), dynamic_ncols=True) 32 | for batch in train_bar: 33 | video, label = [i.cuda() for i in batch['video']], batch['label'].cuda() 34 | train_optimizer.zero_grad() 35 | pred = model(video) 36 | loss = loss_criterion(pred, label) 37 | total_loss += loss.item() * video[0].size(0) 38 | total_acc += (torch.eq(pred.argmax(dim=-1), label)).sum().item() 39 | loss.backward() 40 | train_optimizer.step() 41 | 42 | total_num += video[0].size(0) 43 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f} Acc: {:.2f}%' 44 | .format(epoch, epochs, total_loss / total_num, total_acc * 100 / total_num)) 45 | 46 | return total_loss / total_num, total_acc / total_num 47 | 48 | 49 | # test for one epoch 50 | def val(model, data_loader): 51 | model.eval() 52 | with torch.no_grad(): 53 | total_top_1, total_top_5, total_num = 0, 0, 0 54 | test_bar = tqdm(data_loader, total=math.ceil(test_data.num_videos / batch_size), dynamic_ncols=True) 55 | for batch in test_bar: 56 | video, label = [i.cuda() for i in batch['video']], batch['label'].cuda() 57 | pred = model(video) 58 | total_top_1 += (torch.eq(pred.argmax(dim=-1), label)).sum().item() 59 | total_top_5 += torch.any(torch.eq(pred.topk(k=5, dim=-1).indices, label.unsqueeze(dim=-1)), 60 | dim=-1).sum().item() 61 | total_num += video[0].size(0) 62 | test_bar.set_description('Test Epoch: [{}/{}] | Top-1:{:.2f}% | Top-5:{:.2f}%' 63 | .format(epoch, epochs, total_top_1 * 100 / total_num, 64 | total_top_5 * 100 / total_num)) 65 | return total_top_1 / total_num, total_top_5 / total_num 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser(description='Train Model') 70 | # common args 71 | parser.add_argument('--data_root', default='data', type=str, help='Datasets root path') 72 | parser.add_argument('--batch_size', default=8, type=int, help='Number of videos in each mini-batch') 73 | parser.add_argument('--epochs', default=10, type=int, help='Number of epochs over the model to train') 74 | parser.add_argument('--save_root', default='result', type=str, help='Result saved root path') 75 | 76 | # args parse 77 | args = parser.parse_args() 78 | data_root, batch_size, epochs, save_root = args.data_root, args.batch_size, args.epochs, args.save_root 79 | 80 | # data prepare 81 | train_data = labeled_video_dataset('{}/train'.format(data_root), make_clip_sampler('random', clip_duration), 82 | transform=train_transform, decode_audio=False) 83 | test_data = labeled_video_dataset('{}/test'.format(data_root), 84 | make_clip_sampler('constant_clips_per_video', clip_duration, 1), 85 | transform=test_transform, decode_audio=False) 86 | train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=8) 87 | test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=8) 88 | 89 | # model define, loss setup and optimizer config 90 | slow_fast = create_slowfast(model_num_class=num_classes).cuda() 91 | # slow_fast = torch.hub.load('facebookresearch/pytorchvideo:main', model='slowfast_r50', pretrained=True) 92 | loss_criterion = CrossEntropyLoss() 93 | optimizer = Adam(slow_fast.parameters(), lr=1e-1) 94 | 95 | # training loop 96 | results = {'loss': [], 'acc': [], 'top-1': [], 'top-5': []} 97 | if not os.path.exists(save_root): 98 | os.makedirs(save_root) 99 | best_acc = 0.0 100 | for epoch in range(1, epochs + 1): 101 | train_loss, train_acc = train(slow_fast, train_loader, optimizer) 102 | results['loss'].append(train_loss) 103 | results['acc'].append(train_acc * 100) 104 | top_1, top_5 = val(slow_fast, test_loader) 105 | results['top-1'].append(top_1 * 100) 106 | results['top-5'].append(top_5 * 100) 107 | # save statistics 108 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 109 | data_frame.to_csv('{}/metrics.csv'.format(save_root), index_label='epoch') 110 | 111 | if top_1 > best_acc: 112 | best_acc = top_1 113 | torch.save(slow_fast.state_dict(), '{}/slow_fast.pth'.format(save_root)) 114 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorchvideo.transforms import ApplyTransformToKey, UniformTemporalSubsample, RandomShortSideScale, \ 3 | ShortSideScale, Normalize 4 | from torch import nn 5 | from torchvision.transforms import Compose, Lambda, RandomCrop, RandomHorizontalFlip, CenterCrop 6 | 7 | side_size = 256 8 | max_size = 320 9 | mean = [0.45, 0.45, 0.45] 10 | std = [0.225, 0.225, 0.225] 11 | crop_size = 256 12 | num_frames = 32 13 | sampling_rate = 2 14 | frames_per_second = 30 15 | clip_duration = (num_frames * sampling_rate) / frames_per_second 16 | num_classes = 400 17 | 18 | 19 | class PackPathway(nn.Module): 20 | """ 21 | Transform for converting video frames as a list of tensors. 22 | """ 23 | 24 | def __init__(self, alpha=4): 25 | super().__init__() 26 | self.alpha = alpha 27 | 28 | def forward(self, frames): 29 | fast_pathway = frames 30 | # perform temporal sampling from the fast pathway. 31 | slow_pathway = torch.index_select(frames, 1, 32 | torch.linspace(0, frames.shape[1] - 1, frames.shape[1] // self.alpha).long()) 33 | frame_list = [slow_pathway, fast_pathway] 34 | return frame_list 35 | 36 | 37 | train_transform = ApplyTransformToKey(key="video", transform=Compose( 38 | [UniformTemporalSubsample(num_frames), Lambda(lambda x: x / 255.0), Normalize(mean, std), 39 | RandomShortSideScale(min_size=side_size, max_size=max_size), RandomCrop(crop_size), RandomHorizontalFlip(), 40 | PackPathway()])) 41 | test_transform = ApplyTransformToKey(key="video", transform=Compose( 42 | [UniformTemporalSubsample(num_frames), Lambda(lambda x: x / 255.0), Normalize(mean, std), 43 | ShortSideScale(size=side_size), CenterCrop(crop_size), PackPathway()])) 44 | --------------------------------------------------------------------------------