├── README.md ├── activity_net.v1-3.json ├── activity_net_depth.csv ├── cone_activity_net300.pth ├── cone_kinetics_300.pth ├── cone_moments_300.pth ├── dataset.py ├── demo.ipynb ├── kinetics_depth.csv ├── metric.py ├── model.py ├── moments_depth.csv ├── pmath.py ├── poster.pdf └── poster_jpg.jpg /README.md: -------------------------------------------------------------------------------- 1 | code for **Searching Actions on the Hyperbole**.
2 | 3 | The current code includes all the essentials to re-implement our paper,
4 | The demo code is a raw version, please leave comments if there is any difficuty to use it.
5 | 6 | # Requirements 7 | pytorch == 1.4.0
8 | geoopt == 0.2.0
9 | 10 | If you fail to use this code, please email me (uestc.longteng@gmail.com)
11 | 12 | # Data Hierarchy and Seen/Unseen Split. 13 | 14 | activity_net_depth.csv
15 | kinetics_depth.csv
16 | moments_depth.csv
17 | 18 | # Data files 19 | 20 | I have uploaded the extracted video features (~8GB) to Microsoft Onedrive, please contact me for data sharing. 21 | 22 | # Pretrained 300-D action embedding: 23 | 24 | cone_activity_net300.pth
25 | cone_kinetics_300.pth
26 | cone_moments_300.pth
27 | 28 | # Reference 29 | 30 | [Searching for Actions on the Hyperbole](http://openaccess.thecvf.com/content_CVPR_2020/html/Long_Searching_for_Actions_on_the_Hyperbole_CVPR_2020_paper.html) 31 | 32 | Teng Long, Pascal Mettes, Heng Tao Shen, Cees G. M. Snoek 33 | 34 | @InProceedings{Long_2020_CVPR,
35 | author = {Long, Teng and Mettes, Pascal and Shen, Heng Tao and Snoek, Cees G. M.},
36 | title = {Searching for Actions on the Hyperbole},
37 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
38 | month = {June},
39 | year = {2020}
40 | } 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /activity_net_depth.csv: -------------------------------------------------------------------------------- 1 | Ballet,Dancing,Arts and Entertainment,Root,1 2 | Tango,Dancing,Arts and Entertainment,Root,1 3 | Cheerleading,Dancing,Arts and Entertainment,Root,1 4 | Cumbia,Dancing,Arts and Entertainment,Root,1 5 | Breakdancing,Dancing,Arts and Entertainment,Root,1 6 | Belly dance,Dancing,Arts and Entertainment,Root,0 7 | Building sandcastles,Park activities,Arts and Entertainment,Root,1 8 | Fun sliding down,Park activities,Arts and Entertainment,Root,1 9 | Swinging at the playground,Park activities,Arts and Entertainment,Root,1 10 | Using the monkey bar,Park activities,Arts and Entertainment,Root,1 11 | Starting a campfire,Park activities,Arts and Entertainment,Root,0 12 | Drum corps,Playing musical instruments,Arts and Entertainment,Root,1 13 | Playing congas,Playing musical instruments,Arts and Entertainment,Root,1 14 | Playing drums,Playing musical instruments,Arts and Entertainment,Root,1 15 | Playing bagpipes,Playing musical instruments,Arts and Entertainment,Root,1 16 | Playing harmonica,Playing musical instruments,Arts and Entertainment,Root,1 17 | Playing saxophone,Playing musical instruments,Arts and Entertainment,Root,1 18 | Playing guitarra,Playing musical instruments,Arts and Entertainment,Root,1 19 | Playing flauta,Playing musical instruments,Arts and Entertainment,Root,1 20 | Playing piano,Playing musical instruments,Arts and Entertainment,Root,1 21 | Playing violin,Playing musical instruments,Arts and Entertainment,Root,1 22 | Playing accordion,Playing musical instruments,Arts and Entertainment,Root,0 23 | Having an ice cream,Eating and Drinking,Eating and drinking Activities,Root,1 24 | Drinking coffee,Eating and Drinking,Eating and drinking Activities,Root,1 25 | Drinking beer,Eating and Drinking,Eating and drinking Activities,Root,0 26 | Baking cookies,Food and drink preparation ,Eating and drinking Activities,Root,1 27 | Making a cake,Food and drink preparation ,Eating and drinking Activities,Root,1 28 | Making a lemonade,Food and drink preparation ,Eating and drinking Activities,Root,1 29 | Making an omelette,Food and drink preparation ,Eating and drinking Activities,Root,1 30 | Peeling potatoes,Food and drink preparation ,Eating and drinking Activities,Root,1 31 | Preparing pasta,Food and drink preparation ,Eating and drinking Activities,Root,1 32 | Preparing salad,Food and drink preparation ,Eating and drinking Activities,Root,1 33 | Making a sandwich,Food and drink preparation ,Eating and drinking Activities,Root,0 34 | Mixing drinks,Food and drink preparation ,Eating and drinking Activities,Root,0 35 | Washing dishes,Food and drink preparation ,Eating and drinking Activities,Root,1 36 | Clipping cat claws,Animals and Pets,Household Activities,Root,1 37 | Grooming dog,Animals and Pets,Household Activities,Root,1 38 | Bathing dog,Animals and Pets,Household Activities,Root,1 39 | Disc dog,Animals and Pets,Household Activities,Root,1 40 | Grooming horse,Animals and Pets,Household Activities,Root,1 41 | Walking the dog,Animals and Pets,Household Activities,Root,0 42 | Sharpening knives,Appliances/Tools/and Toys,Household Activities,Root,1 43 | Waxing skis,Appliances/Tools/and Toys,Household Activities,Root,1 44 | Welding,Appliances/Tools/and Toys,Household Activities,Root,0 45 | Mooping floor,Cleaning and Laundry,Household Activities,Root,1 46 | Cleaning windows,Cleaning and Laundry,Household Activities,Root,1 47 | Vacuuming floor,Cleaning and Laundry,Household Activities,Root,1 48 | Polishing forniture,Cleaning and Laundry,Household Activities,Root,1 49 | Ironing clothes,Cleaning and Laundry,Household Activities,Root,1 50 | Hand washing clothes,Cleaning and Laundry,Household Activities,Root,1 51 | Knitting,Cleaning and Laundry,Household Activities,Root,1 52 | Cleaning shoes,Cleaning and Laundry,Household Activities,Root,1 53 | Polishing shoes,Cleaning and Laundry,Household Activities,Root,0 54 | Shoveling snow,Exterior Maintenance/Repair/& Decoration,Household Activities,Root,1 55 | Fixing the roof,Exterior Maintenance/Repair/& Decoration,Household Activities,Root,1 56 | Roof shingle removal,Exterior Maintenance/Repair/& Decoration,Household Activities,Root,1 57 | Painting fence,Exterior Maintenance/Repair/& Decoration,Household Activities,Root,1 58 | Wrapping presents,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 59 | Painting furniture,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 60 | Chopping wood,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 61 | Carving jack-o-lanterns,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 62 | Cleaning sink,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 63 | Decorating the Christmas tree,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 64 | Hanging wallpaper,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 65 | Installing carpet,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 66 | Laying tile,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1 67 | Plastering,Interior Maintenance/Repair/& Decoration,Household Activities,Root,0 68 | Painting,Interior Maintenance/Repair/& Decoration,Household Activities,Root,0 69 | Blowing leaves,Lawn/Garden/and Houseplants,Household Activities,Root,1 70 | Cutting the grass,Lawn/Garden/and Houseplants,Household Activities,Root,1 71 | Raking leaves,Lawn/Garden/and Houseplants,Household Activities,Root,1 72 | Spread mulch,Lawn/Garden/and Houseplants,Household Activities,Root,1 73 | Trimming branches or hedges,Lawn/Garden/and Houseplants,Household Activities,Root,1 74 | Mowing the lawn,Lawn/Garden/and Houseplants,Household Activities,Root,0 75 | Assembling bicycle,Vehicle repair and maintenance,Household Activities,Root,1 76 | Changing car wheel,Vehicle repair and maintenance,Household Activities,Root,1 77 | Hand car wash,Vehicle repair and maintenance,Household Activities,Root,1 78 | Removing ice from car,Vehicle repair and maintenance,Household Activities,Root,1 79 | Fixing bicycle,Vehicle repair and maintenance,Household Activities,Root,0 80 | Putting in contact lenses,Dress up,Personal Care,Root,1 81 | Putting on shoes,Dress up,Personal Care,Root,1 82 | Putting on makeup,Dress up,Personal Care,Root,1 83 | Brushing hair,Dress up,Personal Care,Root,1 84 | Doing nails,Dress up,Personal Care,Root,1 85 | Applying sunscreen,Dress up,Personal Care,Root,0 86 | Removing curlers,Grooming,Personal Care,Root,0 87 | Getting a tattoo,Grooming,Personal Care,Root,1 88 | Getting a piercing,Grooming,Personal Care,Root,1 89 | Getting a haircut,Grooming,Personal Care,Root,1 90 | Blow-drying hair,Grooming,Personal Care,Root,1 91 | Braiding hair,Grooming,Personal Care,Root,1 92 | Shaving legs,Grooming,Personal Care,Root,1 93 | Gargling mouthwash,Wash up,Personal Care,Root,1 94 | Washing face,Wash up,Personal Care,Root,1 95 | Brushing teeth,Wash up,Personal Care,Root,1 96 | Washing hands,Wash up,Personal Care,Root,1 97 | Shaving,Wash up,Personal Care,Root,1 98 | Playing blackjack,Playing games,Relaxing and Leisure,Root,1 99 | Beer pong,Playing games,Relaxing and Leisure,Root,1 100 | Hitting a pinata,Playing games,Relaxing and Leisure,Root,1 101 | Hula hoop,Playing games,Relaxing and Leisure,Root,1 102 | Kite flying,Playing games,Relaxing and Leisure,Root,1 103 | Playing pool,Playing games,Relaxing and Leisure,Root,1 104 | Playing rubik cube,Playing games,Relaxing and Leisure,Root,1 105 | Riding bumper cars,Playing games,Relaxing and Leisure,Root,1 106 | Rock-paper-scissors,Playing games,Relaxing and Leisure,Root,1 107 | Shuffleboard,Playing games,Relaxing and Leisure,Root,1 108 | Slacklining,Playing games,Relaxing and Leisure,Root,1 109 | Table soccer,Playing games,Relaxing and Leisure,Root,1 110 | Throwing darts,Playing games,Relaxing and Leisure,Root,0 111 | Tug of war,Playing games,Relaxing and Leisure,Root,1 112 | Hopscotch,Playing games,Relaxing and Leisure,Root,0 113 | Playing ten pins,Playing games,Relaxing and Leisure,Root,1 114 | Smoking hookah,Tobacco and drug use,Relaxing and Leisure,Root,1 115 | Smoking a cigarette,Tobacco and drug use,Relaxing and Leisure,Root,0 116 | Cricket,Bat-and-ball games,Sports/Exercise/and Recreation,Root,1 117 | Playing kickball,Bat-and-ball games,Sports/Exercise/and Recreation,Root,0 118 | Canoeing,Boating,Sports/Exercise/and Recreation,Root,1 119 | Rafting,Boating,Sports/Exercise/and Recreation,Root,1 120 | River tubing,Boating,Sports/Exercise/and Recreation,Root,0 121 | Sailing,Boating,Sports/Exercise/and Recreation,Root,1 122 | Kayaking,Boating,Sports/Exercise/and Recreation,Root,0 123 | Zumba,Doing aerobics,Sports/Exercise/and Recreation,Root,1 124 | Doing step aerobics,Doing aerobics,Sports/Exercise/and Recreation,Root,0 125 | Using the pommel horse,Doing gymnastics,Sports/Exercise/and Recreation,Root,1 126 | Using the balance beam,Doing gymnastics,Sports/Exercise/and Recreation,Root,1 127 | Tumbling,Doing gymnastics,Sports/Exercise/and Recreation,Root,1 128 | Using parallel bars,Doing gymnastics,Sports/Exercise/and Recreation,Root,1 129 | Using uneven bars,Doing gymnastics,Sports/Exercise/and Recreation,Root,1 130 | Baton twirling,Doing gymnastics,Sports/Exercise/and Recreation,Root,0 131 | Playing polo,Equestrian sports,Sports/Exercise/and Recreation,Root,1 132 | Horseback riding,Equestrian sports,Sports/Exercise/and Recreation,Root,1 133 | Camel ride,Equestrian sports,Sports/Exercise/and Recreation,Root,0 134 | BMX,Extreme sports,Sports/Exercise/and Recreation,Root,1 135 | Rock climbing,Extreme sports,Sports/Exercise/and Recreation,Root,1 136 | Powerbocking,Extreme sports,Sports/Exercise/and Recreation,Root,1 137 | Paintball,Extreme sports,Sports/Exercise/and Recreation,Root,1 138 | Bungee jumping,Extreme sports,Sports/Exercise/and Recreation,Root,1 139 | Doing motocross,Extreme sports,Sports/Exercise/and Recreation,Root,0 140 | High jump,Field sports,Sports/Exercise/and Recreation,Root,1 141 | Discus throw,Field sports,Sports/Exercise/and Recreation,Root,1 142 | Javelin throw,Field sports,Sports/Exercise/and Recreation,Root,1 143 | Long jump,Field sports,Sports/Exercise/and Recreation,Root,1 144 | Triple jump,Field sports,Sports/Exercise/and Recreation,Root,1 145 | Shot put,Field sports,Sports/Exercise/and Recreation,Root,1 146 | Hammer throw,Field sports,Sports/Exercise/and Recreation,Root,1 147 | Archery,Field sports,Sports/Exercise/and Recreation,Root,1 148 | Pole vault,Field sports,Sports/Exercise/and Recreation,Root,0 149 | Running a marathon,Field sports,Sports/Exercise/and Recreation,Root,1 150 | Capoeira,Martial arts,Sports/Exercise/and Recreation,Root,1 151 | Doing kickboxing,Martial arts,Sports/Exercise/and Recreation,Root,1 152 | Doing karate,Martial arts,Sports/Exercise/and Recreation,Root,1 153 | Tai chi,Martial arts,Sports/Exercise/and Recreation,Root,0 154 | Layup drill in basketball,Playing basketball,Sports/Exercise/and Recreation,Root,1 155 | Dodgeball,Playing basketball,Sports/Exercise/and Recreation,Root,0 156 | Playing ice hockey,Playing hockey,Sports/Exercise/and Recreation,Root,1 157 | Playing field hockey,Playing hockey,Sports/Exercise/and Recreation,Root,1 158 | Croquet,Playing hockey,Sports/Exercise/and Recreation,Root,1 159 | Hurling,Playing hockey,Sports/Exercise/and Recreation,Root,0 160 | Beach soccer,Playing soccer,Sports/Exercise/and Recreation,Root,1 161 | Futsal,Playing soccer,Sports/Exercise/and Recreation,Root,0 162 | Playing beach volleyball,Playing volleyball,Sports/Exercise/and Recreation,Root,1 163 | Volleyball,Playing volleyball,Sports/Exercise/and Recreation,Root,0 164 | Ping-pong,Racquet sports ,Sports/Exercise/and Recreation,Root,1 165 | Tennis serve with ball bouncing,Racquet sports ,Sports/Exercise/and Recreation,Root,1 166 | Playing squash,Racquet sports ,Sports/Exercise/and Recreation,Root,1 167 | Playing lacrosse,Racquet sports ,Sports/Exercise/and Recreation,Root,1 168 | Playing racquetball,Racquet sports ,Sports/Exercise/and Recreation,Root,1 169 | Playing badminton,Racquet sports ,Sports/Exercise/and Recreation,Root,0 170 | Bullfighting,Rodeo competitions,Sports/Exercise/and Recreation,Root,1 171 | Calf roping,Rodeo competitions,Sports/Exercise/and Recreation,Root,0 172 | Longboarding,Roller sports,Sports/Exercise/and Recreation,Root,1 173 | Rollerblading,Roller sports,Sports/Exercise/and Recreation,Root,1 174 | Skateboarding,Roller sports,Sports/Exercise/and Recreation,Root,0 175 | Ice fishing,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,1 176 | Curling,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,1 177 | Skiing,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,1 178 | Snow tubing,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,1 179 | Snowboarding,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,0 180 | Elliptical trainer,Using cardiovascular equipment,Sports/Exercise/and Recreation,Root,1 181 | Using the rowing machine,Using cardiovascular equipment,Sports/Exercise/and Recreation,Root,1 182 | Spinning,Using cardiovascular equipment,Sports/Exercise/and Recreation,Root,0 183 | Scuba diving,Water sports,Sports/Exercise/and Recreation,Root,1 184 | Surfing,Water sports,Sports/Exercise/and Recreation,Root,1 185 | Swimming,Water sports,Sports/Exercise/and Recreation,Root,1 186 | Wakeboarding,Water sports,Sports/Exercise/and Recreation,Root,1 187 | Waterskiing,Water sports,Sports/Exercise/and Recreation,Root,1 188 | Springboard diving,Water sports,Sports/Exercise/and Recreation,Root,1 189 | Plataform diving,Water sports,Sports/Exercise/and Recreation,Root,1 190 | Windsurfing,Water sports,Sports/Exercise/and Recreation,Root,1 191 | Playing water polo,Water sports,Sports/Exercise/and Recreation,Root,0 192 | Clean and jerk,Weightlifting,Sports/Exercise/and Recreation,Root,1 193 | Snatch,Weightlifting,Sports/Exercise/and Recreation,Root,0 194 | Doing crunches,Working out,Sports/Exercise/and Recreation,Root,1 195 | Kneeling,Working out,Sports/Exercise/and Recreation,Root,1 196 | Rope skipping,Working out,Sports/Exercise/and Recreation,Root,0 197 | Doing fencing,Wrestling,Sports/Exercise/and Recreation,Root,1 198 | Doing a powerbomb,Wrestling,Sports/Exercise/and Recreation,Root,1 199 | Arm wrestling,Wrestling,Sports/Exercise/and Recreation,Root,1 200 | Sumo,Wrestling,Sports/Exercise/and Recreation,Root,0 -------------------------------------------------------------------------------- /cone_activity_net300.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/cone_activity_net300.pth -------------------------------------------------------------------------------- /cone_kinetics_300.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/cone_kinetics_300.pth -------------------------------------------------------------------------------- /cone_moments_300.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/cone_moments_300.pth -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch, pickle, json 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | def get_son2parent(csv_path): 5 | son2parent = dict() 6 | with open (csv_path,'r') as fp: 7 | lines = fp.readlines() 8 | for line in lines: 9 | if line[-1] == '\n': 10 | line = line[:-1] 11 | tmp_list = line.split(',') 12 | for i in range(len(tmp_list) - 2): # ignore last digit. 13 | key, value = tmp_list[i], tmp_list[i+1] 14 | son2parent[key] = value 15 | if '' in son2parent.keys(): 16 | del son2parent[''] 17 | return son2parent 18 | 19 | class My_DS(Dataset): 20 | def __init__(self, X, y, emb): 21 | self.X = X 22 | self.y = y 23 | self.emb = emb 24 | 25 | def __len__(self): 26 | return self.X.shape[0] 27 | 28 | def __getitem__(self, index): 29 | x_batch = self.X[index, :] 30 | y_batch = self.y[index] 31 | a_batch = self.emb[y_batch,:] 32 | 33 | return x_batch, y_batch, a_batch 34 | 35 | def get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size): 36 | 37 | train_dataset = My_DS(X=Xtr, y=ytr, emb=emb) 38 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) 39 | 40 | val_dataset = My_DS(X=Xval, y=yval, emb=emb) 41 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, drop_last=True) 42 | 43 | return train_loader, val_loader 44 | 45 | def split_act_data(feat, label, fns, anno_fn): # This only for ActivityNet. 46 | 47 | with open(anno_fn) as json_file: 48 | data = json.load(json_file) 49 | 50 | anno_tr = {key:value for key,value in data['database'].items() if value['annotations']} 51 | 52 | training_inx, validation_inx = [], [] 53 | 54 | for feat_fn in fns: 55 | # Using file name to get the corresponding clip_id, 找到文件名及对应clip_id 56 | key = feat_fn.split('.')[0].split('_')[-1] 57 | key = '_'.join(feat_fn.split('.')[0].split('_')[2:]) 58 | clip_inx = feat_fn.split('.')[0].split('_')[0] 59 | clip_inx = int(clip_inx) 60 | 61 | # Get the video annotation, i.e. the action label. 62 | video_anno = anno_tr[key] 63 | if video_anno['subset'] == 'validation': 64 | training_inx.append(False), validation_inx.append(True) 65 | elif video_anno['subset'] == 'training': 66 | training_inx.append(True), validation_inx.append(False) 67 | 68 | training_inx, validation_inx = torch.tensor(training_inx), torch.tensor(validation_inx) 69 | 70 | Xtr, Xval = feat[training_inx], feat[validation_inx] 71 | ytr, yval = label[training_inx], label[validation_inx] 72 | assert Xtr.dim() == 2 and ytr.dim() == 1 73 | 74 | return Xtr, Xval, ytr, yval 75 | 76 | def get_activitynet_dataset(feat_path = './data.pickle', anno_fn = './activity_net.v1-3.json'): 77 | # Generate X, the feature and y, the label. 78 | with open(feat_path,'rb') as f: 79 | data1 = pickle.load(f) 80 | 81 | label_set = list(set(data1['label'])) 82 | label_set.sort() # make sure that the label set is sorted for one-hot encoding. 83 | 84 | label = [label_set.index(item) for item in data1['label']] 85 | label = torch.tensor(label) 86 | feat, fns = data1['feat'], data1['fn'] 87 | 88 | Xtr, Xval, ytr, yval = split_act_data(feat, label, fns, anno_fn) 89 | 90 | return Xtr, Xval, ytr, yval, label_set 91 | 92 | 93 | def get_emb(emb_type, emb_file, label_set): 94 | 95 | n_cls = len(label_set) 96 | if emb_type == 'rand': 97 | emb = torch.rand(n_cls,300) 98 | 99 | elif emb_type == 'wacv': 100 | with open(emb_file, 'rb') as pickle_file: 101 | content = pickle.load(pickle_file) 102 | emb = content['embedding'] 103 | emb = torch.tensor(emb).cuda() 104 | 105 | 106 | elif emb_type == 'oh': 107 | one_hot = torch.zeros(n_cls, n_cls).long() 108 | emb = one_hot.scatter_(dim=1, index=torch.unsqueeze(torch.arange(n_cls), dim=1), src=torch.ones(n_cls, n_cls).long()) 109 | emb = emb.float() 110 | 111 | elif emb_type == 'glove': 112 | data2 = torch.load(emb_file, map_location='cpu') 113 | emb_names = data2['objects'] 114 | ext_emb = data2['embeddings'] # n_cls x dim 115 | 116 | emb = torch.zeros(n_cls,ext_emb.shape[1]) 117 | emb = emb.cuda() 118 | for i in range(n_cls): 119 | pos = emb_names.index(label_set[i]) 120 | emb[i] = ext_emb[pos,:] 121 | 122 | elif emb_type == 'hyp': 123 | data2 = torch.load(emb_file, map_location='cpu') 124 | emb_names = data2['objects'] 125 | ext_emb = data2['embeddings'] # 272 x dim 126 | emb = torch.zeros(n_cls,ext_emb.shape[1]) 127 | for i in range(n_cls): 128 | pos = emb_names.index(label_set[i]) 129 | emb[i] = ext_emb[pos,:] 130 | 131 | elif emb_type == 'cone': 132 | data2 = torch.load(emb_file, map_location='cpu') 133 | if type(data2) is zip: 134 | data2 = dict(data2) 135 | emb_names = list(data2.keys()) 136 | ext_emb = list(data2.values()) # 271 x dim, root is discarded 137 | ext_emb = torch.tensor(ext_emb) 138 | 139 | emb = torch.zeros(n_cls,ext_emb.shape[1]) 140 | for i in range(n_cls): 141 | pos = emb_names.index(label_set[i]) 142 | emb[i] = ext_emb[pos,:] 143 | 144 | return emb 145 | 146 | def get_kineticslike_dataset(train_pth_path, valid_pth_path): 147 | data_train = torch.load(train_pth_path) 148 | data_val = torch.load(valid_pth_path) 149 | 150 | label_set = list(set(data_train['label'])) 151 | label_set.sort() # 有Sort很重要 152 | 153 | Xtr, Xval = data_train['feat'], data_val['feat'] 154 | ytr = torch.tensor([label_set.index(item )for item in data_train['label']]) 155 | yval = torch.tensor([label_set.index(item )for item in data_val['label']]) 156 | 157 | return Xtr, Xval, ytr, yval, label_set 158 | 159 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('../')\n", 11 | "from dataset import get_activitynet_dataset, get_kineticslike_dataset\n", 12 | "from dataset import get_son2parent, get_emb, get_dataloader\n", 13 | "from metric import Metric\n", 14 | "import os\n", 15 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 16 | "\n", 17 | "dataset_name, emb_type = 'activitynet', 'cone'\n", 18 | "\n", 19 | "if dataset_name == 'activitynet':\n", 20 | " tree_file = '../activity_net_depth_v5.csv'\n", 21 | " if emb_type == 'cone' : emb_file = '../cone_activity_net300.pth' \n", 22 | " if emb_type == 'glove' : emb_file = '../w2v_activity_300d.pth'\n", 23 | " feat_path, anno_fn = '../data.pickle', '../activity_net.v1-3.json'\n", 24 | " Xtr, Xval, ytr, yval, label_set = get_activitynet_dataset(feat_path, anno_fn)\n", 25 | " \n", 26 | "elif dataset_name == 'kinetics':\n", 27 | " tree_file = '../kinetics_depth_v2.csv'\n", 28 | " if emb_type == 'cone' : emb_file = '../cones_kinetics300.pth'\n", 29 | " if emb_type == 'glove' : emb_file = '../w2v_kinetics_300d.pth'\n", 30 | " train_pth_path, valid_pth_path = '../kinetics_train.pth', '../kinetics_val.pth'\n", 31 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n", 32 | " \n", 33 | "elif dataset_name == 'moments':\n", 34 | " tree_file = '../moments_depth_v4.csv'\n", 35 | " if emb_type == 'cone' : emb_file = '../cones_moments300.pth'\n", 36 | " if emb_type == 'glove' : emb_file = '../w2v_moments_300d.pth'\n", 37 | " train_pth_path, valid_pth_path = '../moments_train.pth', '../moments_val.pth'\n", 38 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n", 39 | "\n", 40 | "son2parent = get_son2parent(tree_file)\n", 41 | "emb = get_emb('hyp', emb_file, label_set)\n", 42 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 43 | "metric = Metric(label_set, son2parent)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "### T=0.1" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 7, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "loss:1.291,acc:0.669,1hop_acc:0.829,2hop_acc:0.908,mAP:0.584,1hop_mAP:0.767,2hop_mAP:0.871\n", 63 | "mAP:0.620,1hop_mAP:0.958,2hop_mAP:0.983\n", 64 | "loss:1.185,acc:0.711,1hop_acc:0.838,2hop_acc:0.908,mAP:0.623,1hop_mAP:0.791,2hop_mAP:0.885\n", 65 | "mAP:0.712,1hop_mAP:0.955,2hop_mAP:0.979\n", 66 | "loss:1.226,acc:0.716,1hop_acc:0.837,2hop_acc:0.909,mAP:0.636,1hop_mAP:0.798,2hop_mAP:0.889\n", 67 | "mAP:0.720,1hop_mAP:0.960,2hop_mAP:0.982\n", 68 | "loss:1.216,acc:0.726,1hop_acc:0.849,2hop_acc:0.914,mAP:0.643,1hop_mAP:0.802,2hop_mAP:0.891\n", 69 | "mAP:0.730,1hop_mAP:0.960,2hop_mAP:0.981\n", 70 | "loss:1.272,acc:0.726,1hop_acc:0.849,2hop_acc:0.911,mAP:0.645,1hop_mAP:0.803,2hop_mAP:0.890\n", 71 | "mAP:0.730,1hop_mAP:0.962,2hop_mAP:0.981\n", 72 | "loss:1.268,acc:0.734,1hop_acc:0.849,2hop_acc:0.915,mAP:0.651,1hop_mAP:0.805,2hop_mAP:0.892\n", 73 | "mAP:0.740,1hop_mAP:0.960,2hop_mAP:0.981\n", 74 | "loss:1.262,acc:0.733,1hop_acc:0.847,2hop_acc:0.913,mAP:0.652,1hop_mAP:0.806,2hop_mAP:0.893\n", 75 | "mAP:0.740,1hop_mAP:0.959,2hop_mAP:0.981\n", 76 | "loss:1.299,acc:0.720,1hop_acc:0.840,2hop_acc:0.912,mAP:0.653,1hop_mAP:0.803,2hop_mAP:0.890\n", 77 | "mAP:0.741,1hop_mAP:0.959,2hop_mAP:0.981\n", 78 | "loss:1.333,acc:0.722,1hop_acc:0.843,2hop_acc:0.905,mAP:0.653,1hop_mAP:0.805,2hop_mAP:0.890\n", 79 | "mAP:0.740,1hop_mAP:0.956,2hop_mAP:0.978\n", 80 | "loss:1.379,acc:0.708,1hop_acc:0.831,2hop_acc:0.906,mAP:0.649,1hop_mAP:0.800,2hop_mAP:0.889\n", 81 | "mAP:0.737,1hop_mAP:0.957,2hop_mAP:0.979\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n", 87 | "from model import RegressNet\n", 88 | "import torch\n", 89 | "c, T, epochs, dist_func, eval_dist = .1,.1, 50, pair_wise_hyp, pair_wise_cos\n", 90 | "torch.manual_seed(42)\n", 91 | "emb = emb.cuda()\n", 92 | "\n", 93 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 94 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n", 95 | "model = model.cuda()\n", 96 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 97 | "model._train(optimizer, epochs, T, eval_interval = 5)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "### T = 1" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 8, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "loss:3.328,acc:0.635,1hop_acc:0.821,2hop_acc:0.898,mAP:0.568,1hop_mAP:0.811,2hop_mAP:0.895\n", 117 | "mAP:0.595,1hop_mAP:0.977,2hop_mAP:0.988\n", 118 | "loss:3.025,acc:0.731,1hop_acc:0.858,2hop_acc:0.920,mAP:0.646,1hop_mAP:0.847,2hop_mAP:0.913\n", 119 | "mAP:0.752,1hop_mAP:0.965,2hop_mAP:0.981\n", 120 | "loss:2.968,acc:0.744,1hop_acc:0.858,2hop_acc:0.919,mAP:0.667,1hop_mAP:0.849,2hop_mAP:0.913\n", 121 | "mAP:0.778,1hop_mAP:0.960,2hop_mAP:0.977\n", 122 | "loss:2.945,acc:0.749,1hop_acc:0.859,2hop_acc:0.919,mAP:0.675,1hop_mAP:0.847,2hop_mAP:0.911\n", 123 | "mAP:0.788,1hop_mAP:0.955,2hop_mAP:0.974\n", 124 | "loss:2.934,acc:0.751,1hop_acc:0.858,2hop_acc:0.918,mAP:0.678,1hop_mAP:0.844,2hop_mAP:0.909\n", 125 | "mAP:0.790,1hop_mAP:0.951,2hop_mAP:0.972\n", 126 | "loss:2.925,acc:0.750,1hop_acc:0.857,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.842,2hop_mAP:0.908\n", 127 | "mAP:0.791,1hop_mAP:0.948,2hop_mAP:0.971\n", 128 | "loss:2.917,acc:0.750,1hop_acc:0.856,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.840,2hop_mAP:0.908\n", 129 | "mAP:0.792,1hop_mAP:0.945,2hop_mAP:0.969\n", 130 | "loss:2.910,acc:0.751,1hop_acc:0.857,2hop_acc:0.918,mAP:0.680,1hop_mAP:0.840,2hop_mAP:0.907\n", 131 | "mAP:0.791,1hop_mAP:0.944,2hop_mAP:0.968\n", 132 | "loss:2.905,acc:0.751,1hop_acc:0.857,2hop_acc:0.919,mAP:0.680,1hop_mAP:0.839,2hop_mAP:0.907\n", 133 | "mAP:0.790,1hop_mAP:0.942,2hop_mAP:0.967\n", 134 | "loss:2.901,acc:0.750,1hop_acc:0.856,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.838,2hop_mAP:0.907\n", 135 | "mAP:0.791,1hop_mAP:0.942,2hop_mAP:0.967\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n", 141 | "from model import RegressNet\n", 142 | "import torch\n", 143 | "c, T, epochs, dist_func, eval_dist = .1,1, 50, pair_wise_hyp, pair_wise_cos\n", 144 | "torch.manual_seed(42)\n", 145 | "emb = emb.cuda()\n", 146 | "\n", 147 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 148 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n", 149 | "model = model.cuda()\n", 150 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 151 | "model._train(optimizer, epochs, T, eval_interval = 5)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "### Step2 200D hyperbolic embedding, for all three datasets" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 1, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "import sys\n", 168 | "sys.path.append('../')\n", 169 | "from dataset import get_activitynet_dataset, get_kineticslike_dataset\n", 170 | "from dataset import get_son2parent, get_emb, get_dataloader\n", 171 | "from metric import Metric\n", 172 | "import os\n", 173 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 174 | "\n", 175 | "dataset_name, emb_type = 'activitynet', 'cone'\n", 176 | "\n", 177 | "if dataset_name == 'activitynet':\n", 178 | " tree_file = '../activity_net_depth_v5.csv'\n", 179 | " if emb_type == 'cone' : emb_file = '../cone_activity_net200.pth' \n", 180 | " if emb_type == 'glove' : emb_file = '../w2v_activity_300d.pth'\n", 181 | " feat_path, anno_fn = '../data.pickle', '../activity_net.v1-3.json'\n", 182 | " Xtr, Xval, ytr, yval, label_set = get_activitynet_dataset(feat_path, anno_fn)\n", 183 | " \n", 184 | "elif dataset_name == 'kinetics':\n", 185 | " tree_file = '../kinetics_depth_v2.csv'\n", 186 | " if emb_type == 'cone' : emb_file = '../cones_kinetics300.pth'\n", 187 | " if emb_type == 'glove' : emb_file = '../w2v_kinetics_300d.pth'\n", 188 | " train_pth_path, valid_pth_path = '../kinetics_train.pth', '../kinetics_val.pth'\n", 189 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n", 190 | " \n", 191 | "elif dataset_name == 'moments':\n", 192 | " tree_file = '../moments_depth_v4.csv'\n", 193 | " if emb_type == 'cone' : emb_file = '../cones_moments300.pth'\n", 194 | " if emb_type == 'glove' : emb_file = '../w2v_moments_300d.pth'\n", 195 | " train_pth_path, valid_pth_path = '../moments_train.pth', '../moments_val.pth'\n", 196 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n", 197 | "\n", 198 | "son2parent = get_son2parent(tree_file)\n", 199 | "emb = get_emb('hyp', emb_file, label_set)\n", 200 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 201 | "metric = Metric(label_set, son2parent)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 2, 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "loss:3.318,acc:0.649,1hop_acc:0.822,2hop_acc:0.900,mAP:0.571,1hop_mAP:0.812,2hop_mAP:0.896\n", 214 | "mAP:0.608,1hop_mAP:0.974,2hop_mAP:0.986\n", 215 | "loss:3.034,acc:0.729,1hop_acc:0.858,2hop_acc:0.917,mAP:0.647,1hop_mAP:0.847,2hop_mAP:0.913\n", 216 | "mAP:0.756,1hop_mAP:0.966,2hop_mAP:0.981\n", 217 | "loss:2.980,acc:0.742,1hop_acc:0.860,2hop_acc:0.919,mAP:0.666,1hop_mAP:0.850,2hop_mAP:0.914\n", 218 | "mAP:0.778,1hop_mAP:0.961,2hop_mAP:0.978\n", 219 | "loss:2.955,acc:0.748,1hop_acc:0.858,2hop_acc:0.918,mAP:0.674,1hop_mAP:0.848,2hop_mAP:0.911\n", 220 | "mAP:0.787,1hop_mAP:0.955,2hop_mAP:0.974\n", 221 | "loss:2.942,acc:0.749,1hop_acc:0.859,2hop_acc:0.918,mAP:0.676,1hop_mAP:0.845,2hop_mAP:0.909\n", 222 | "mAP:0.789,1hop_mAP:0.952,2hop_mAP:0.972\n", 223 | "loss:2.933,acc:0.747,1hop_acc:0.858,2hop_acc:0.916,mAP:0.678,1hop_mAP:0.843,2hop_mAP:0.908\n", 224 | "mAP:0.789,1hop_mAP:0.950,2hop_mAP:0.971\n", 225 | "loss:2.925,acc:0.749,1hop_acc:0.858,2hop_acc:0.916,mAP:0.679,1hop_mAP:0.843,2hop_mAP:0.907\n", 226 | "mAP:0.789,1hop_mAP:0.948,2hop_mAP:0.970\n", 227 | "loss:2.921,acc:0.750,1hop_acc:0.858,2hop_acc:0.917,mAP:0.679,1hop_mAP:0.842,2hop_mAP:0.907\n", 228 | "mAP:0.789,1hop_mAP:0.945,2hop_mAP:0.969\n", 229 | "loss:2.915,acc:0.751,1hop_acc:0.859,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.841,2hop_mAP:0.907\n", 230 | "mAP:0.790,1hop_mAP:0.945,2hop_mAP:0.969\n", 231 | "loss:2.909,acc:0.750,1hop_acc:0.858,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.841,2hop_mAP:0.907\n", 232 | "mAP:0.791,1hop_mAP:0.945,2hop_mAP:0.968\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n", 238 | "from model import RegressNet\n", 239 | "import torch\n", 240 | "c, T, epochs, dist_func, eval_dist = .1,1, 50, pair_wise_hyp, pair_wise_cos\n", 241 | "torch.manual_seed(42)\n", 242 | "emb = emb.cuda()\n", 243 | "\n", 244 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 245 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n", 246 | "model = model.cuda()\n", 247 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 248 | "model._train(optimizer, epochs, T, eval_interval = 5)" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 4, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "import sys\n", 258 | "sys.path.append('../')\n", 259 | "from dataset import get_activitynet_dataset, get_kineticslike_dataset\n", 260 | "from dataset import get_son2parent, get_emb, get_dataloader\n", 261 | "from metric import Metric\n", 262 | "import os\n", 263 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 264 | "\n", 265 | "dataset_name, emb_type = 'kinetics', 'cone'\n", 266 | "\n", 267 | "if dataset_name == 'activitynet':\n", 268 | " tree_file = '../activity_net_depth_v5.csv'\n", 269 | " if emb_type == 'cone' : emb_file = '../cone_activity_net200.pth' \n", 270 | " if emb_type == 'glove' : emb_file = '../w2v_activity_300d.pth'\n", 271 | " feat_path, anno_fn = '../data.pickle', '../activity_net.v1-3.json'\n", 272 | " Xtr, Xval, ytr, yval, label_set = get_activitynet_dataset(feat_path, anno_fn)\n", 273 | " \n", 274 | "elif dataset_name == 'kinetics':\n", 275 | " tree_file = '../kinetics_depth_v2.csv'\n", 276 | " if emb_type == 'cone' : emb_file = '../cone_kinetics_200.pth'\n", 277 | " if emb_type == 'glove' : emb_file = '../w2v_kinetics_300d.pth'\n", 278 | " train_pth_path, valid_pth_path = '../kinetics_train.pth', '../kinetics_val.pth'\n", 279 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n", 280 | " \n", 281 | "elif dataset_name == 'moments':\n", 282 | " tree_file = '../moments_depth_v4.csv'\n", 283 | " if emb_type == 'cone' : emb_file = '../cones_moments300.pth'\n", 284 | " if emb_type == 'glove' : emb_file = '../w2v_moments_300d.pth'\n", 285 | " train_pth_path, valid_pth_path = '../moments_train.pth', '../moments_val.pth'\n", 286 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n", 287 | "\n", 288 | "son2parent = get_son2parent(tree_file)\n", 289 | "emb = get_emb('hyp', emb_file, label_set)\n", 290 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 291 | "metric = Metric(label_set, son2parent)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 5, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "loss:2.935,acc:0.762,1hop_acc:0.872,2hop_acc:0.914,mAP:0.573,1hop_mAP:0.844,2hop_mAP:0.894\n", 304 | "mAP:0.698,1hop_mAP:0.956,2hop_mAP:0.969\n", 305 | "loss:2.845,acc:0.770,1hop_acc:0.870,2hop_acc:0.913,mAP:0.595,1hop_mAP:0.834,2hop_mAP:0.887\n", 306 | "mAP:0.719,1hop_mAP:0.937,2hop_mAP:0.957\n", 307 | "loss:2.846,acc:0.771,1hop_acc:0.870,2hop_acc:0.912,mAP:0.594,1hop_mAP:0.827,2hop_mAP:0.882\n", 308 | "mAP:0.717,1hop_mAP:0.929,2hop_mAP:0.951\n", 309 | "loss:2.850,acc:0.769,1hop_acc:0.867,2hop_acc:0.911,mAP:0.593,1hop_mAP:0.824,2hop_mAP:0.880\n", 310 | "mAP:0.715,1hop_mAP:0.928,2hop_mAP:0.950\n", 311 | "loss:2.854,acc:0.762,1hop_acc:0.864,2hop_acc:0.908,mAP:0.590,1hop_mAP:0.821,2hop_mAP:0.878\n", 312 | "mAP:0.713,1hop_mAP:0.927,2hop_mAP:0.949\n", 313 | "loss:2.853,acc:0.763,1hop_acc:0.863,2hop_acc:0.906,mAP:0.590,1hop_mAP:0.822,2hop_mAP:0.878\n", 314 | "mAP:0.712,1hop_mAP:0.926,2hop_mAP:0.949\n", 315 | "loss:2.856,acc:0.762,1hop_acc:0.863,2hop_acc:0.908,mAP:0.589,1hop_mAP:0.819,2hop_mAP:0.876\n", 316 | "mAP:0.711,1hop_mAP:0.924,2hop_mAP:0.947\n", 317 | "loss:2.854,acc:0.763,1hop_acc:0.860,2hop_acc:0.903,mAP:0.589,1hop_mAP:0.816,2hop_mAP:0.871\n", 318 | "mAP:0.711,1hop_mAP:0.922,2hop_mAP:0.944\n", 319 | "loss:2.857,acc:0.763,1hop_acc:0.859,2hop_acc:0.903,mAP:0.588,1hop_mAP:0.814,2hop_mAP:0.871\n", 320 | "mAP:0.711,1hop_mAP:0.921,2hop_mAP:0.944\n", 321 | "loss:2.853,acc:0.766,1hop_acc:0.863,2hop_acc:0.904,mAP:0.589,1hop_mAP:0.815,2hop_mAP:0.869\n", 322 | "mAP:0.711,1hop_mAP:0.919,2hop_mAP:0.942\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n", 328 | "from model import RegressNet\n", 329 | "import torch\n", 330 | "c, T, epochs, dist_func, eval_dist = .1,1, 50, pair_wise_hyp, pair_wise_cos\n", 331 | "torch.manual_seed(42)\n", 332 | "emb = emb.cuda()\n", 333 | "\n", 334 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 335 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n", 336 | "model = model.cuda()\n", 337 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 338 | "model._train(optimizer, epochs, T, eval_interval = 5)" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 7, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "import sys\n", 348 | "sys.path.append('../')\n", 349 | "from dataset import get_activitynet_dataset, get_kineticslike_dataset\n", 350 | "from dataset import get_son2parent, get_emb, get_dataloader\n", 351 | "from metric import Metric\n", 352 | "import os\n", 353 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 354 | "\n", 355 | "dataset_name, emb_type = 'moments', 'cone'\n", 356 | "\n", 357 | "if dataset_name == 'activitynet':\n", 358 | " tree_file = '../activity_net_depth_v5.csv'\n", 359 | " if emb_type == 'cone' : emb_file = '../cone_activity_net200.pth' \n", 360 | " if emb_type == 'glove' : emb_file = '../w2v_activity_300d.pth'\n", 361 | " feat_path, anno_fn = '../data.pickle', '../activity_net.v1-3.json'\n", 362 | " Xtr, Xval, ytr, yval, label_set = get_activitynet_dataset(feat_path, anno_fn)\n", 363 | " \n", 364 | "elif dataset_name == 'kinetics':\n", 365 | " tree_file = '../kinetics_depth_v2.csv'\n", 366 | " if emb_type == 'cone' : emb_file = '../cone_kinetics_200.pth'\n", 367 | " if emb_type == 'glove' : emb_file = '../w2v_kinetics_300d.pth'\n", 368 | " train_pth_path, valid_pth_path = '../kinetics_train.pth', '../kinetics_val.pth'\n", 369 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n", 370 | " \n", 371 | "elif dataset_name == 'moments':\n", 372 | " tree_file = '../moments_depth_v5.csv'\n", 373 | " if emb_type == 'cone' : emb_file = '../cone_moments_200.pth'\n", 374 | " if emb_type == 'glove' : emb_file = '../w2v_moments_300d.pth'\n", 375 | " train_pth_path, valid_pth_path = '../moments_train.pth', '../moments_val.pth'\n", 376 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n", 377 | "\n", 378 | "son2parent = get_son2parent(tree_file)\n", 379 | "emb = get_emb('hyp', emb_file, label_set)\n", 380 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 381 | "metric = Metric(label_set, son2parent)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 8, 387 | "metadata": {}, 388 | "outputs": [ 389 | { 390 | "name": "stdout", 391 | "output_type": "stream", 392 | "text": [ 393 | "loss:4.406,acc:0.133,1hop_acc:0.188,2hop_acc:0.375,mAP:0.149,1hop_mAP:0.182,2hop_mAP:0.365\n", 394 | "mAP:0.209,1hop_mAP:0.388,2hop_mAP:0.527\n", 395 | "loss:4.365,acc:0.153,1hop_acc:0.207,2hop_acc:0.380,mAP:0.157,1hop_mAP:0.190,2hop_mAP:0.373\n", 396 | "mAP:0.254,1hop_mAP:0.414,2hop_mAP:0.557\n", 397 | "loss:4.353,acc:0.160,1hop_acc:0.214,2hop_acc:0.391,mAP:0.160,1hop_mAP:0.196,2hop_mAP:0.378\n", 398 | "mAP:0.273,1hop_mAP:0.417,2hop_mAP:0.561\n", 399 | "loss:4.352,acc:0.164,1hop_acc:0.218,2hop_acc:0.393,mAP:0.162,1hop_mAP:0.199,2hop_mAP:0.380\n", 400 | "mAP:0.281,1hop_mAP:0.419,2hop_mAP:0.563\n", 401 | "loss:4.356,acc:0.164,1hop_acc:0.219,2hop_acc:0.398,mAP:0.163,1hop_mAP:0.200,2hop_mAP:0.381\n", 402 | "mAP:0.287,1hop_mAP:0.420,2hop_mAP:0.564\n", 403 | "loss:4.358,acc:0.166,1hop_acc:0.222,2hop_acc:0.403,mAP:0.163,1hop_mAP:0.201,2hop_mAP:0.381\n", 404 | "mAP:0.292,1hop_mAP:0.419,2hop_mAP:0.560\n", 405 | "loss:4.366,acc:0.166,1hop_acc:0.222,2hop_acc:0.402,mAP:0.163,1hop_mAP:0.201,2hop_mAP:0.381\n", 406 | "mAP:0.296,1hop_mAP:0.419,2hop_mAP:0.561\n", 407 | "loss:4.370,acc:0.167,1hop_acc:0.223,2hop_acc:0.406,mAP:0.163,1hop_mAP:0.201,2hop_mAP:0.382\n", 408 | "mAP:0.300,1hop_mAP:0.415,2hop_mAP:0.561\n", 409 | "loss:4.374,acc:0.166,1hop_acc:0.222,2hop_acc:0.409,mAP:0.163,1hop_mAP:0.201,2hop_mAP:0.383\n", 410 | "mAP:0.302,1hop_mAP:0.415,2hop_mAP:0.559\n", 411 | "loss:4.381,acc:0.164,1hop_acc:0.218,2hop_acc:0.407,mAP:0.163,1hop_mAP:0.200,2hop_mAP:0.382\n", 412 | "mAP:0.307,1hop_mAP:0.414,2hop_mAP:0.564\n" 413 | ] 414 | } 415 | ], 416 | "source": [ 417 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n", 418 | "from model import RegressNet\n", 419 | "import torch\n", 420 | "c, T, epochs, dist_func, eval_dist = .1,1, 20, pair_wise_hyp, pair_wise_cos\n", 421 | "torch.manual_seed(42)\n", 422 | "emb = emb.cuda()\n", 423 | "\n", 424 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n", 425 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n", 426 | "model = model.cuda()\n", 427 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 428 | "model._train(optimizer, epochs, T, eval_interval = 2)" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": null, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [] 437 | } 438 | ], 439 | "metadata": { 440 | "kernelspec": { 441 | "display_name": "Python 3", 442 | "language": "python", 443 | "name": "python3" 444 | }, 445 | "language_info": { 446 | "codemirror_mode": { 447 | "name": "ipython", 448 | "version": 3 449 | }, 450 | "file_extension": ".py", 451 | "mimetype": "text/x-python", 452 | "name": "python", 453 | "nbconvert_exporter": "python", 454 | "pygments_lexer": "ipython3", 455 | "version": "3.7.4" 456 | } 457 | }, 458 | "nbformat": 4, 459 | "nbformat_minor": 4 460 | } 461 | -------------------------------------------------------------------------------- /kinetics_depth.csv: -------------------------------------------------------------------------------- 1 | blowing glass,arts and crafts,Arts and Entertainment,Root,1 2 | spray painting,arts and crafts,Arts and Entertainment,Root,1 3 | weaving basket,arts and crafts,Arts and Entertainment,Root,0 4 | headbanging,body motions,Arts and Entertainment,Root,1 5 | pumping fist,body motions,Arts and Entertainment,Root,1 6 | stretching leg,body motions,Arts and Entertainment,Root,0 7 | belly dancing,dancing,Arts and Entertainment,Root,1 8 | breakdancing,dancing,Arts and Entertainment,Root,1 9 | cheerleading,dancing,Arts and Entertainment,Root,1 10 | country line dancing,dancing,Arts and Entertainment,Root,1 11 | dancing ballet,dancing,Arts and Entertainment,Root,1 12 | dancing gangnam style,dancing,Arts and Entertainment,Root,1 13 | dancing macarena,dancing,Arts and Entertainment,Root,0 14 | marching,dancing,Arts and Entertainment,Root,1 15 | robot dancing,dancing,Arts and Entertainment,Root,1 16 | salsa dancing,dancing,Arts and Entertainment,Root,1 17 | tango dancing,dancing,Arts and Entertainment,Root,1 18 | tap dancing,dancing,Arts and Entertainment,Root,1 19 | zumba,dancing,Arts and Entertainment,Root,0 20 | arm wrestling,martial arts,Arts and Entertainment,Root,1 21 | capoeira,martial arts,Arts and Entertainment,Root,1 22 | high kick,martial arts,Arts and Entertainment,Root,1 23 | punching bag,martial arts,Arts and Entertainment,Root,1 24 | side kick,martial arts,Arts and Entertainment,Root,1 25 | tai chi,martial arts,Arts and Entertainment,Root,0 26 | busking,music,Arts and Entertainment,Root,1 27 | playing accordion,music,Arts and Entertainment,Root,1 28 | playing bagpipes,music,Arts and Entertainment,Root,1 29 | playing bass guitar,music,Arts and Entertainment,Root,1 30 | playing cello,music,Arts and Entertainment,Root,1 31 | playing clarinet,music,Arts and Entertainment,Root,0 32 | playing didgeridoo,music,Arts and Entertainment,Root,1 33 | playing drums,music,Arts and Entertainment,Root,1 34 | playing guitar,music,Arts and Entertainment,Root,1 35 | playing harmonica,music,Arts and Entertainment,Root,1 36 | playing harp,music,Arts and Entertainment,Root,1 37 | playing recorder,music,Arts and Entertainment,Root,1 38 | playing saxophone,music,Arts and Entertainment,Root,0 39 | playing trombone,music,Arts and Entertainment,Root,1 40 | playing trumpet,music,Arts and Entertainment,Root,1 41 | playing ukulele,music,Arts and Entertainment,Root,1 42 | playing violin,music,Arts and Entertainment,Root,1 43 | playing xylophone,music,Arts and Entertainment,Root,1 44 | tapping guitar,music,Arts and Entertainment,Root,0 45 | cleaning floor,cleaning,Household Activities,Root,1 46 | tying knot (not on a tie),cleaning,Household Activities,Root,1 47 | washing dishes,cleaning,Household Activities,Root,0 48 | baking cookies,cooking,Household Activities,Root,1 49 | barbequing,cooking,Household Activities,Root,1 50 | cooking chicken,cooking,Household Activities,Root,1 51 | cutting watermelon,cooking,Household Activities,Root,1 52 | making pizza,cooking,Household Activities,Root,1 53 | picking fruit,cooking,Household Activities,Root,0 54 | scrambling eggs,cooking,Household Activities,Root,0 55 | chopping wood,garden + plants,Household Activities,Root,1 56 | mowing lawn,garden + plants,Household Activities,Root,1 57 | bookbinding,paper,Household Activities,Root,1 58 | folding napkins,paper,Household Activities,Root,1 59 | folding paper,paper,Household Activities,Root,1 60 | opening present,paper,Household Activities,Root,1 61 | reading book,paper,Household Activities,Root,1 62 | unboxing,paper,Household Activities,Root,1 63 | wrapping present,paper,Household Activities,Root,0 64 | sharpening pencil,using tools,Household Activities,Root,1 65 | using computer,using tools,Household Activities,Root,1 66 | welding,using tools,Household Activities,Root,0 67 | high jump,athletics jumping,Participating in Sports/Exercise/or Recreation,Root,1 68 | long jump,athletics jumping,Participating in Sports/Exercise/or Recreation,Root,1 69 | pole vault,athletics jumping,Participating in Sports/Exercise/or Recreation,Root,1 70 | triple jump,athletics jumping,Participating in Sports/Exercise/or Recreation,Root,0 71 | archery,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1 72 | catching or throwing frisbee,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1 73 | hammer throw,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1 74 | javelin throw,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1 75 | shot put,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1 76 | throwing axe,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1 77 | throwing discus,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,0 78 | bowling,ball sports,Participating in Sports/Exercise/or Recreation,Root,1 79 | dribbling basketball,ball sports,Participating in Sports/Exercise/or Recreation,Root,1 80 | dunking basketball,ball sports,Participating in Sports/Exercise/or Recreation,Root,1 81 | kicking field goal,ball sports,Participating in Sports/Exercise/or Recreation,Root,1 82 | passing American football (in game),ball sports,Participating in Sports/Exercise/or Recreation,Root,1 83 | passing American football (not in game),ball sports,Participating in Sports/Exercise/or Recreation,Root,1 84 | playing basketball,ball sports,Participating in Sports/Exercise/or Recreation,Root,0 85 | playing volleyball,ball sports,Participating in Sports/Exercise/or Recreation,Root,0 86 | golf driving,golf,Participating in Sports/Exercise/or Recreation,Root,1 87 | golf putting,golf,Participating in Sports/Exercise/or Recreation,Root,1 88 | bench pressing,gym,Participating in Sports/Exercise/or Recreation,Root,1 89 | clean and jerk,gym,Participating in Sports/Exercise/or Recreation,Root,1 90 | deadlifting,gym,Participating in Sports/Exercise/or Recreation,Root,1 91 | front raises,gym,Participating in Sports/Exercise/or Recreation,Root,1 92 | pull ups,gym,Participating in Sports/Exercise/or Recreation,Root,1 93 | situp,gym,Participating in Sports/Exercise/or Recreation,Root,1 94 | snatch weight lifting,gym,Participating in Sports/Exercise/or Recreation,Root,1 95 | squat,gym,Participating in Sports/Exercise/or Recreation,Root,0 96 | yoga,gym,Participating in Sports/Exercise/or Recreation,Root,0 97 | lunge,gym,Participating in Sports/Exercise/or Recreation,Root,1 98 | gymnastics tumbling,gymnastics,Participating in Sports/Exercise/or Recreation,Root,1 99 | somersaulting,gymnastics,Participating in Sports/Exercise/or Recreation,Root,1 100 | balloon blowing,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1 101 | beatboxing,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1 102 | blowing out candles,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1 103 | shaking head,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1 104 | singing,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1 105 | smoking,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1 106 | smoking hookah,head + mouth,Participating in Sports/Exercise/or Recreation,Root,0 107 | sticking tongue out,head + mouth,Participating in Sports/Exercise/or Recreation,Root,0 108 | abseiling,heights,Participating in Sports/Exercise/or Recreation,Root,1 109 | bungee jumping,heights,Participating in Sports/Exercise/or Recreation,Root,1 110 | climbing tree,heights,Participating in Sports/Exercise/or Recreation,Root,1 111 | diving cliff,heights,Participating in Sports/Exercise/or Recreation,Root,1 112 | paragliding,heights,Participating in Sports/Exercise/or Recreation,Root,1 113 | rock climbing,heights,Participating in Sports/Exercise/or Recreation,Root,1 114 | slacklining,heights,Participating in Sports/Exercise/or Recreation,Root,0 115 | trapezing,heights,Participating in Sports/Exercise/or Recreation,Root,1 116 | contact juggling,juggling,Participating in Sports/Exercise/or Recreation,Root,1 117 | hula hooping,juggling,Participating in Sports/Exercise/or Recreation,Root,1 118 | juggling balls,juggling,Participating in Sports/Exercise/or Recreation,Root,1 119 | spinning poi,juggling,Participating in Sports/Exercise/or Recreation,Root,0 120 | catching or throwing baseball,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1 121 | catching or throwing softball,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1 122 | hitting baseball,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1 123 | hurling (sport),racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1 124 | playing badminton,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1 125 | playing cricket,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1 126 | playing squash or racquetball,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,0 127 | playing tennis,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,0 128 | biking through snow,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 129 | ice climbing,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 130 | ice skating,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 131 | making snowman,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 132 | playing ice hockey,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 133 | shoveling snow,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 134 | ski jumping,snow + ice,Participating in Sports/Exercise/or Recreation,Root,0 135 | skiing (not slalom or crosscountry),snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 136 | sled dog racing,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 137 | snowboarding,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 138 | snowkiting,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1 139 | tobogganing,snow + ice,Participating in Sports/Exercise/or Recreation,Root,0 140 | swimming backstroke,swimming,Participating in Sports/Exercise/or Recreation,Root,1 141 | swimming breast stroke,swimming,Participating in Sports/Exercise/or Recreation,Root,1 142 | canoeing or kayaking,water sports,Participating in Sports/Exercise/or Recreation,Root,1 143 | jetskiing,water sports,Participating in Sports/Exercise/or Recreation,Root,1 144 | kitesurfing,water sports,Participating in Sports/Exercise/or Recreation,Root,1 145 | parasailing,water sports,Participating in Sports/Exercise/or Recreation,Root,1 146 | sailing,water sports,Participating in Sports/Exercise/or Recreation,Root,1 147 | surfing water,water sports,Participating in Sports/Exercise/or Recreation,Root,0 148 | water skiing,water sports,Participating in Sports/Exercise/or Recreation,Root,1 149 | windsurfing,water sports,Participating in Sports/Exercise/or Recreation,Root,0 150 | braiding hair,hair,Personal Care,Root,1 151 | brushing hair,hair,Personal Care,Root,1 152 | curling hair,hair,Personal Care,Root,1 153 | shaving head,hair,Personal Care,Root,1 154 | trimming or shaving beard,hair,Personal Care,Root,0 155 | air drumming,hands,Personal Care,Root,1 156 | finger snapping,hands,Personal Care,Root,1 157 | doing nails,makeup,Personal Care,Root,1 158 | dying hair,makeup,Personal Care,Root,1 159 | filling eyebrows,makeup,Personal Care,Root,1 160 | waxing chest,makeup,Personal Care,Root,1 161 | waxing legs,makeup,Personal Care,Root,0 162 | brushing teeth,personal hygiene,Personal Care,Root,1 163 | washing feet,personal hygiene,Personal Care,Root,1 164 | washing hands,personal hygiene,Personal Care,Root,0 165 | feeding birds,interacting with animals,Relaxing and Leisure,Root,1 166 | feeding fish,interacting with animals,Relaxing and Leisure,Root,1 167 | feeding goats,interacting with animals,Relaxing and Leisure,Root,1 168 | milking cow,interacting with animals,Relaxing and Leisure,Root,1 169 | petting animal (not cat),interacting with animals,Relaxing and Leisure,Root,1 170 | riding elephant,interacting with animals,Relaxing and Leisure,Root,1 171 | riding or walking with horse,interacting with animals,Relaxing and Leisure,Root,0 172 | shearing sheep,interacting with animals,Relaxing and Leisure,Root,1 173 | walking the dog,interacting with animals,Relaxing and Leisure,Root,0 174 | crawling baby,mobility land,Relaxing and Leisure,Root,1 175 | driving car,mobility land,Relaxing and Leisure,Root,1 176 | driving tractor,mobility land,Relaxing and Leisure,Root,1 177 | motorcycling,mobility land,Relaxing and Leisure,Root,1 178 | pushing car,mobility land,Relaxing and Leisure,Root,1 179 | pushing cart,mobility land,Relaxing and Leisure,Root,1 180 | riding unicycle,mobility land,Relaxing and Leisure,Root,0 181 | roller skating,mobility land,Relaxing and Leisure,Root,1 182 | skateboarding,mobility land,Relaxing and Leisure,Root,1 183 | surfing crowd,mobility land,Relaxing and Leisure,Root,0 184 | crossing river,mobility water,Relaxing and Leisure,Root,1 185 | jumping into pool,mobility water,Relaxing and Leisure,Root,1 186 | scuba diving,mobility water,Relaxing and Leisure,Root,1 187 | snorkeling,mobility water,Relaxing and Leisure,Root,0 188 | flying kite,playing games,Relaxing and Leisure,Root,1 189 | playing chess,playing games,Relaxing and Leisure,Root,1 190 | playing paintball,playing games,Relaxing and Leisure,Root,1 191 | playing poker,playing games,Relaxing and Leisure,Root,1 192 | shuffling cards,playing games,Relaxing and Leisure,Root,0 193 | crying,communication,Social Activities,Root,1 194 | giving or receiving award,communication,Social Activities,Root,1 195 | laughing,communication,Social Activities,Root,1 196 | massaging back,communication,Social Activities,Root,1 197 | presenting weather forecast,communication,Social Activities,Root,0 198 | eating burger,eating + drinking,Social Activities,Root,1 199 | eating ice cream,eating + drinking,Social Activities,Root,1 200 | eating spaghetti,eating + drinking,Social Activities,Root,0 -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import torch 3 | 4 | class Metric(): 5 | def __init__(self,label_set,son2parent): 6 | 7 | self.label_set = label_set # Make sure that the label is sorted as we did in data preparation. 8 | 9 | self.son2parent = son2parent 10 | self.G = nx.Graph() 11 | for son,parent in self.son2parent.items(): 12 | self.G.add_edge(son, parent) # Add an node and corresponding edges on the hierarchy tree 13 | 14 | shortest_path_gen = nx.all_pairs_shortest_path(self.G) # Store the shortest path for fast retrieval 15 | self.shortest_path_dict = dict(shortest_path_gen) 16 | self.hop_matrix = self._get_hops_matrix() # n_cls x n_cls Matrix, whose (i,j) store hops from label i to j. 17 | 18 | def _get_hops_matrix(self): 19 | n_cls = len(self.label_set) 20 | hop_matrix = torch.zeros(n_cls,n_cls) 21 | for i in range(n_cls): 22 | for j in range(n_cls): 23 | source, target = self.label_set[i], self.label_set[j] 24 | path = self.shortest_path_dict[source][target] 25 | hop_matrix[i,j] = len(path) - 1 26 | return hop_matrix 27 | 28 | 29 | def hop_acc(self,ypred,ybatch, hops): 30 | 31 | correct_count = 0 32 | 33 | for i in range (len(ypred)): 34 | source, target = self.label_set[ypred[i]], self.label_set[ybatch[i]] 35 | path = self.shortest_path_dict[source][target] # shortest path on the hierarchy 36 | if len(path) - 1 <= hops: # if the predction is the same class as ybatch (In contrast to pred and y are sliblings or cousins or other) 37 | correct_count = correct_count + 1 38 | 39 | return correct_count / len(ybatch) 40 | 41 | def hop_mAP(self, ypred_topk, ybatch, hop = 0): 42 | # ypred_topk: n x k 43 | # ybatch: n x 1 44 | n,k = ypred_topk.size(0), ypred_topk.size(1) 45 | 46 | # correct_inx = (ypred == ybatch) # broadcast automatically to n x k 47 | correct_inx = torch.zeros(n,k) 48 | for i in range(n): 49 | x_pos = ybatch[i] # 1 x 1 50 | y_pos = ypred_topk[i] # 1 x k 51 | hop_dist = self.hop_matrix[x_pos, y_pos] # Get the graph distance (by hop) between pred_i and ybatch_i 52 | correct_inx[i,:] = (hop_dist <= hop) # For all predictions whose hop distance smaller than hop, regard it as a correct prediction. (Acc 0 hop, Sibling 2 hops, Cousin 4 hops) 53 | 54 | numerator = [correct_inx[:,:i+1].sum(dim=1) for i in range(k)] 55 | numerator = torch.stack(numerator).t() 56 | denominator = torch.arange(1,k+1).repeat(n,1) 57 | P = numerator.float() / denominator.float() 58 | AP = P.mean(dim=1,keepdim=True) 59 | 60 | return AP.mean() 61 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | def loss_fn(support, query, dist_func, c, T): 6 | #Here we use synthesised support. 7 | logits = -dist_func(support,query,c) / T 8 | fewshot_label = torch.arange(support.size(0)).cuda() 9 | loss = F.cross_entropy(logits, fewshot_label) 10 | 11 | return loss 12 | 13 | class RegressNet(nn.Module): 14 | def __init__(self, T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric): 15 | super(RegressNet, self).__init__() 16 | self.layers = nn.Sequential( 17 | nn.Linear(2048, 2048), 18 | nn.LeakyReLU(), 19 | nn.Linear(2048, 2048), 20 | nn.LeakyReLU(), 21 | nn.Linear(2048, emb.size(1)) 22 | ) 23 | self.T = T 24 | self.c = c 25 | self.dist_func = dist_func 26 | self.eval_dist = eval_dist 27 | self.val_loader = val_loader 28 | self.train_loader = train_loader 29 | self.emb = emb 30 | self.metric = metric 31 | 32 | def forward(self, x): 33 | x = self.layers(x) 34 | if self.dist_func.__name__ == 'pair_wise_hyp': 35 | if (x.norm(dim=1) >= 1).sum() != 0: 36 | x = x / (x.norm(dim=1,keepdim=True) + 1e-2) 37 | return x 38 | 39 | def _train(self, optimizer, epochs, T, eval_interval): 40 | model = self 41 | model.train() 42 | for epoch in range(epochs): 43 | 44 | for i, (xbatch, ybatch, abatch) in enumerate(self.train_loader): 45 | 46 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 47 | optimizer.zero_grad() 48 | apred = model(xbatch) 49 | 50 | loss = loss_fn(apred, abatch, self.dist_func, self.c, self.T) 51 | 52 | loss.backward() 53 | optimizer.step() 54 | 55 | if epoch % eval_interval == 0: 56 | self.evaluation(flag = 'valid') 57 | self.evaluation2() 58 | 59 | def eval_zsl(self, unseen_idset): 60 | model = self 61 | unseen_idset = torch.tensor(unseen_idset).cuda() 62 | unseen_emb = self.emb[unseen_idset,:] 63 | 64 | # 0.query_candidates & GT 65 | aval_pred_list, yval_list = [], [] 66 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 67 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 68 | apred = model(xbatch) 69 | aval_pred_list.append(apred) 70 | yval_list.append(ybatch) 71 | aval_pred = torch.cat(aval_pred_list) # search candidates 72 | yval_order1 = torch.cat(yval_list) # ground truth in search candidates's order. 73 | 74 | GT_list, loss_list, ypred_list, ypred_topk_list = [], [], [], [] 75 | 76 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 77 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 78 | apred = model(xbatch) 79 | # 1.loss 80 | loss = loss_fn(apred, abatch, self.dist_func, self.c, self.T).item() 81 | loss_list.append(loss) 82 | # 2.ypred for recognition 83 | dist = self.eval_dist(apred, unseen_emb, self.c) 84 | rank = dist.sort()[1] 85 | top1_rank = rank[:,0] 86 | ypred = unseen_idset[top1_rank] 87 | 88 | # 3.ypred_topk for retrieval 89 | dist_retrieval = self.eval_dist(apred, aval_pred, self.c) 90 | rank = dist_retrieval.sort()[1] 91 | topk_result_inx = rank[:,:50] # k = 50 92 | ypred_topk = yval_order1[topk_result_inx] 93 | ypred_topk_list.append(ypred_topk) 94 | ypred_list.append(ypred) 95 | GT_list.append(ybatch) 96 | GT = torch.cat(GT_list) 97 | ypred = torch.cat(ypred_list) 98 | ypred_topk = torch.cat(ypred_topk_list) 99 | 100 | loss = torch.tensor(loss_list).mean() 101 | hop0_acc = (ypred == GT).float().mean().item() 102 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2) 103 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4) 104 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 105 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 106 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 107 | print('loss :%.3f acc:%.3f, 1hop_acc:%.3f, 2hop_acc:%.3f, mAP:%.3f, 1hop_mAP:%.3f, 2hop_mAP:%.3f'%(loss, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP)) 108 | 109 | def evaluation(self, flag): 110 | model = self 111 | 112 | # 0.query_candidates & GT 113 | aval_pred_list, yval_list = [], [] 114 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 115 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 116 | apred = model(xbatch) 117 | aval_pred_list.append(apred) 118 | yval_list.append(ybatch) 119 | aval_pred = torch.cat(aval_pred_list) # search candidates 120 | yval_order1 = torch.cat(yval_list) # ground truth in search candidates's order. 121 | 122 | GT_list = [] # ground truth in validation order 123 | loss_list, ypred_list, ypred_topk_list = [],[],[] 124 | torch.manual_seed(42) 125 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 126 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 127 | apred = model(xbatch) 128 | # 1.loss 129 | loss = loss_fn(apred, abatch, self.dist_func, self.c, self.T).item() 130 | loss_list.append(loss) 131 | # 2.ypred for recognition 132 | dist = self.eval_dist(apred, self.emb, self.c) 133 | rank = dist.sort()[1] 134 | ypred = rank[:,0] 135 | ypred_list.append(ypred) 136 | # 3.ypred_topk for retrieval 137 | dist_retrieval = self.eval_dist(apred, aval_pred, self.c) 138 | rank = dist_retrieval.sort()[1] 139 | topk_result_inx = rank[:,:50] # k = 50 140 | ypred_topk = yval_order1[topk_result_inx] 141 | ypred_topk_list.append(ypred_topk) 142 | GT_list.append(ybatch) 143 | GT = torch.cat(GT_list) 144 | 145 | loss = torch.tensor(loss_list).mean() 146 | ypred = torch.cat(ypred_list) 147 | ypred_topk = torch.cat(ypred_topk_list) 148 | 149 | hop0_acc = (ypred == GT).float().mean().item() 150 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2) 151 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4) 152 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 153 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 154 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 155 | 156 | print('loss:%.3f,acc:%.3f,1hop_acc:%.3f,2hop_acc:%.3f,mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'% 157 | (loss, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP)) 158 | 159 | def evaluation2(self): 160 | model = self 161 | torch.manual_seed(42) 162 | 163 | # 0.query_candidates & GT 164 | avalpred_list, yval_list = [], [] 165 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader): 166 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda() 167 | avalpred_batch = model(xbatch) 168 | avalpred_list.append(avalpred_batch) 169 | yval_list.append(ybatch) 170 | aval_pred = torch.cat(avalpred_list) # search candidates 171 | yval = torch.cat(yval_list) # ground truth in search candidates's order. 172 | 173 | dist_retrieval = self.eval_dist(self.emb, aval_pred, self.c) # n_cls x n_val 174 | rank = dist_retrieval.sort()[1] # n_cls x n_val 175 | 176 | topk_result_inx = rank[:,:50] # n_cls x k (k = 50) 177 | ypred_topk = yval[topk_result_inx] # n_cls x k 178 | 179 | GT = torch.arange(self.emb.size(0)).cuda() 180 | 181 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 182 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 183 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 184 | 185 | print('mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP)) 186 | 187 | 188 | class SoftmaxNet(nn.Module): 189 | def __init__(self, T, c, eval_dist, train_loader, val_loader, emb, metric): 190 | super(SoftmaxNet, self).__init__() 191 | self.layers = nn.Sequential( 192 | nn.Linear(2048, 2048), 193 | nn.LeakyReLU(), 194 | nn.Linear(2048, 2048), 195 | nn.LeakyReLU(), 196 | nn.Linear(2048, emb.size(0)) 197 | ) 198 | self.T = T 199 | self.c = c 200 | self.eval_dist = eval_dist 201 | self.val_loader = val_loader 202 | self.train_loader = train_loader 203 | self.emb = emb 204 | self.metric = metric 205 | 206 | def forward(self, x): 207 | x = self.layers(x) 208 | return x 209 | 210 | def _train(self, optimizer, epochs, eval_interval, feat_layer): 211 | model = self 212 | model.train() 213 | for epoch in range(epochs): 214 | 215 | for i, (xbatch, ybatch, abatch) in enumerate(self.train_loader): 216 | 217 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 218 | optimizer.zero_grad() 219 | logits = model(xbatch) 220 | logits = logits / self.T # Temperature trick 221 | loss_fn = F.cross_entropy 222 | loss = loss_fn(logits,ybatch) 223 | 224 | loss.backward() 225 | optimizer.step() 226 | 227 | if epoch % eval_interval == 0: 228 | self.evaluation('valid', feat_layer) 229 | self.evaluation2() 230 | 231 | def evaluation(self, flag, feat_layer): # 2,4,6三种 232 | model = self 233 | 234 | # 0.query_candidates & GT 235 | feat_list, yval_list = [], [] 236 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 237 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 238 | feat_batch = model.layers[0:feat_layer](xbatch).detach() 239 | feat_list.append(feat_batch) 240 | yval_list.append(ybatch) 241 | feat = torch.cat(feat_list) # search candidates 242 | yval_order1 = torch.cat(yval_list) # ground truth in search candidates's order. 243 | 244 | GT_list = [] # ground truth in validation order 245 | loss_list, ypred_list, ypred_topk_list = [],[],[] 246 | torch.manual_seed(42) 247 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 248 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 249 | feat_batch = model.layers[0:feat_layer](xbatch).detach() 250 | 251 | logits = model(xbatch) 252 | 253 | # 1.loss 254 | loss_fn = F.cross_entropy 255 | loss = loss_fn(logits,ybatch).item() 256 | loss_list.append(loss) 257 | 258 | # 2.ypred for recognition 259 | rank = logits.sort()[1] 260 | ypred = rank[:,-1] 261 | ypred_list.append(ypred) 262 | 263 | # 3.ypred_topk for retrieval 264 | dist_retrieval = self.eval_dist(feat_batch, feat, self.c) 265 | rank = dist_retrieval.sort()[1] 266 | topk_result_inx = rank[:,:50] # k = 50 267 | ypred_topk = yval_order1[topk_result_inx] # batch_size x k 268 | ypred_topk_list.append(ypred_topk) 269 | GT_list.append(ybatch) 270 | GT = torch.cat(GT_list) 271 | 272 | loss = torch.tensor(loss_list).mean() 273 | ypred = torch.cat(ypred_list) 274 | ypred_topk = torch.cat(ypred_topk_list) 275 | 276 | hop0_acc = (ypred == GT).float().mean().item() 277 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2) 278 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4) 279 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 280 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 281 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 282 | 283 | print('loss:%.3f,acc:%.3f,1hop_acc:%.3f,2hop_acc:%.3f,mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'% 284 | (loss, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP)) 285 | 286 | def evaluation2(self): 287 | model = self 288 | # 0.query_candidates & GT 289 | logits_val_list, yval_list = [], [] 290 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader): 291 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda() 292 | logits_batch = model(xbatch) 293 | logits_batch = logits_batch / self.T 294 | logits_val_list.append(logits_batch) 295 | yval_list.append(ybatch) 296 | logits_val = torch.cat(logits_val_list) # n_cls x n_val 297 | probs_val = F.softmax(logits_val,dim=1) # n_cls x n_val 298 | rank = probs_val.sort(dim=0)[1] # nval x n_cls 299 | topk_result_inx = rank[-50:,:] # k x n_cls (k = 50) 300 | yval = torch.cat(yval_list) # ground truth in search candidates's order. 301 | ypred_topk = yval[topk_result_inx] # k x n_cls 302 | ypred_topk = ypred_topk.t() # n_cls x k 303 | GT = torch.arange(self.emb.size(0)).unsqueeze(1).cuda() # n_cls x 1 304 | 305 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 306 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 307 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 308 | 309 | print('mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP)) 310 | 311 | class CVPR19Net2(nn.Module): 312 | def __init__(self, T, c, parent_set, grandpa_set, son2parent, eval_dist, train_loader, val_loader, emb, metric): 313 | super(CVPR19Net2, self).__init__() 314 | 315 | n_pa, n_gp = len(parent_set), len(grandpa_set) 316 | 317 | self.dense1 = nn.Linear(2048, 2048) 318 | self.leaky1 = nn.LeakyReLU() 319 | self.dense2 = nn.Linear(2048, 2048) 320 | self.leaky2 = nn.LeakyReLU() 321 | self.dense_leaf = nn.Linear(2048, 200) 322 | self.dense_parent = nn.Linear(2048,n_pa) 323 | self.dense_grandpa = nn.Linear(2048,n_gp) 324 | 325 | self.c = c 326 | self.T = T 327 | self.parent_set = parent_set 328 | self.grandpa_set = grandpa_set 329 | self.eval_dist = eval_dist 330 | self.val_loader = val_loader 331 | self.train_loader = train_loader 332 | self.son2parent = son2parent 333 | self.emb = emb 334 | self.metric = metric 335 | 336 | def forward(self, x): 337 | hidden1 = self.dense1(x) 338 | action1 = self.leaky1(hidden1) 339 | hidden2 = self.dense1(action1) 340 | action2 = self.leaky1(hidden2) 341 | 342 | logits_leaf = self.dense_leaf(action1) 343 | logits_leaf = self.dense_leaf(action2) 344 | logits_pa = self.dense_parent(action2) 345 | logits_gp = self.dense_grandpa(action2) 346 | 347 | return logits_leaf, logits_pa, logits_gp, hidden1 348 | 349 | def evaluation(self): 350 | model = self 351 | 352 | # 0.query_candidates & GT 353 | cand_feat_list, cand_yval_list = [], [] 354 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 355 | xbatch, ybatch, abatch, _, _ = self.batch_generation_on_gpu(xbatch, ybatch, abatch) 356 | _, _, _, batch_feat = model(xbatch) 357 | 358 | cand_feat_list.append(batch_feat) 359 | cand_yval_list.append(ybatch) 360 | cand_feat = torch.cat(cand_feat_list) # search candidates 361 | cand_yval = torch.cat(cand_yval_list) # ground truth in search candidates's order. 362 | 363 | GT_list = [] # ground truth in validation order 364 | loss_list, ypred_list, ypred_topk_list = [],[],[] 365 | torch.manual_seed(42) 366 | 367 | 368 | loss_lists = ([],[],[]) 369 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 370 | 371 | xbatch, ybatch, abatch, ybatch_pa, ybatch_gp = self.batch_generation_on_gpu(xbatch, ybatch, abatch) 372 | # 1.Loss 373 | logits_leaf, logits_pa, logits_gp, batch_feat = model(xbatch) 374 | loss_fn = F.cross_entropy 375 | 376 | loss_leaf = loss_fn(logits_leaf,ybatch).item() 377 | loss_pa = loss_fn(logits_pa,ybatch_pa).item() 378 | loss_gp = loss_fn(logits_gp,ybatch_gp).item() 379 | loss_lists[0].append(loss_leaf) 380 | loss_lists[1].append(loss_pa) 381 | loss_lists[2].append(loss_gp) 382 | 383 | # For accs 384 | rank = logits_leaf.sort()[1] 385 | ypred = rank[:,-1] 386 | 387 | # For retrievals 388 | dist_retrieval = self.eval_dist(batch_feat, cand_feat, self.c) 389 | rank = dist_retrieval.sort()[1] 390 | topk_result_inx = rank[:,:50] # k = 50 391 | ypred_topk = cand_yval[topk_result_inx] 392 | 393 | # For ground truth 394 | GT_list.append(ybatch) 395 | ypred_list.append(ypred) 396 | ypred_topk_list.append(ypred_topk) 397 | 398 | loss_leaf = torch.tensor(loss_lists[0]).mean() 399 | loss_pa = torch.tensor(loss_lists[1]).mean() 400 | loss_gp = torch.tensor(loss_lists[2]).mean() 401 | 402 | GT = torch.cat(GT_list) 403 | ypred = torch.cat(ypred_list) 404 | ypred_topk = torch.cat(ypred_topk_list) 405 | 406 | hop0_acc = (ypred == GT).float().mean().item() 407 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2) 408 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4) 409 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 410 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 411 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 412 | 413 | print('loss:%.3f,loss:%.3f,loss:%.3f,acc:%.3f,1hop_acc:%.3f,2hop_acc:%.3f,mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(loss_leaf,loss_pa,loss_gp, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP)) 414 | 415 | def batch_generation_on_gpu(self, xbatch, ybatch, abatch): 416 | son2parent = self.son2parent 417 | label_set = [key for key,value in son2parent.items() if key not in son2parent.values()] 418 | son2grandpa = {key:son2parent[value] for key,value in son2parent.items() if value in self.parent_set} 419 | 420 | ybatch_pa = [self.parent_set.index(son2parent[label_set[item]]) for item in ybatch] 421 | ybatch_pa = torch.tensor(ybatch_pa) 422 | ybatch_gp = [self.grandpa_set.index(son2grandpa[label_set[item]]) for item in ybatch] 423 | ybatch_gp = torch.tensor(ybatch_gp) 424 | 425 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 426 | ybatch_pa, ybatch_gp = ybatch_pa.cuda(), ybatch_gp.cuda() 427 | 428 | return xbatch, ybatch, abatch, ybatch_pa, ybatch_gp 429 | 430 | def _train(self, optimizer, epochs,T, eval_interval): 431 | model = self 432 | model.train() 433 | for epoch in range(epochs): 434 | 435 | for i, (xbatch, ybatch, abatch) in enumerate(self.train_loader): 436 | 437 | xbatch, ybatch, abatch, ybatch_pa, ybatch_gp = self.batch_generation_on_gpu(xbatch, ybatch, abatch) 438 | optimizer.zero_grad() 439 | logits_leaf, logits_pa, logits_gp, _ = model(xbatch) 440 | T = self.T 441 | logits_leaf, logits_pa, logits_gp = logits_leaf/T, logits_pa/T, logits_gp/T # Temperature trick 442 | 443 | loss_fn = F.cross_entropy 444 | loss_leaf = loss_fn(logits_leaf,ybatch) 445 | loss_pa = loss_fn(logits_pa,ybatch_pa) 446 | loss_gp = loss_fn(logits_gp,ybatch_gp) 447 | 448 | loss = loss_leaf + 1* loss_pa + 1 * loss_gp 449 | 450 | loss.backward() 451 | optimizer.step() 452 | 453 | if epoch % eval_interval == 0: 454 | self.evaluation() 455 | self.evaluation2() 456 | 457 | def eval_zsl(self, unseen_idset): 458 | model = self 459 | logits_val_list, yval_list = [], [] 460 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader): 461 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda() 462 | logits_batch, _, _, _ = model(xbatch) 463 | logits_batch = logits_batch / self.T 464 | logits_val_list.append(logits_batch) 465 | yval_list.append(ybatch) 466 | logits_val = torch.cat(logits_val_list) # n_val x n_cls 467 | 468 | logits_val = logits_val[:,unseen_idset] 469 | 470 | probs_val = F.softmax(logits_val,dim=1) # n_val x n_unseencls 471 | rank = probs_val.sort(dim=0)[1] # n_val x n_unseencls 472 | topk_result_inx = rank[-50:,:] # k x n_unseencls (top k for each class) 473 | 474 | yval = torch.cat(yval_list) # ground truth in search candidates's order. 475 | ypred_topk = yval[topk_result_inx] # k x n_unseencls # Get top-K predictions for each sample, using for retrieval, we use k = 50 in this context. 476 | ypred_topk = ypred_topk.t() 477 | 478 | GT = torch.tensor(unseen_idset).cuda() 479 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 480 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 481 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 482 | print('ZSL-search by name:, mAP:%.3f, 1hop_mAP:%.3f, 2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP)) 483 | 484 | def evaluation2(self): 485 | model = self 486 | # 0.query_candidates & GT 487 | logits_val_list, yval_list = [], [] 488 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader): 489 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda() 490 | logits_batch, _, _, _ = model(xbatch) 491 | logits_batch = logits_batch / self.T 492 | logits_val_list.append(logits_batch) 493 | yval_list.append(ybatch) 494 | logits_val = torch.cat(logits_val_list) # n_cls x n_val 495 | probs_val = F.softmax(logits_val,dim=1) # n_cls x n_val 496 | rank = probs_val.sort(dim=0)[1] # n_cls x nval 497 | topk_result_inx = rank[-50:,:] # k x nval (k = 50) 498 | yval = torch.cat(yval_list) # ground truth in search candidates's order. 499 | ypred_topk = yval[topk_result_inx] # k x nval 500 | ypred_topk = ypred_topk.t() # nval x k 501 | GT = torch.arange(self.emb.size(0)).unsqueeze(1).cuda() # n_cls x 1 502 | 503 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 504 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 505 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 506 | 507 | print('mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP)) 508 | 509 | 510 | class CVPR19Net(nn.Module): 511 | def __init__(self, T, c, parent_set, grandpa_set, son2parent, eval_dist, train_loader, val_loader, emb, metric): 512 | super(CVPR19Net, self).__init__() 513 | 514 | n_pa, n_gp = len(parent_set), len(grandpa_set) 515 | 516 | self.dense1 = nn.Linear(2048, 2048) 517 | self.leaky1 = nn.LeakyReLU() 518 | self.dense_pa = nn.Linear(2048, 2048) 519 | self.leaky_pa = nn.LeakyReLU() 520 | self.dense_gp = nn.Linear(2048, 2048) 521 | self.leaky_gp = nn.LeakyReLU() 522 | self.dense_leaf = nn.Linear(2048, 200) 523 | self.dense_parent = nn.Linear(4096, n_pa) 524 | self.dense_grandpa = nn.Linear(6144, n_gp) 525 | 526 | self.c = c 527 | self.T = T 528 | self.parent_set = parent_set 529 | self.grandpa_set = grandpa_set 530 | self.eval_dist = eval_dist 531 | self.val_loader = val_loader 532 | self.train_loader = train_loader 533 | self.son2parent = son2parent 534 | self.emb = emb 535 | self.metric = metric 536 | 537 | def forward(self, x): 538 | hidden1 = self.dense1(x) 539 | action1 = self.leaky1(hidden1) 540 | hidden_pa = self.dense_pa(action1) 541 | action_pa = self.leaky_pa(hidden_pa) 542 | hidden_gp = self.dense_gp(action1) 543 | action_gp = self.leaky_gp(hidden_gp) 544 | 545 | logits_leaf = self.dense_leaf(action1) 546 | # import pdb 547 | # pdb.set_trace() 548 | logits_pa = self.dense_parent(torch.cat((action1,action_pa),dim=1)) 549 | logits_gp = self.dense_grandpa(torch.cat((action1,action_pa,action_gp),dim=1)) 550 | # logits_leaf = self.dense_leaf(action2) 551 | # logits_pa = self.dense_parent(action2) 552 | # logits_gp = self.dense_grandpa(action2) 553 | 554 | 555 | return logits_leaf, logits_pa, logits_gp, hidden1 556 | 557 | def evaluation(self): 558 | model = self 559 | 560 | # 0.query_candidates & GT 561 | cand_feat_list, cand_yval_list = [], [] 562 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 563 | xbatch, ybatch, abatch, _, _ = self.batch_generation_on_gpu(xbatch, ybatch, abatch) 564 | _, _, _, batch_feat = model(xbatch) 565 | 566 | cand_feat_list.append(batch_feat) 567 | cand_yval_list.append(ybatch) 568 | cand_feat = torch.cat(cand_feat_list) # search candidates 569 | cand_yval = torch.cat(cand_yval_list) # ground truth in search candidates's order. 570 | 571 | GT_list = [] # ground truth in validation order 572 | loss_list, ypred_list, ypred_topk_list = [],[],[] 573 | torch.manual_seed(42) 574 | 575 | 576 | loss_lists = ([],[],[]) 577 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader): 578 | 579 | xbatch, ybatch, abatch, ybatch_pa, ybatch_gp = self.batch_generation_on_gpu(xbatch, ybatch, abatch) 580 | # 1.Loss 581 | logits_leaf, logits_pa, logits_gp, batch_feat = model(xbatch) 582 | loss_fn = F.cross_entropy 583 | 584 | loss_leaf = loss_fn(logits_leaf,ybatch).item() 585 | loss_pa = loss_fn(logits_pa,ybatch_pa).item() 586 | loss_gp = loss_fn(logits_gp,ybatch_gp).item() 587 | loss_lists[0].append(loss_leaf) 588 | loss_lists[1].append(loss_pa) 589 | loss_lists[2].append(loss_gp) 590 | 591 | # For accs 592 | rank = logits_leaf.sort()[1] 593 | ypred = rank[:,-1] 594 | 595 | # For retrievals 596 | dist_retrieval = self.eval_dist(batch_feat, cand_feat, self.c) 597 | rank = dist_retrieval.sort()[1] 598 | topk_result_inx = rank[:,:50] # k = 50 599 | ypred_topk = cand_yval[topk_result_inx] 600 | 601 | # For ground truth 602 | GT_list.append(ybatch) 603 | ypred_list.append(ypred) 604 | ypred_topk_list.append(ypred_topk) 605 | 606 | loss_leaf = torch.tensor(loss_lists[0]).mean() 607 | loss_pa = torch.tensor(loss_lists[1]).mean() 608 | loss_gp = torch.tensor(loss_lists[2]).mean() 609 | 610 | GT = torch.cat(GT_list) 611 | ypred = torch.cat(ypred_list) 612 | ypred_topk = torch.cat(ypred_topk_list) 613 | 614 | hop0_acc = (ypred == GT).float().mean().item() 615 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2) 616 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4) 617 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 618 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 619 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 620 | 621 | print('loss:%.3f,loss:%.3f,loss:%.3f,acc:%.3f,1hop_acc:%.3f,2hop_acc:%.3f,mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(loss_leaf,loss_pa,loss_gp, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP)) 622 | 623 | def batch_generation_on_gpu(self, xbatch, ybatch, abatch): 624 | son2parent = self.son2parent 625 | label_set = [key for key,value in son2parent.items() if key not in son2parent.values()] 626 | son2grandpa = {key:son2parent[value] for key,value in son2parent.items() if value in self.parent_set} 627 | 628 | ybatch_pa = [self.parent_set.index(son2parent[label_set[item]]) for item in ybatch] 629 | ybatch_pa = torch.tensor(ybatch_pa) 630 | ybatch_gp = [self.grandpa_set.index(son2grandpa[label_set[item]]) for item in ybatch] 631 | ybatch_gp = torch.tensor(ybatch_gp) 632 | 633 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda() 634 | ybatch_pa, ybatch_gp = ybatch_pa.cuda(), ybatch_gp.cuda() 635 | 636 | return xbatch, ybatch, abatch, ybatch_pa, ybatch_gp 637 | 638 | def _train(self, optimizer, epochs,T, eval_interval): 639 | model = self 640 | model.train() 641 | for epoch in range(epochs): 642 | 643 | for i, (xbatch, ybatch, abatch) in enumerate(self.train_loader): 644 | 645 | xbatch, ybatch, abatch, ybatch_pa, ybatch_gp = self.batch_generation_on_gpu(xbatch, ybatch, abatch) 646 | optimizer.zero_grad() 647 | logits_leaf, logits_pa, logits_gp, _ = model(xbatch) 648 | T = self.T 649 | logits_leaf, logits_pa, logits_gp = logits_leaf/T, logits_pa/T, logits_gp/T # Temperature trick 650 | 651 | loss_fn = F.cross_entropy 652 | loss_leaf = loss_fn(logits_leaf,ybatch) 653 | loss_pa = loss_fn(logits_pa,ybatch_pa) 654 | loss_gp = loss_fn(logits_gp,ybatch_gp) 655 | 656 | loss = loss_leaf + 1* loss_pa + 1 * loss_gp 657 | 658 | loss.backward() 659 | optimizer.step() 660 | 661 | if epoch % eval_interval == 0: 662 | self.evaluation() 663 | self.evaluation2() 664 | 665 | def eval_zsl(self, unseen_idset): 666 | model = self 667 | logits_val_list, yval_list = [], [] 668 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader): 669 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda() 670 | logits_batch, _, _, _ = model(xbatch) 671 | logits_batch = logits_batch / self.T 672 | logits_val_list.append(logits_batch) 673 | yval_list.append(ybatch) 674 | logits_val = torch.cat(logits_val_list) # n_val x n_cls 675 | # 关键步骤 676 | logits_val = logits_val[:,unseen_idset] 677 | 678 | probs_val = F.softmax(logits_val,dim=1) # n_val x n_unseencls 679 | rank = probs_val.sort(dim=0)[1] # n_val x n_unseencls 680 | topk_result_inx = rank[-50:,:] # k x n_unseencls (top k for each class) 681 | 682 | yval = torch.cat(yval_list) # ground truth in search candidates's order. 683 | ypred_topk = yval[topk_result_inx] # k x n_unseencls # Get top-K predictions for each sample, using for retrieval, we use k = 50 in this context. 684 | ypred_topk = ypred_topk.t() 685 | 686 | # top1_result_inx = rank[-1,:] # n_unseencls x 1 (k = 1) 687 | # ypred = yval[top1_result_inx] 688 | # hop0_acc = (ypred == GT).float().mean().item() 689 | # hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2) 690 | # hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4) 691 | 692 | GT = torch.tensor(unseen_idset).cuda() 693 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 694 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 695 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 696 | print('ZSL-search by name:, mAP:%.3f, 1hop_mAP:%.3f, 2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP)) 697 | 698 | # dist = self.eval_dist(apred, unseen_emb, self.c) 699 | 700 | def evaluation2(self): 701 | model = self 702 | # 0.query_candidates & GT 703 | logits_val_list, yval_list = [], [] 704 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader): 705 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda() 706 | logits_batch, _, _, _ = model(xbatch) 707 | logits_batch = logits_batch / self.T 708 | logits_val_list.append(logits_batch) 709 | yval_list.append(ybatch) 710 | logits_val = torch.cat(logits_val_list) # n_cls x n_val 711 | probs_val = F.softmax(logits_val,dim=1) # n_cls x n_val 712 | rank = probs_val.sort(dim=0)[1] # n_cls x nval 713 | topk_result_inx = rank[-50:,:] # k x nval (k = 50) 714 | yval = torch.cat(yval_list) # ground truth in search candidates's order. 715 | ypred_topk = yval[topk_result_inx] # k x nval 716 | ypred_topk = ypred_topk.t() # nval x k 717 | GT = torch.arange(self.emb.size(0)).unsqueeze(1).cuda() # n_cls x 1 718 | 719 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0) 720 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2) 721 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4) 722 | 723 | print('mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP)) 724 | 725 | class NIP19Proto(): 726 | def __init__(self, w2v_emb, lr, epochs, loss_lambdas): 727 | self.w2v_emb = w2v_emb 728 | self.lr = lr 729 | self.epochs = epochs 730 | self.loss_lambdas = loss_lambdas 731 | 732 | def similarity(self,prototypes): 733 | norm = torch.norm(prototypes, dim=1) # each row is norm 1. 734 | 735 | deviation = (norm.sum() - prototypes.shape[0]) 736 | if deviation > 1e1: 737 | print('deviation from norm 1', deviation) 738 | 739 | t1 = norm.unsqueeze(1) # n_cls x 1 740 | t2 = norm.unsqueeze(0) # 1 x n_cls 741 | denominator = torch.matmul(t1, t2) # n_cls x n_cls, each element is a norm product 742 | numerator = torch.matmul(prototypes, prototypes.t()) # each element is a in-prod 743 | cos_sim = numerator / denominator # n_cls x n_cls, each element is a cos_sim 744 | cos_sim_off_diag = cos_sim - torch.diag(torch.diag(cos_sim)) 745 | obj = cos_sim_off_diag.max(dim=1)[0] 746 | 747 | return obj.mean(), cos_sim 748 | 749 | def order_loss(self,prototypes, w2v, lmd=0): 750 | B = prototypes.t() 751 | _, S = self.similarity(w2v) 752 | S = S.float() 753 | 754 | # Laplacian matrix L 755 | S1 = S - lmd 756 | ones = torch.ones(S1.shape[0], 1) 757 | L = torch.diag(S1.matmul(ones)) - S 758 | 759 | # Loss = Trace(BLB') 760 | M = B.matmul(L) 761 | M = M.matmul(B.t()) 762 | o_loss = torch.trace(M) / B.shape[1]**2 763 | 764 | return o_loss 765 | 766 | def train(self): 767 | 768 | emb = self.w2v_emb.cpu() 769 | emb = F.normalize(emb) 770 | prototypes = nn.Parameter(F.normalize(torch.randn(emb.size(0), 300), p=2, dim=1)) 771 | optimizer = torch.optim.SGD([prototypes], lr=self.lr, momentum=0.9) 772 | best_loss = 1000 773 | 774 | for i in range(self.epochs): 775 | optimizer.zero_grad() 776 | sim, _ = self.similarity(prototypes) 777 | o_loss = self.order_loss(prototypes, emb.cpu(), lmd=0) 778 | loss = self.loss_lambdas[0] * sim + self.loss_lambdas[1] * o_loss # 默认是1:1 779 | 780 | loss.backward(retain_graph=True) 781 | optimizer.step() 782 | if i % 10 == 0: 783 | if i % 100 == 0: 784 | print(f'Loss: {loss}, Order Loss: {o_loss} and Sim Loss: {sim}') 785 | prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1)) 786 | optimizer = torch.optim.SGD([prototypes], lr=self.lr, momentum=0.9) 787 | 788 | self.prototypes = prototypes 789 | 790 | 791 | -------------------------------------------------------------------------------- /moments_depth.csv: -------------------------------------------------------------------------------- 1 | ,bowling,compete.v.01,act.v.01,Root,1 2 | ,competing,compete.v.01,act.v.01,Root,1 3 | ,playing+sports,compete.v.01,act.v.01,Root,1 4 | ,racing,compete.v.01,act.v.01,Root,0 5 | ,,descending,act.v.01,Root,1 6 | ,imitating,perform.v.02,act.v.01,Root,1 7 | marrying,officiate.v.01,perform.v.02,act.v.01,Root,1 8 | officiating,officiate.v.01,perform.v.02,act.v.01,Root,1 9 | ,performing,perform.v.02,act.v.01,Root,0 10 | ,gambling,play.v.16,act.v.01,Root,1 11 | ,sneezing,act_involuntarily.v.01,act.v.02,Root,1 12 | ,squinting,act_involuntarily.v.01,act.v.02,Root,1 13 | ,winking,act_involuntarily.v.01,act.v.02,Root,0 14 | ,playing+fun,play.v.16,act.v.02,Root,1 15 | ,playing+videogames,play.v.16,act.v.02,Root,1 16 | ,,hanging,be.v.01,Root,1 17 | ,,leaning,be.v.01,Root,1 18 | ,kneeling,rest.v.01,be.v.01,Root,0 19 | ,queuing,rest.v.01,be.v.01,Root,1 20 | ,resting,rest.v.01,be.v.01,Root,1 21 | ,,standing,be.v.01,Root,1 22 | ,,yawning,be.v.01,Root,1 23 | ,crouching,sit.v.01,be.v.03,Root,1 24 | ,sitting,sit.v.01,be.v.03,Root,1 25 | ,squatting,sit.v.01,be.v.03,Root,0 26 | ,,sleeping,be.v.03,Root,1 27 | ,sniffing,inhale.v.02,cause_to_be_perceived.v.01,Root,1 28 | ,smelling,smell.v.02,cause_to_be_perceived.v.01,Root,1 29 | ,dining,sound.v.02,cause_to_be_perceived.v.01,Root,1 30 | ,knocking,sound.v.02,cause_to_be_perceived.v.01,Root,1 31 | ,whistling,sound.v.02,cause_to_be_perceived.v.01,Root,0 32 | ,tuning,adjust.v.01,change.v.01,Root,1 33 | bathing,fancify.v.01,better.v.02,change.v.01,Root,1 34 | grooming,fancify.v.01,better.v.02,change.v.01,Root,1 35 | ,,camping,change.v.01,Root,1 36 | ,mopping,clean.v.01,change.v.01,Root,1 37 | ,rinsing,clean.v.01,change.v.01,Root,1 38 | ,scrubbing,clean.v.01,change.v.01,Root,1 39 | ,vacuuming,clean.v.01,change.v.01,Root,1 40 | ,washing,clean.v.01,change.v.01,Root,0 41 | ,,clearing,change.v.01,Root,1 42 | ,,crushing,change.v.01,Root,1 43 | ,,cutting,change.v.01,Root,1 44 | ,trimming,decorate.v.01,change.v.01,Root,1 45 | ,,drying,change.v.01,Root,1 46 | ,loading,fill.v.01,change.v.01,Root,1 47 | ,stretching,increase.v.02,change.v.01,Root,1 48 | ,,inflating,change.v.01,Root,1 49 | ,,putting,change.v.01,Root,1 50 | ,clipping,reduce.v.01,change.v.01,Root,1 51 | ,handcuffing,restrain.v.03,change.v.01,Root,1 52 | ,,socializing,change.v.01,Root,1 53 | ,combing,straighten.v.02,change.v.01,Root,1 54 | ,welding,unite.v.06,change.v.01,Root,1 55 | ,,waking,change.v.01,Root,1 56 | ,draining,weaken.v.01,change.v.01,Root,1 57 | ,drenching,wet.v.01,change.v.01,Root,1 58 | ,flooding,wet.v.01,change.v.01,Root,1 59 | ,leaking,wet.v.01,change.v.01,Root,1 60 | ,overflowing,wet.v.01,change.v.01,Root,1 61 | ,submerging,wet.v.01,change.v.01,Root,1 62 | ,watering,wet.v.01,change.v.01,Root,1 63 | ,wetting,wet.v.01,change.v.01,Root,0 64 | ,,boiling,change.v.02,Root,1 65 | ,,breaking,change.v.02,Root,1 66 | chewing,break.v.02,change_integrity.v.01,change.v.02,Root,1 67 | frying,cook.v.03,change_integrity.v.01,change.v.02,Root,1 68 | barbecuing,cook.v.03,change_integrity.v.01,change.v.02,Root,1 69 | grilling,cook.v.03,change_integrity.v.01,change.v.02,Root,0 70 | climbing,increase.v.01,change_magnitude.v.01,change.v.02,Root,1 71 | waxing,increase.v.01,change_magnitude.v.01,change.v.02,Root,1 72 | ,bowing,change_posture.v.01,change.v.02,Root,1 73 | ,bending,change_shape.v.01,change.v.02,Root,1 74 | ,pressing,change_shape.v.01,change.v.02,Root,0 75 | ,combusting,change_state.v.01,change.v.02,Root,1 76 | ,sanding,change_surface.v.01,change.v.02,Root,1 77 | ,cracking,crack.v.01,change.v.02,Root,1 78 | ,,crashing,change.v.02,Root,1 79 | ,dressing,dress.v.01,change.v.02,Root,1 80 | ,,dropping,change.v.02,Root,1 81 | ,tattooing,dye.v.01,change.v.02,Root,1 82 | ,,erupting,change.v.02,Root,1 83 | ,,falling,change.v.02,Root,1 84 | ,,folding,change.v.02,Root,1 85 | ,,landing,change.v.02,Root,1 86 | ,,removing,change.v.02,Root,1 87 | ,,rising,change.v.02,Root,1 88 | ,,turning,change.v.02,Root,0 89 | ,,adult+female+speaking,communicate.v.02,Root,1 90 | ,,adult+male+speaking,communicate.v.02,Root,1 91 | ,,baptizing,communicate.v.02,Root,1 92 | ,,buying,communicate.v.02,Root,1 93 | ,,child+speaking,communicate.v.02,Root,1 94 | ,giggling,express_emotion.v.01,communicate.v.02,Root,1 95 | ,laughing,express_emotion.v.01,communicate.v.02,Root,1 96 | ,shouting,express_emotion.v.01,communicate.v.02,Root,1 97 | ,applauding,gesticulate.v.01,communicate.v.02,Root,1 98 | ,clapping,gesticulate.v.01,communicate.v.02,Root,1 99 | ,shrugging,gesticulate.v.01,communicate.v.02,Root,1 100 | ,frowning,grimace.v.01,communicate.v.02,Root,1 101 | ,grinning,grimace.v.01,communicate.v.02,Root,1 102 | ,smiling,grimace.v.01,communicate.v.02,Root,1 103 | ,coaching,inform.v.01,communicate.v.02,Root,1 104 | ,instructing,inform.v.01,communicate.v.02,Root,1 105 | ,lecturing,inform.v.01,communicate.v.02,Root,1 106 | ,pointing,inform.v.01,communicate.v.02,Root,1 107 | ,teaching,inform.v.01,communicate.v.02,Root,0 108 | ,,paying,communicate.v.02,Root,1 109 | ,asking,request.v.01,communicate.v.02,Root,1 110 | ,praying,request.v.01,communicate.v.02,Root,1 111 | ,,selling,communicate.v.02,Root,1 112 | ,adult+female+singing,sing.v.02,communicate.v.02,Root,1 113 | ,adult+male+singing,sing.v.02,communicate.v.02,Root,1 114 | ,child+singing,sing.v.02,communicate.v.02,Root,1 115 | ,signing,sing.v.02,communicate.v.02,Root,1 116 | ,singing,sing.v.02,communicate.v.02,Root,0 117 | ,,speaking,communicate.v.02,Root,1 118 | ,discussing,talk.v.02,communicate.v.02,Root,1 119 | ,interviewing,talk.v.02,communicate.v.02,Root,1 120 | ,preaching,talk.v.02,communicate.v.02,Root,1 121 | ,talking,talk.v.02,communicate.v.02,Root,0 122 | ,telephoning,telecommunicate.v.01,communicate.v.02,Root,1 123 | ,,buttoning,connect.v.01,Root,1 124 | ,,joining,connect.v.01,Root,1 125 | ,,sewing,connect.v.01,Root,1 126 | ,,stitching,connect.v.01,Root,1 127 | ,,tying,connect.v.01,Root,0 128 | ,,drinking,consume.v.02,Root,1 129 | ,dunking,eat.v.01,consume.v.02,Root,1 130 | ,eating,eat.v.02,consume.v.02,Root,1 131 | ,,filling,consume.v.02,Root,1 132 | ,,smoking,consume.v.02,Root,0 133 | ,,boxing,contend.v.06,Root,1 134 | ,,fencing,contend.v.06,Root,1 135 | ,,fighting,contend.v.06,Root,1 136 | ,,wrestling,contend.v.06,Root,0 137 | ,piloting,steer.v.01,control.v.01,Root,1 138 | ,starting,steer.v.01,control.v.01,Root,1 139 | ,steering,steer.v.01,control.v.01,Root,1 140 | ,typing,steer.v.01,control.v.01,Root,0 141 | ,,dusting,cover.v.01,Root,1 142 | ,,painting,cover.v.01,Root,1 143 | ,,wrapping,cover.v.01,Root,1 144 | ,,boarding,enter.v.01,Root,1 145 | dipping,immerse.v.01,penetrate.v.01,enter.v.01,Root,1 146 | plunging,immerse.v.01,penetrate.v.01,enter.v.01,Root,1 147 | biting,pierce.v.05,penetrate.v.01,enter.v.01,Root,0 148 | ,bubbling,emit.v.01,exhaust.v.05,Root,1 149 | ,coughing,expectorate.v.02,exhaust.v.05,Root,1 150 | filming,record.v.01,save.v.02,have.v.01,Root,1 151 | photographing,record.v.01,save.v.02,have.v.01,Root,1 152 | ,,carrying,have.v.02,Root,1 153 | ,,burying,hide.v.01,Root,1 154 | ,,covering,hide.v.01,Root,1 155 | ,cuddling,embrace.v.02,hold.v.02,Root,1 156 | ,,locking,hold.v.02,Root,1 157 | ,,balancing,hold.v.14,Root,1 158 | ,,juggling,hold.v.14,Root,1 159 | ,arresting,arouse.v.01,make.v.03,Root,1 160 | ,celebrating,arouse.v.01,make.v.03,Root,1 161 | ,cheering,arouse.v.01,make.v.03,Root,1 162 | ,cheerleading,arouse.v.01,make.v.03,Root,1 163 | ,marching,arouse.v.01,make.v.03,Root,1 164 | ,parading,arouse.v.01,make.v.03,Root,1 165 | ,protesting,arouse.v.01,make.v.03,Root,0 166 | ,,building,make.v.03,Root,1 167 | ,,constructing,make.v.03,Root,1 168 | ,assembling,create_from_raw_material.v.01,make.v.03,Root,1 169 | ,baking,create_from_raw_material.v.01,make.v.03,Root,1 170 | ,cooking,create_from_raw_material.v.01,make.v.03,Root,1 171 | ,crafting,create_from_raw_material.v.01,make.v.03,Root,1 172 | ,hammering,create_from_raw_material.v.01,make.v.03,Root,1 173 | ,knitting,create_from_raw_material.v.01,make.v.03,Root,1 174 | ,spinning,create_from_raw_material.v.01,make.v.03,Root,0 175 | ,,playing,make.v.03,Root,1 176 | ,,raising,make.v.03,Root,1 177 | ,autographing,write.v.02,make.v.03,Root,1 178 | ,handwriting,write.v.02,make.v.03,Root,1 179 | ,sketching,write.v.02,make.v.03,Root,1 180 | ,writing,write.v.02,make.v.03,Root,0 181 | ,poking,agitate.v.06,move.v.02,Root,1 182 | ,,blowing,move.v.02,Root,1 183 | ,towing,drag.v.01,move.v.02,Root,1 184 | ,,dragging,move.v.02,Root,1 185 | ,,driving,move.v.02,Root,1 186 | ,,flicking,move.v.02,Root,1 187 | ,smashing,hit.v.12,move.v.02,Root,1 188 | ,,kicking,move.v.02,Root,1 189 | ,launching,launch.v.05,move.v.02,Root,1 190 | ,,lifting,move.v.02,Root,1 191 | ,dripping,pour.v.01,move.v.02,Root,1 192 | ,pouring,pour.v.01,move.v.02,Root,1 193 | ,punting,propel.v.01,move.v.02,Root,1 194 | ,,pulling,move.v.02,Root,1 195 | ,,pushing,move.v.02,Root,1 196 | ,aiming,put.v.01,move.v.02,Root,0 197 | ,attacking,put.v.01,move.v.02,Root,1 198 | ,cramming,put.v.01,move.v.02,Root,1 199 | ,placing,put.v.01,move.v.02,Root,1 200 | ,planting,put.v.01,move.v.02,Root,1 201 | ,plugging,put.v.01,move.v.02,Root,1 202 | ,shooting,put.v.01,move.v.02,Root,1 203 | ,snuggling,put.v.01,move.v.02,Root,1 204 | ,sowing,put.v.01,move.v.02,Root,1 205 | ,stacking,put.v.01,move.v.02,Root,1 206 | ,throwing,put.v.01,move.v.02,Root,0 207 | ,,rocking,move.v.02,Root,1 208 | ,carving,separate.v.02,move.v.02,Root,1 209 | ,chasing,separate.v.02,move.v.02,Root,1 210 | ,chopping,separate.v.02,move.v.02,Root,1 211 | ,manicuring,separate.v.02,move.v.02,Root,1 212 | ,mowing,separate.v.02,move.v.02,Root,1 213 | ,sawing,separate.v.02,move.v.02,Root,1 214 | ,shaving,separate.v.02,move.v.02,Root,1 215 | ,shredding,separate.v.02,move.v.02,Root,1 216 | ,slicing,separate.v.02,move.v.02,Root,1 217 | ,tearing,separate.v.02,move.v.02,Root,0 218 | ,,slipping,move.v.02,Root,1 219 | ,,spilling,move.v.02,Root,1 220 | ,reaching,succeed.v.01,move.v.02,Root,1 221 | ,entering,succeed.v.02,move.v.02,Root,1 222 | ,,swinging,move.v.02,Root,1 223 | ,,closing,move.v.03,Root,1 224 | ,,dancing,move.v.03,Root,1 225 | ,,exiting,move.v.03,Root,1 226 | ,,flipping,move.v.03,Root,1 227 | ,,flowing,move.v.03,Root,1 228 | ,bouncing,jump.v.01,move.v.03,Root,1 229 | ,jumping,jump.v.01,move.v.03,Root,1 230 | ,leaping,jump.v.01,move.v.03,Root,1 231 | ,skipping,jump.v.01,move.v.03,Root,0 232 | ,,pitching,move.v.03,Root,1 233 | ,,rolling,move.v.03,Root,1 234 | ,,sailing,move.v.03,Root,1 235 | ,,shaking,move.v.03,Root,1 236 | ,,snapping,move.v.03,Root,1 237 | ,,stealing,move.v.03,Root,1 238 | ,,stirring,move.v.03,Root,1 239 | ,,sweeping,move.v.03,Root,1 240 | ,,tripping,move.v.03,Root,1 241 | ,swerving,turn.v.01,move.v.03,Root,1 242 | ,drilling,turn.v.09,move.v.03,Root,1 243 | ,screwing,turn.v.09,move.v.03,Root,1 244 | ,twisting,turn.v.09,move.v.03,Root,0 245 | ,,waving,move.v.03,Root,1 246 | ,,catching,move.v.15,Root,1 247 | ,,opening,move.v.15,Root,1 248 | ,,raining,precipitate.v.03,Root,1 249 | ,,snowing,precipitate.v.03,Root,1 250 | ,,storming,precipitate.v.03,Root,1 251 | ,,stopping,prevent.v.01,Root,1 252 | guarding,protect.v.01,defend.v.02,prevent.v.02,Root,1 253 | blocking,obstruct.v.02,impede.v.01,prevent.v.02,Root,1 254 | ,,brushing,remove.v.01,Root,1 255 | ,,cleaning,remove.v.01,Root,1 256 | ,shoveling,dig.v.01,remove.v.01,Root,1 257 | ,digging,dig.v.01,remove.v.01,Root,1 258 | ,,drawing,remove.v.01,Root,1 259 | ,unloading,empty.v.04,remove.v.01,Root,0 260 | ,emptying,empty.v.04,remove.v.01,Root,1 261 | ,,picking,remove.v.01,Root,1 262 | ,peeling,take_off.v.02,remove.v.01,Root,1 263 | ,,unpacking,remove.v.01,Root,1 264 | ,,weeding,remove.v.01,Root,0 265 | ,,hunting,search.v.01,Root,1 266 | ,,surfing,search.v.01,Root,1 267 | ,drumming,play.v.07,sound.v.06,Root,1 268 | ,playing+music,play.v.07,sound.v.06,Root,1 269 | splashing,scatter.v.03,discharge.v.02,spread.v.01,Root,1 270 | spraying,scatter.v.03,discharge.v.02,spread.v.01,Root,1 271 | sprinkling,scatter.v.03,discharge.v.02,spread.v.01,Root,0 272 | ,,spreading,spread.v.01,Root,1 273 | ,gardening,care.v.02,support.v.01,Root,1 274 | ,,shopping,support.v.01,Root,1 275 | fishing,catch.v.04,seize.v.01,take.v.04,Root,1 276 | ,clawing,seize.v.01,take.v.04,Root,1 277 | ,gripping,seize.v.01,take.v.04,Root,1 278 | ,measuring,understand.v.01,think.v.03,Root,1 279 | ,reading,understand.v.01,think.v.03,Root,1 280 | ,studying,understand.v.01,think.v.03,Root,1 281 | massaging,manipulate.v.02,handle.v.04,touch.v.01,Root,1 282 | ,colliding,hit.v.02,touch.v.01,Root,0 283 | ,punching,hit.v.02,touch.v.01,Root,1 284 | ,,hitting,touch.v.01,Root,1 285 | ,,kissing,touch.v.01,Root,1 286 | ,,licking,touch.v.01,Root,1 287 | ,slapping,strike.v.01,touch.v.01,Root,1 288 | ,tickling,strike.v.01,touch.v.01,Root,0 289 | ,,stroking,touch.v.01,Root,1 290 | ,,clinging,touch.v.05,Root,1 291 | packaging,encase.v.01,enclose.v.03,touch.v.05,Root,1 292 | ,,hugging,touch.v.05,Root,1 293 | ,,rubbing,touch.v.05,Root,1 294 | ,,scratching,touch.v.05,Root,0 295 | feeding,provide.v.02,give.v.03,transfer.v.05,Root,1 296 | fueling,supply.v.01,give.v.03,transfer.v.05,Root,1 297 | ,,giving,transfer.v.05,Root,1 298 | ,,ascending,travel.v.01,Root,1 299 | saluting,greet.v.01,come.v.01,travel.v.01,Root,1 300 | ,,crawling,travel.v.01,Root,0 301 | ,diving,descend.v.01,travel.v.01,Root,1 302 | ,,floating,travel.v.01,Root,1 303 | ,,flying,travel.v.01,Root,1 304 | ,skating,glide.v.01,travel.v.01,Root,1 305 | ,,rafting,travel.v.01,Root,1 306 | ,,repairing,travel.v.01,Root,1 307 | ,bicycling,ride.v.02,travel.v.01,Root,1 308 | rowing,boat.v.01,ride.v.02,travel.v.01,Root,1 309 | ,boating,ride.v.02,travel.v.01,Root,1 310 | ,hitchhiking,ride.v.02,travel.v.01,Root,1 311 | ,pedaling,ride.v.02,travel.v.01,Root,0 312 | ,,riding,travel.v.01,Root,1 313 | ,,running,travel.v.01,Root,1 314 | ,,skiing,travel.v.01,Root,1 315 | ,,sliding,travel.v.01,Root,1 316 | ,,swimming,travel.v.01,Root,1 317 | ,jogging,travel_rapidly.v.01,travel.v.01,Root,1 318 | ,sprinting,travel_rapidly.v.01,travel.v.01,Root,1 319 | ,hiking,walk.v.01,travel.v.01,Root,1 320 | ,stomping,walk.v.01,travel.v.01,Root,0 321 | ,,walking,travel.v.01,Root,1 322 | ,,burning,treat.v.03,Root,1 323 | ,,injecting,treat.v.03,Root,1 324 | ,,operating,treat.v.03,Root,1 325 | ,,packing,treat.v.03,Root,0 326 | ,bulldozing,destroy.v.01,unmake.v.01,Root,1 327 | ,destroying,destroy.v.01,unmake.v.01,Root,1 328 | ,extinguishing,destroy.v.01,unmake.v.01,Root,0 329 | ,,exercising,use.v.01,Root,1 330 | ,,taping,use.v.01,Root,1 331 | ,,tapping,use.v.01,Root,1 332 | ,,barking,utter.v.02,Root,1 333 | ,,calling,utter.v.02,Root,1 334 | ,,crying,utter.v.02,Root,1 335 | ,,howling,utter.v.02,Root,1 336 | ,,roaring,utter.v.02,Root,1 337 | ,,spitting,utter.v.02,Root,0 338 | ,,working,work.v.01,Root,1 339 | ,,serving,work.v.02,Root,1 -------------------------------------------------------------------------------- /pmath.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def pair_wise_eud(x,y,c=1.0): 6 | # input: 7 | # x, m x d 8 | # y, n x d 9 | # output: m x n 10 | m = x.size(0) 11 | n = y.size(0) 12 | d = x.size(1) 13 | assert(x.size(1) == y.size(1)) 14 | xx = x.pow(2).sum(-1, keepdim = True) 15 | yy = y.pow(2).sum(-1, keepdim = True) 16 | xy = torch.einsum('ij,kj->ik', (x, y)) 17 | 18 | result = xx - 2*xy + yy.permute(1, 0) 19 | return result 20 | 21 | def pair_wise_cos(x,y,c=1.0): 22 | # input: 23 | # x, m x d 24 | # y, n x d 25 | # output: m x n 26 | x_norm = torch.norm(x, dim=1, keepdim = True)# m x 1 27 | y_norm = torch.norm(y, dim=1, keepdim = True)# n x 1 28 | 29 | denominator = torch.matmul(x_norm, y_norm.t()) # m x n 30 | numerator = torch.matmul(x, y.t()) # m x n, each element is a in-prod 31 | 32 | return -numerator / denominator # m x n 33 | 34 | def pair_wise_hyp(x, y, c=1.0): 35 | c = torch.as_tensor(c) 36 | return _dist_matrix(x, y, c) 37 | 38 | def tanh(x, clamp=15): 39 | return x.clamp(-clamp, clamp).tanh() 40 | 41 | 42 | def tensor_dot(x, y): 43 | res = torch.einsum('ij,kj->ik', (x, y)) 44 | return res 45 | 46 | def _mobius_addition_batch(x, y, c): 47 | xy = tensor_dot(x, y) # B x C 48 | x2 = x.pow(2).sum(-1, keepdim=True) # B x 1 49 | y2 = y.pow(2).sum(-1, keepdim=True) # C x 1 50 | num = (1 + 2 * c * xy + c * y2.permute(1, 0)) # B x C 51 | num = num.unsqueeze(2) * x.unsqueeze(1) # B x C x 1 * B x 1 x D = B x C x D 52 | num = num + (1 - c * x2).unsqueeze(2) * y # B x C x D + B x 1 x 1 = B x C x D 53 | denom_part1 = 1 + 2 * c * xy # B x C 54 | denom_part2 = c ** 2 * x2 * y2.permute(1, 0) # B x 1 * 1 x C = B x C 55 | denom = denom_part1 + denom_part2 56 | res = num / (denom.unsqueeze(2) + 1e-5) 57 | return res 58 | 59 | def _mobius_addition_same_size(x, y, c): 60 | xy = torch.einsum('ij,ij -> i', (x, y)).unsqueeze(1) # n x1 61 | x2 = x.pow(2).sum(-1, keepdim=True) # n x 1 62 | y2 = y.pow(2).sum(-1, keepdim=True) # n x 1 63 | num = 1 + 2 * c * xy + c * y2 # n x 1 64 | num2 = num * x # n x D 65 | num3 = num + (1 - c * x2) * y # n x D 66 | denom_part1 = 1 + 2 * c * xy # n x 1 67 | denom_part2 = c ** 2 * x2 * y2 # n x 1 68 | denom = denom_part1 + denom_part2 69 | res = num3 / (denom + 1e-5) 70 | 71 | return res 72 | 73 | def _dist_matrix(x, y, c): 74 | sqrt_c = c ** 0.5 75 | return 2 / sqrt_c * artanh(sqrt_c * torch.norm(_mobius_addition_batch(-x, y, c=c), dim=-1)) 76 | 77 | 78 | def dist_same_size(x,y, c=1.0): 79 | c = torch.as_tensor(c) 80 | sqrt_c = c ** 0.5 81 | return 2 / sqrt_c * artanh(sqrt_c * torch.norm(_mobius_addition_same_size(-x, y, c=c), dim=-1)) 82 | 83 | def project(x, *, c=1.0): 84 | r""" 85 | Safe projection on the manifold for numerical stability. This was mentioned in [1]_ 86 | Parameters 87 | """ 88 | c = torch.as_tensor(c).type_as(x) 89 | return _project(x, c) 90 | 91 | def _project(x, c): 92 | """Parameters 93 | ---------- 94 | x : tensor 95 | point on the Poincare ball 96 | c : float|tensor 97 | ball negative curvature 98 | Returns 99 | ------- 100 | tensor 101 | projected vector on the manifold""" 102 | c = torch.as_tensor(c).type_as(x) 103 | 104 | norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5) 105 | maxnorm = (1 - 1e-3) / (c ** 0.5) 106 | cond = norm > maxnorm 107 | projected = x / norm * maxnorm 108 | return torch.where(cond, projected, x) 109 | 110 | def mobius_matvec(m, x, c): 111 | x_norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5) 112 | sqrt_c = c ** 0.5 113 | mx = x @ m.transpose(-1, -2) 114 | mx_norm = mx.norm(dim=-1, keepdim=True, p=2) 115 | res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c) 116 | cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8) 117 | res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) 118 | res = torch.where(cond, res_0, res_c) 119 | return _project(res, c) 120 | 121 | def euc2hyp(x, c): 122 | new_x = _project(expmap0(x, c), c) 123 | return new_x 124 | 125 | ### 以下为原始代码 126 | def dist(x, y, *, c=1.0, keepdim=False): 127 | r""" 128 | Distance on the Poincare ball 129 | .. math:: 130 | d_c(x, y) = \frac{2}{\sqrt{c}}\tanh^{-1}(\sqrt{c}\|(-x)\oplus_c y\|_2) 131 | .. plot:: plots/extended/poincare/distance.py 132 | Parameters 133 | ---------- 134 | x : tensor 135 | point on poincare ball 136 | y : tensor 137 | point on poincare ball 138 | c : float|tensor 139 | ball negative curvature 140 | keepdim : bool 141 | retain the last dim? (default: false) 142 | Returns 143 | ------- 144 | tensor 145 | geodesic distance between :math:`x` and :math:`y` 146 | """ 147 | c = torch.as_tensor(c).type_as(x) 148 | return _dist(x, y, c, keepdim=keepdim) 149 | 150 | 151 | def _dist(x, y, c, keepdim: bool = False): 152 | sqrt_c = c ** 0.5 153 | dist_c = artanh(sqrt_c * _mobius_add(-x, y, c).norm(dim=-1, p=2, keepdim=keepdim)) 154 | return dist_c * 2 / sqrt_c 155 | 156 | def mobius_add(x, y, *, c=1.0): 157 | r""" 158 | Mobius addition is a special operation in a hyperbolic space. 159 | .. math:: 160 | x \oplus_c y = \frac{ 161 | (1 + 2 c \langle x, y\rangle + c \|y\|^2_2) x + (1 - c \|x\|_2^2) y 162 | }{ 163 | 1 + 2 c \langle x, y\rangle + c^2 \|x\|^2_2 \|y\|^2_2 164 | } 165 | In general this operation is not commutative: 166 | .. math:: 167 | x \oplus_c y \ne y \oplus_c x 168 | But in some cases this property holds: 169 | * zero vector case 170 | .. math:: 171 | \mathbf{0} \oplus_c x = x \oplus_c \mathbf{0} 172 | * zero negative curvature case that is same as Euclidean addition 173 | .. math:: 174 | x \oplus_0 y = y \oplus_0 x 175 | Another usefull property is so called left-cancellation law: 176 | .. math:: 177 | (-x) \oplus_c (x \oplus_c y) = y 178 | Parameters 179 | ---------- 180 | x : tensor 181 | point on the Poincare ball 182 | y : tensor 183 | point on the Poincare ball 184 | c : float|tensor 185 | ball negative curvature 186 | Returns 187 | ------- 188 | tensor 189 | the result of mobius addition 190 | """ 191 | c = torch.as_tensor(c).type_as(x) 192 | return _mobius_add(x, y, c) 193 | 194 | 195 | def _mobius_add(x, y, c): 196 | x2 = x.pow(2).sum(dim=-1, keepdim=True) 197 | y2 = y.pow(2).sum(dim=-1, keepdim=True) 198 | xy = (x * y).sum(dim=-1, keepdim=True) 199 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y 200 | denom = 1 + 2 * c * xy + c ** 2 * x2 * y2 201 | return num / (denom + 1e-5) 202 | 203 | 204 | def expmap0(u, *, c=1.0): 205 | r""" 206 | Exponential map for Poincare ball model from :math:`0`. 207 | .. math:: 208 | \operatorname{Exp}^c_0(u) = \tanh(\sqrt{c}/2 \|u\|_2) \frac{u}{\sqrt{c}\|u\|_2} 209 | Parameters 210 | ---------- 211 | u : tensor 212 | speed vector on poincare ball 213 | c : float|tensor 214 | ball negative curvature 215 | Returns 216 | ------- 217 | tensor 218 | :math:`\gamma_{0, u}(1)` end point 219 | """ 220 | c = torch.as_tensor(c).type_as(u) 221 | return _expmap0(u, c) 222 | 223 | 224 | def _expmap0(u, c): 225 | sqrt_c = c ** 0.5 226 | u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5) 227 | gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) 228 | return gamma_1 229 | 230 | 231 | def _hyperbolic_softmax(X, A, P, c): 232 | lambda_pkc = 2 / (1 - c * P.pow(2).sum(dim=1)) 233 | k = lambda_pkc * torch.norm(A, dim=1) / torch.sqrt(c) 234 | mob_add = _mobius_addition_batch(-P, X, c) 235 | num = 2 * torch.sqrt(c) * torch.sum(mob_add * A.unsqueeze(1), dim=-1) 236 | denom = torch.norm(A, dim=1, keepdim=True) * (1 - c * mob_add.pow(2).sum(dim=2)) 237 | logit = k.unsqueeze(1) * arsinh(num / denom) 238 | return logit.permute(1, 0) 239 | 240 | class Arsinh(torch.autograd.Function): 241 | @staticmethod 242 | def forward(ctx, x): 243 | ctx.save_for_backward(x) 244 | return (x + torch.sqrt_(1 + x.pow(2))).clamp_min_(1e-5).log_() 245 | 246 | @staticmethod 247 | def backward(ctx, grad_output): 248 | input, = ctx.saved_tensors 249 | return grad_output / (1 + input ** 2) ** 0.5 250 | 251 | class Artanh(torch.autograd.Function): 252 | @staticmethod 253 | def forward(ctx, x): 254 | x = x.clamp(-1 + 1e-5, 1 - 1e-5) 255 | ctx.save_for_backward(x) 256 | res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5) 257 | return res 258 | 259 | @staticmethod 260 | def backward(ctx, grad_output): 261 | input, = ctx.saved_tensors 262 | return grad_output / (1 - input ** 2) 263 | 264 | def artanh(x): 265 | return Artanh.apply(x) 266 | 267 | 268 | def arsinh(x): 269 | return Arsinh.apply(x) 270 | -------------------------------------------------------------------------------- /poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/poster.pdf -------------------------------------------------------------------------------- /poster_jpg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/poster_jpg.jpg --------------------------------------------------------------------------------