├── README.md
├── activity_net.v1-3.json
├── activity_net_depth.csv
├── cone_activity_net300.pth
├── cone_kinetics_300.pth
├── cone_moments_300.pth
├── dataset.py
├── demo.ipynb
├── kinetics_depth.csv
├── metric.py
├── model.py
├── moments_depth.csv
├── pmath.py
├── poster.pdf
└── poster_jpg.jpg
/README.md:
--------------------------------------------------------------------------------
1 | code for **Searching Actions on the Hyperbole**.
2 |
3 | The current code includes all the essentials to re-implement our paper,
4 | The demo code is a raw version, please leave comments if there is any difficuty to use it.
5 |
6 | # Requirements
7 | pytorch == 1.4.0
8 | geoopt == 0.2.0
9 |
10 | If you fail to use this code, please email me (uestc.longteng@gmail.com)
11 |
12 | # Data Hierarchy and Seen/Unseen Split.
13 |
14 | activity_net_depth.csv
15 | kinetics_depth.csv
16 | moments_depth.csv
17 |
18 | # Data files
19 |
20 | I have uploaded the extracted video features (~8GB) to Microsoft Onedrive, please contact me for data sharing.
21 |
22 | # Pretrained 300-D action embedding:
23 |
24 | cone_activity_net300.pth
25 | cone_kinetics_300.pth
26 | cone_moments_300.pth
27 |
28 | # Reference
29 |
30 | [Searching for Actions on the Hyperbole](http://openaccess.thecvf.com/content_CVPR_2020/html/Long_Searching_for_Actions_on_the_Hyperbole_CVPR_2020_paper.html)
31 |
32 | Teng Long, Pascal Mettes, Heng Tao Shen, Cees G. M. Snoek
33 |
34 | @InProceedings{Long_2020_CVPR,
35 | author = {Long, Teng and Mettes, Pascal and Shen, Heng Tao and Snoek, Cees G. M.},
36 | title = {Searching for Actions on the Hyperbole},
37 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
38 | month = {June},
39 | year = {2020}
40 | }
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/activity_net_depth.csv:
--------------------------------------------------------------------------------
1 | Ballet,Dancing,Arts and Entertainment,Root,1
2 | Tango,Dancing,Arts and Entertainment,Root,1
3 | Cheerleading,Dancing,Arts and Entertainment,Root,1
4 | Cumbia,Dancing,Arts and Entertainment,Root,1
5 | Breakdancing,Dancing,Arts and Entertainment,Root,1
6 | Belly dance,Dancing,Arts and Entertainment,Root,0
7 | Building sandcastles,Park activities,Arts and Entertainment,Root,1
8 | Fun sliding down,Park activities,Arts and Entertainment,Root,1
9 | Swinging at the playground,Park activities,Arts and Entertainment,Root,1
10 | Using the monkey bar,Park activities,Arts and Entertainment,Root,1
11 | Starting a campfire,Park activities,Arts and Entertainment,Root,0
12 | Drum corps,Playing musical instruments,Arts and Entertainment,Root,1
13 | Playing congas,Playing musical instruments,Arts and Entertainment,Root,1
14 | Playing drums,Playing musical instruments,Arts and Entertainment,Root,1
15 | Playing bagpipes,Playing musical instruments,Arts and Entertainment,Root,1
16 | Playing harmonica,Playing musical instruments,Arts and Entertainment,Root,1
17 | Playing saxophone,Playing musical instruments,Arts and Entertainment,Root,1
18 | Playing guitarra,Playing musical instruments,Arts and Entertainment,Root,1
19 | Playing flauta,Playing musical instruments,Arts and Entertainment,Root,1
20 | Playing piano,Playing musical instruments,Arts and Entertainment,Root,1
21 | Playing violin,Playing musical instruments,Arts and Entertainment,Root,1
22 | Playing accordion,Playing musical instruments,Arts and Entertainment,Root,0
23 | Having an ice cream,Eating and Drinking,Eating and drinking Activities,Root,1
24 | Drinking coffee,Eating and Drinking,Eating and drinking Activities,Root,1
25 | Drinking beer,Eating and Drinking,Eating and drinking Activities,Root,0
26 | Baking cookies,Food and drink preparation ,Eating and drinking Activities,Root,1
27 | Making a cake,Food and drink preparation ,Eating and drinking Activities,Root,1
28 | Making a lemonade,Food and drink preparation ,Eating and drinking Activities,Root,1
29 | Making an omelette,Food and drink preparation ,Eating and drinking Activities,Root,1
30 | Peeling potatoes,Food and drink preparation ,Eating and drinking Activities,Root,1
31 | Preparing pasta,Food and drink preparation ,Eating and drinking Activities,Root,1
32 | Preparing salad,Food and drink preparation ,Eating and drinking Activities,Root,1
33 | Making a sandwich,Food and drink preparation ,Eating and drinking Activities,Root,0
34 | Mixing drinks,Food and drink preparation ,Eating and drinking Activities,Root,0
35 | Washing dishes,Food and drink preparation ,Eating and drinking Activities,Root,1
36 | Clipping cat claws,Animals and Pets,Household Activities,Root,1
37 | Grooming dog,Animals and Pets,Household Activities,Root,1
38 | Bathing dog,Animals and Pets,Household Activities,Root,1
39 | Disc dog,Animals and Pets,Household Activities,Root,1
40 | Grooming horse,Animals and Pets,Household Activities,Root,1
41 | Walking the dog,Animals and Pets,Household Activities,Root,0
42 | Sharpening knives,Appliances/Tools/and Toys,Household Activities,Root,1
43 | Waxing skis,Appliances/Tools/and Toys,Household Activities,Root,1
44 | Welding,Appliances/Tools/and Toys,Household Activities,Root,0
45 | Mooping floor,Cleaning and Laundry,Household Activities,Root,1
46 | Cleaning windows,Cleaning and Laundry,Household Activities,Root,1
47 | Vacuuming floor,Cleaning and Laundry,Household Activities,Root,1
48 | Polishing forniture,Cleaning and Laundry,Household Activities,Root,1
49 | Ironing clothes,Cleaning and Laundry,Household Activities,Root,1
50 | Hand washing clothes,Cleaning and Laundry,Household Activities,Root,1
51 | Knitting,Cleaning and Laundry,Household Activities,Root,1
52 | Cleaning shoes,Cleaning and Laundry,Household Activities,Root,1
53 | Polishing shoes,Cleaning and Laundry,Household Activities,Root,0
54 | Shoveling snow,Exterior Maintenance/Repair/& Decoration,Household Activities,Root,1
55 | Fixing the roof,Exterior Maintenance/Repair/& Decoration,Household Activities,Root,1
56 | Roof shingle removal,Exterior Maintenance/Repair/& Decoration,Household Activities,Root,1
57 | Painting fence,Exterior Maintenance/Repair/& Decoration,Household Activities,Root,1
58 | Wrapping presents,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
59 | Painting furniture,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
60 | Chopping wood,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
61 | Carving jack-o-lanterns,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
62 | Cleaning sink,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
63 | Decorating the Christmas tree,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
64 | Hanging wallpaper,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
65 | Installing carpet,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
66 | Laying tile,Interior Maintenance/Repair/& Decoration,Household Activities,Root,1
67 | Plastering,Interior Maintenance/Repair/& Decoration,Household Activities,Root,0
68 | Painting,Interior Maintenance/Repair/& Decoration,Household Activities,Root,0
69 | Blowing leaves,Lawn/Garden/and Houseplants,Household Activities,Root,1
70 | Cutting the grass,Lawn/Garden/and Houseplants,Household Activities,Root,1
71 | Raking leaves,Lawn/Garden/and Houseplants,Household Activities,Root,1
72 | Spread mulch,Lawn/Garden/and Houseplants,Household Activities,Root,1
73 | Trimming branches or hedges,Lawn/Garden/and Houseplants,Household Activities,Root,1
74 | Mowing the lawn,Lawn/Garden/and Houseplants,Household Activities,Root,0
75 | Assembling bicycle,Vehicle repair and maintenance,Household Activities,Root,1
76 | Changing car wheel,Vehicle repair and maintenance,Household Activities,Root,1
77 | Hand car wash,Vehicle repair and maintenance,Household Activities,Root,1
78 | Removing ice from car,Vehicle repair and maintenance,Household Activities,Root,1
79 | Fixing bicycle,Vehicle repair and maintenance,Household Activities,Root,0
80 | Putting in contact lenses,Dress up,Personal Care,Root,1
81 | Putting on shoes,Dress up,Personal Care,Root,1
82 | Putting on makeup,Dress up,Personal Care,Root,1
83 | Brushing hair,Dress up,Personal Care,Root,1
84 | Doing nails,Dress up,Personal Care,Root,1
85 | Applying sunscreen,Dress up,Personal Care,Root,0
86 | Removing curlers,Grooming,Personal Care,Root,0
87 | Getting a tattoo,Grooming,Personal Care,Root,1
88 | Getting a piercing,Grooming,Personal Care,Root,1
89 | Getting a haircut,Grooming,Personal Care,Root,1
90 | Blow-drying hair,Grooming,Personal Care,Root,1
91 | Braiding hair,Grooming,Personal Care,Root,1
92 | Shaving legs,Grooming,Personal Care,Root,1
93 | Gargling mouthwash,Wash up,Personal Care,Root,1
94 | Washing face,Wash up,Personal Care,Root,1
95 | Brushing teeth,Wash up,Personal Care,Root,1
96 | Washing hands,Wash up,Personal Care,Root,1
97 | Shaving,Wash up,Personal Care,Root,1
98 | Playing blackjack,Playing games,Relaxing and Leisure,Root,1
99 | Beer pong,Playing games,Relaxing and Leisure,Root,1
100 | Hitting a pinata,Playing games,Relaxing and Leisure,Root,1
101 | Hula hoop,Playing games,Relaxing and Leisure,Root,1
102 | Kite flying,Playing games,Relaxing and Leisure,Root,1
103 | Playing pool,Playing games,Relaxing and Leisure,Root,1
104 | Playing rubik cube,Playing games,Relaxing and Leisure,Root,1
105 | Riding bumper cars,Playing games,Relaxing and Leisure,Root,1
106 | Rock-paper-scissors,Playing games,Relaxing and Leisure,Root,1
107 | Shuffleboard,Playing games,Relaxing and Leisure,Root,1
108 | Slacklining,Playing games,Relaxing and Leisure,Root,1
109 | Table soccer,Playing games,Relaxing and Leisure,Root,1
110 | Throwing darts,Playing games,Relaxing and Leisure,Root,0
111 | Tug of war,Playing games,Relaxing and Leisure,Root,1
112 | Hopscotch,Playing games,Relaxing and Leisure,Root,0
113 | Playing ten pins,Playing games,Relaxing and Leisure,Root,1
114 | Smoking hookah,Tobacco and drug use,Relaxing and Leisure,Root,1
115 | Smoking a cigarette,Tobacco and drug use,Relaxing and Leisure,Root,0
116 | Cricket,Bat-and-ball games,Sports/Exercise/and Recreation,Root,1
117 | Playing kickball,Bat-and-ball games,Sports/Exercise/and Recreation,Root,0
118 | Canoeing,Boating,Sports/Exercise/and Recreation,Root,1
119 | Rafting,Boating,Sports/Exercise/and Recreation,Root,1
120 | River tubing,Boating,Sports/Exercise/and Recreation,Root,0
121 | Sailing,Boating,Sports/Exercise/and Recreation,Root,1
122 | Kayaking,Boating,Sports/Exercise/and Recreation,Root,0
123 | Zumba,Doing aerobics,Sports/Exercise/and Recreation,Root,1
124 | Doing step aerobics,Doing aerobics,Sports/Exercise/and Recreation,Root,0
125 | Using the pommel horse,Doing gymnastics,Sports/Exercise/and Recreation,Root,1
126 | Using the balance beam,Doing gymnastics,Sports/Exercise/and Recreation,Root,1
127 | Tumbling,Doing gymnastics,Sports/Exercise/and Recreation,Root,1
128 | Using parallel bars,Doing gymnastics,Sports/Exercise/and Recreation,Root,1
129 | Using uneven bars,Doing gymnastics,Sports/Exercise/and Recreation,Root,1
130 | Baton twirling,Doing gymnastics,Sports/Exercise/and Recreation,Root,0
131 | Playing polo,Equestrian sports,Sports/Exercise/and Recreation,Root,1
132 | Horseback riding,Equestrian sports,Sports/Exercise/and Recreation,Root,1
133 | Camel ride,Equestrian sports,Sports/Exercise/and Recreation,Root,0
134 | BMX,Extreme sports,Sports/Exercise/and Recreation,Root,1
135 | Rock climbing,Extreme sports,Sports/Exercise/and Recreation,Root,1
136 | Powerbocking,Extreme sports,Sports/Exercise/and Recreation,Root,1
137 | Paintball,Extreme sports,Sports/Exercise/and Recreation,Root,1
138 | Bungee jumping,Extreme sports,Sports/Exercise/and Recreation,Root,1
139 | Doing motocross,Extreme sports,Sports/Exercise/and Recreation,Root,0
140 | High jump,Field sports,Sports/Exercise/and Recreation,Root,1
141 | Discus throw,Field sports,Sports/Exercise/and Recreation,Root,1
142 | Javelin throw,Field sports,Sports/Exercise/and Recreation,Root,1
143 | Long jump,Field sports,Sports/Exercise/and Recreation,Root,1
144 | Triple jump,Field sports,Sports/Exercise/and Recreation,Root,1
145 | Shot put,Field sports,Sports/Exercise/and Recreation,Root,1
146 | Hammer throw,Field sports,Sports/Exercise/and Recreation,Root,1
147 | Archery,Field sports,Sports/Exercise/and Recreation,Root,1
148 | Pole vault,Field sports,Sports/Exercise/and Recreation,Root,0
149 | Running a marathon,Field sports,Sports/Exercise/and Recreation,Root,1
150 | Capoeira,Martial arts,Sports/Exercise/and Recreation,Root,1
151 | Doing kickboxing,Martial arts,Sports/Exercise/and Recreation,Root,1
152 | Doing karate,Martial arts,Sports/Exercise/and Recreation,Root,1
153 | Tai chi,Martial arts,Sports/Exercise/and Recreation,Root,0
154 | Layup drill in basketball,Playing basketball,Sports/Exercise/and Recreation,Root,1
155 | Dodgeball,Playing basketball,Sports/Exercise/and Recreation,Root,0
156 | Playing ice hockey,Playing hockey,Sports/Exercise/and Recreation,Root,1
157 | Playing field hockey,Playing hockey,Sports/Exercise/and Recreation,Root,1
158 | Croquet,Playing hockey,Sports/Exercise/and Recreation,Root,1
159 | Hurling,Playing hockey,Sports/Exercise/and Recreation,Root,0
160 | Beach soccer,Playing soccer,Sports/Exercise/and Recreation,Root,1
161 | Futsal,Playing soccer,Sports/Exercise/and Recreation,Root,0
162 | Playing beach volleyball,Playing volleyball,Sports/Exercise/and Recreation,Root,1
163 | Volleyball,Playing volleyball,Sports/Exercise/and Recreation,Root,0
164 | Ping-pong,Racquet sports ,Sports/Exercise/and Recreation,Root,1
165 | Tennis serve with ball bouncing,Racquet sports ,Sports/Exercise/and Recreation,Root,1
166 | Playing squash,Racquet sports ,Sports/Exercise/and Recreation,Root,1
167 | Playing lacrosse,Racquet sports ,Sports/Exercise/and Recreation,Root,1
168 | Playing racquetball,Racquet sports ,Sports/Exercise/and Recreation,Root,1
169 | Playing badminton,Racquet sports ,Sports/Exercise/and Recreation,Root,0
170 | Bullfighting,Rodeo competitions,Sports/Exercise/and Recreation,Root,1
171 | Calf roping,Rodeo competitions,Sports/Exercise/and Recreation,Root,0
172 | Longboarding,Roller sports,Sports/Exercise/and Recreation,Root,1
173 | Rollerblading,Roller sports,Sports/Exercise/and Recreation,Root,1
174 | Skateboarding,Roller sports,Sports/Exercise/and Recreation,Root,0
175 | Ice fishing,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,1
176 | Curling,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,1
177 | Skiing,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,1
178 | Snow tubing,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,1
179 | Snowboarding,Skiing/ice skating/snowboarding,Sports/Exercise/and Recreation,Root,0
180 | Elliptical trainer,Using cardiovascular equipment,Sports/Exercise/and Recreation,Root,1
181 | Using the rowing machine,Using cardiovascular equipment,Sports/Exercise/and Recreation,Root,1
182 | Spinning,Using cardiovascular equipment,Sports/Exercise/and Recreation,Root,0
183 | Scuba diving,Water sports,Sports/Exercise/and Recreation,Root,1
184 | Surfing,Water sports,Sports/Exercise/and Recreation,Root,1
185 | Swimming,Water sports,Sports/Exercise/and Recreation,Root,1
186 | Wakeboarding,Water sports,Sports/Exercise/and Recreation,Root,1
187 | Waterskiing,Water sports,Sports/Exercise/and Recreation,Root,1
188 | Springboard diving,Water sports,Sports/Exercise/and Recreation,Root,1
189 | Plataform diving,Water sports,Sports/Exercise/and Recreation,Root,1
190 | Windsurfing,Water sports,Sports/Exercise/and Recreation,Root,1
191 | Playing water polo,Water sports,Sports/Exercise/and Recreation,Root,0
192 | Clean and jerk,Weightlifting,Sports/Exercise/and Recreation,Root,1
193 | Snatch,Weightlifting,Sports/Exercise/and Recreation,Root,0
194 | Doing crunches,Working out,Sports/Exercise/and Recreation,Root,1
195 | Kneeling,Working out,Sports/Exercise/and Recreation,Root,1
196 | Rope skipping,Working out,Sports/Exercise/and Recreation,Root,0
197 | Doing fencing,Wrestling,Sports/Exercise/and Recreation,Root,1
198 | Doing a powerbomb,Wrestling,Sports/Exercise/and Recreation,Root,1
199 | Arm wrestling,Wrestling,Sports/Exercise/and Recreation,Root,1
200 | Sumo,Wrestling,Sports/Exercise/and Recreation,Root,0
--------------------------------------------------------------------------------
/cone_activity_net300.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/cone_activity_net300.pth
--------------------------------------------------------------------------------
/cone_kinetics_300.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/cone_kinetics_300.pth
--------------------------------------------------------------------------------
/cone_moments_300.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/cone_moments_300.pth
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch, pickle, json
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | def get_son2parent(csv_path):
5 | son2parent = dict()
6 | with open (csv_path,'r') as fp:
7 | lines = fp.readlines()
8 | for line in lines:
9 | if line[-1] == '\n':
10 | line = line[:-1]
11 | tmp_list = line.split(',')
12 | for i in range(len(tmp_list) - 2): # ignore last digit.
13 | key, value = tmp_list[i], tmp_list[i+1]
14 | son2parent[key] = value
15 | if '' in son2parent.keys():
16 | del son2parent['']
17 | return son2parent
18 |
19 | class My_DS(Dataset):
20 | def __init__(self, X, y, emb):
21 | self.X = X
22 | self.y = y
23 | self.emb = emb
24 |
25 | def __len__(self):
26 | return self.X.shape[0]
27 |
28 | def __getitem__(self, index):
29 | x_batch = self.X[index, :]
30 | y_batch = self.y[index]
31 | a_batch = self.emb[y_batch,:]
32 |
33 | return x_batch, y_batch, a_batch
34 |
35 | def get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size):
36 |
37 | train_dataset = My_DS(X=Xtr, y=ytr, emb=emb)
38 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
39 |
40 | val_dataset = My_DS(X=Xval, y=yval, emb=emb)
41 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
42 |
43 | return train_loader, val_loader
44 |
45 | def split_act_data(feat, label, fns, anno_fn): # This only for ActivityNet.
46 |
47 | with open(anno_fn) as json_file:
48 | data = json.load(json_file)
49 |
50 | anno_tr = {key:value for key,value in data['database'].items() if value['annotations']}
51 |
52 | training_inx, validation_inx = [], []
53 |
54 | for feat_fn in fns:
55 | # Using file name to get the corresponding clip_id, 找到文件名及对应clip_id
56 | key = feat_fn.split('.')[0].split('_')[-1]
57 | key = '_'.join(feat_fn.split('.')[0].split('_')[2:])
58 | clip_inx = feat_fn.split('.')[0].split('_')[0]
59 | clip_inx = int(clip_inx)
60 |
61 | # Get the video annotation, i.e. the action label.
62 | video_anno = anno_tr[key]
63 | if video_anno['subset'] == 'validation':
64 | training_inx.append(False), validation_inx.append(True)
65 | elif video_anno['subset'] == 'training':
66 | training_inx.append(True), validation_inx.append(False)
67 |
68 | training_inx, validation_inx = torch.tensor(training_inx), torch.tensor(validation_inx)
69 |
70 | Xtr, Xval = feat[training_inx], feat[validation_inx]
71 | ytr, yval = label[training_inx], label[validation_inx]
72 | assert Xtr.dim() == 2 and ytr.dim() == 1
73 |
74 | return Xtr, Xval, ytr, yval
75 |
76 | def get_activitynet_dataset(feat_path = './data.pickle', anno_fn = './activity_net.v1-3.json'):
77 | # Generate X, the feature and y, the label.
78 | with open(feat_path,'rb') as f:
79 | data1 = pickle.load(f)
80 |
81 | label_set = list(set(data1['label']))
82 | label_set.sort() # make sure that the label set is sorted for one-hot encoding.
83 |
84 | label = [label_set.index(item) for item in data1['label']]
85 | label = torch.tensor(label)
86 | feat, fns = data1['feat'], data1['fn']
87 |
88 | Xtr, Xval, ytr, yval = split_act_data(feat, label, fns, anno_fn)
89 |
90 | return Xtr, Xval, ytr, yval, label_set
91 |
92 |
93 | def get_emb(emb_type, emb_file, label_set):
94 |
95 | n_cls = len(label_set)
96 | if emb_type == 'rand':
97 | emb = torch.rand(n_cls,300)
98 |
99 | elif emb_type == 'wacv':
100 | with open(emb_file, 'rb') as pickle_file:
101 | content = pickle.load(pickle_file)
102 | emb = content['embedding']
103 | emb = torch.tensor(emb).cuda()
104 |
105 |
106 | elif emb_type == 'oh':
107 | one_hot = torch.zeros(n_cls, n_cls).long()
108 | emb = one_hot.scatter_(dim=1, index=torch.unsqueeze(torch.arange(n_cls), dim=1), src=torch.ones(n_cls, n_cls).long())
109 | emb = emb.float()
110 |
111 | elif emb_type == 'glove':
112 | data2 = torch.load(emb_file, map_location='cpu')
113 | emb_names = data2['objects']
114 | ext_emb = data2['embeddings'] # n_cls x dim
115 |
116 | emb = torch.zeros(n_cls,ext_emb.shape[1])
117 | emb = emb.cuda()
118 | for i in range(n_cls):
119 | pos = emb_names.index(label_set[i])
120 | emb[i] = ext_emb[pos,:]
121 |
122 | elif emb_type == 'hyp':
123 | data2 = torch.load(emb_file, map_location='cpu')
124 | emb_names = data2['objects']
125 | ext_emb = data2['embeddings'] # 272 x dim
126 | emb = torch.zeros(n_cls,ext_emb.shape[1])
127 | for i in range(n_cls):
128 | pos = emb_names.index(label_set[i])
129 | emb[i] = ext_emb[pos,:]
130 |
131 | elif emb_type == 'cone':
132 | data2 = torch.load(emb_file, map_location='cpu')
133 | if type(data2) is zip:
134 | data2 = dict(data2)
135 | emb_names = list(data2.keys())
136 | ext_emb = list(data2.values()) # 271 x dim, root is discarded
137 | ext_emb = torch.tensor(ext_emb)
138 |
139 | emb = torch.zeros(n_cls,ext_emb.shape[1])
140 | for i in range(n_cls):
141 | pos = emb_names.index(label_set[i])
142 | emb[i] = ext_emb[pos,:]
143 |
144 | return emb
145 |
146 | def get_kineticslike_dataset(train_pth_path, valid_pth_path):
147 | data_train = torch.load(train_pth_path)
148 | data_val = torch.load(valid_pth_path)
149 |
150 | label_set = list(set(data_train['label']))
151 | label_set.sort() # 有Sort很重要
152 |
153 | Xtr, Xval = data_train['feat'], data_val['feat']
154 | ytr = torch.tensor([label_set.index(item )for item in data_train['label']])
155 | yval = torch.tensor([label_set.index(item )for item in data_val['label']])
156 |
157 | return Xtr, Xval, ytr, yval, label_set
158 |
159 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 6,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "sys.path.append('../')\n",
11 | "from dataset import get_activitynet_dataset, get_kineticslike_dataset\n",
12 | "from dataset import get_son2parent, get_emb, get_dataloader\n",
13 | "from metric import Metric\n",
14 | "import os\n",
15 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
16 | "\n",
17 | "dataset_name, emb_type = 'activitynet', 'cone'\n",
18 | "\n",
19 | "if dataset_name == 'activitynet':\n",
20 | " tree_file = '../activity_net_depth_v5.csv'\n",
21 | " if emb_type == 'cone' : emb_file = '../cone_activity_net300.pth' \n",
22 | " if emb_type == 'glove' : emb_file = '../w2v_activity_300d.pth'\n",
23 | " feat_path, anno_fn = '../data.pickle', '../activity_net.v1-3.json'\n",
24 | " Xtr, Xval, ytr, yval, label_set = get_activitynet_dataset(feat_path, anno_fn)\n",
25 | " \n",
26 | "elif dataset_name == 'kinetics':\n",
27 | " tree_file = '../kinetics_depth_v2.csv'\n",
28 | " if emb_type == 'cone' : emb_file = '../cones_kinetics300.pth'\n",
29 | " if emb_type == 'glove' : emb_file = '../w2v_kinetics_300d.pth'\n",
30 | " train_pth_path, valid_pth_path = '../kinetics_train.pth', '../kinetics_val.pth'\n",
31 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n",
32 | " \n",
33 | "elif dataset_name == 'moments':\n",
34 | " tree_file = '../moments_depth_v4.csv'\n",
35 | " if emb_type == 'cone' : emb_file = '../cones_moments300.pth'\n",
36 | " if emb_type == 'glove' : emb_file = '../w2v_moments_300d.pth'\n",
37 | " train_pth_path, valid_pth_path = '../moments_train.pth', '../moments_val.pth'\n",
38 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n",
39 | "\n",
40 | "son2parent = get_son2parent(tree_file)\n",
41 | "emb = get_emb('hyp', emb_file, label_set)\n",
42 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
43 | "metric = Metric(label_set, son2parent)"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "### T=0.1"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 7,
56 | "metadata": {},
57 | "outputs": [
58 | {
59 | "name": "stdout",
60 | "output_type": "stream",
61 | "text": [
62 | "loss:1.291,acc:0.669,1hop_acc:0.829,2hop_acc:0.908,mAP:0.584,1hop_mAP:0.767,2hop_mAP:0.871\n",
63 | "mAP:0.620,1hop_mAP:0.958,2hop_mAP:0.983\n",
64 | "loss:1.185,acc:0.711,1hop_acc:0.838,2hop_acc:0.908,mAP:0.623,1hop_mAP:0.791,2hop_mAP:0.885\n",
65 | "mAP:0.712,1hop_mAP:0.955,2hop_mAP:0.979\n",
66 | "loss:1.226,acc:0.716,1hop_acc:0.837,2hop_acc:0.909,mAP:0.636,1hop_mAP:0.798,2hop_mAP:0.889\n",
67 | "mAP:0.720,1hop_mAP:0.960,2hop_mAP:0.982\n",
68 | "loss:1.216,acc:0.726,1hop_acc:0.849,2hop_acc:0.914,mAP:0.643,1hop_mAP:0.802,2hop_mAP:0.891\n",
69 | "mAP:0.730,1hop_mAP:0.960,2hop_mAP:0.981\n",
70 | "loss:1.272,acc:0.726,1hop_acc:0.849,2hop_acc:0.911,mAP:0.645,1hop_mAP:0.803,2hop_mAP:0.890\n",
71 | "mAP:0.730,1hop_mAP:0.962,2hop_mAP:0.981\n",
72 | "loss:1.268,acc:0.734,1hop_acc:0.849,2hop_acc:0.915,mAP:0.651,1hop_mAP:0.805,2hop_mAP:0.892\n",
73 | "mAP:0.740,1hop_mAP:0.960,2hop_mAP:0.981\n",
74 | "loss:1.262,acc:0.733,1hop_acc:0.847,2hop_acc:0.913,mAP:0.652,1hop_mAP:0.806,2hop_mAP:0.893\n",
75 | "mAP:0.740,1hop_mAP:0.959,2hop_mAP:0.981\n",
76 | "loss:1.299,acc:0.720,1hop_acc:0.840,2hop_acc:0.912,mAP:0.653,1hop_mAP:0.803,2hop_mAP:0.890\n",
77 | "mAP:0.741,1hop_mAP:0.959,2hop_mAP:0.981\n",
78 | "loss:1.333,acc:0.722,1hop_acc:0.843,2hop_acc:0.905,mAP:0.653,1hop_mAP:0.805,2hop_mAP:0.890\n",
79 | "mAP:0.740,1hop_mAP:0.956,2hop_mAP:0.978\n",
80 | "loss:1.379,acc:0.708,1hop_acc:0.831,2hop_acc:0.906,mAP:0.649,1hop_mAP:0.800,2hop_mAP:0.889\n",
81 | "mAP:0.737,1hop_mAP:0.957,2hop_mAP:0.979\n"
82 | ]
83 | }
84 | ],
85 | "source": [
86 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n",
87 | "from model import RegressNet\n",
88 | "import torch\n",
89 | "c, T, epochs, dist_func, eval_dist = .1,.1, 50, pair_wise_hyp, pair_wise_cos\n",
90 | "torch.manual_seed(42)\n",
91 | "emb = emb.cuda()\n",
92 | "\n",
93 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
94 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n",
95 | "model = model.cuda()\n",
96 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
97 | "model._train(optimizer, epochs, T, eval_interval = 5)"
98 | ]
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {},
103 | "source": [
104 | "### T = 1"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 8,
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "name": "stdout",
114 | "output_type": "stream",
115 | "text": [
116 | "loss:3.328,acc:0.635,1hop_acc:0.821,2hop_acc:0.898,mAP:0.568,1hop_mAP:0.811,2hop_mAP:0.895\n",
117 | "mAP:0.595,1hop_mAP:0.977,2hop_mAP:0.988\n",
118 | "loss:3.025,acc:0.731,1hop_acc:0.858,2hop_acc:0.920,mAP:0.646,1hop_mAP:0.847,2hop_mAP:0.913\n",
119 | "mAP:0.752,1hop_mAP:0.965,2hop_mAP:0.981\n",
120 | "loss:2.968,acc:0.744,1hop_acc:0.858,2hop_acc:0.919,mAP:0.667,1hop_mAP:0.849,2hop_mAP:0.913\n",
121 | "mAP:0.778,1hop_mAP:0.960,2hop_mAP:0.977\n",
122 | "loss:2.945,acc:0.749,1hop_acc:0.859,2hop_acc:0.919,mAP:0.675,1hop_mAP:0.847,2hop_mAP:0.911\n",
123 | "mAP:0.788,1hop_mAP:0.955,2hop_mAP:0.974\n",
124 | "loss:2.934,acc:0.751,1hop_acc:0.858,2hop_acc:0.918,mAP:0.678,1hop_mAP:0.844,2hop_mAP:0.909\n",
125 | "mAP:0.790,1hop_mAP:0.951,2hop_mAP:0.972\n",
126 | "loss:2.925,acc:0.750,1hop_acc:0.857,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.842,2hop_mAP:0.908\n",
127 | "mAP:0.791,1hop_mAP:0.948,2hop_mAP:0.971\n",
128 | "loss:2.917,acc:0.750,1hop_acc:0.856,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.840,2hop_mAP:0.908\n",
129 | "mAP:0.792,1hop_mAP:0.945,2hop_mAP:0.969\n",
130 | "loss:2.910,acc:0.751,1hop_acc:0.857,2hop_acc:0.918,mAP:0.680,1hop_mAP:0.840,2hop_mAP:0.907\n",
131 | "mAP:0.791,1hop_mAP:0.944,2hop_mAP:0.968\n",
132 | "loss:2.905,acc:0.751,1hop_acc:0.857,2hop_acc:0.919,mAP:0.680,1hop_mAP:0.839,2hop_mAP:0.907\n",
133 | "mAP:0.790,1hop_mAP:0.942,2hop_mAP:0.967\n",
134 | "loss:2.901,acc:0.750,1hop_acc:0.856,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.838,2hop_mAP:0.907\n",
135 | "mAP:0.791,1hop_mAP:0.942,2hop_mAP:0.967\n"
136 | ]
137 | }
138 | ],
139 | "source": [
140 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n",
141 | "from model import RegressNet\n",
142 | "import torch\n",
143 | "c, T, epochs, dist_func, eval_dist = .1,1, 50, pair_wise_hyp, pair_wise_cos\n",
144 | "torch.manual_seed(42)\n",
145 | "emb = emb.cuda()\n",
146 | "\n",
147 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
148 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n",
149 | "model = model.cuda()\n",
150 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
151 | "model._train(optimizer, epochs, T, eval_interval = 5)"
152 | ]
153 | },
154 | {
155 | "cell_type": "markdown",
156 | "metadata": {},
157 | "source": [
158 | "### Step2 200D hyperbolic embedding, for all three datasets"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 1,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "import sys\n",
168 | "sys.path.append('../')\n",
169 | "from dataset import get_activitynet_dataset, get_kineticslike_dataset\n",
170 | "from dataset import get_son2parent, get_emb, get_dataloader\n",
171 | "from metric import Metric\n",
172 | "import os\n",
173 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
174 | "\n",
175 | "dataset_name, emb_type = 'activitynet', 'cone'\n",
176 | "\n",
177 | "if dataset_name == 'activitynet':\n",
178 | " tree_file = '../activity_net_depth_v5.csv'\n",
179 | " if emb_type == 'cone' : emb_file = '../cone_activity_net200.pth' \n",
180 | " if emb_type == 'glove' : emb_file = '../w2v_activity_300d.pth'\n",
181 | " feat_path, anno_fn = '../data.pickle', '../activity_net.v1-3.json'\n",
182 | " Xtr, Xval, ytr, yval, label_set = get_activitynet_dataset(feat_path, anno_fn)\n",
183 | " \n",
184 | "elif dataset_name == 'kinetics':\n",
185 | " tree_file = '../kinetics_depth_v2.csv'\n",
186 | " if emb_type == 'cone' : emb_file = '../cones_kinetics300.pth'\n",
187 | " if emb_type == 'glove' : emb_file = '../w2v_kinetics_300d.pth'\n",
188 | " train_pth_path, valid_pth_path = '../kinetics_train.pth', '../kinetics_val.pth'\n",
189 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n",
190 | " \n",
191 | "elif dataset_name == 'moments':\n",
192 | " tree_file = '../moments_depth_v4.csv'\n",
193 | " if emb_type == 'cone' : emb_file = '../cones_moments300.pth'\n",
194 | " if emb_type == 'glove' : emb_file = '../w2v_moments_300d.pth'\n",
195 | " train_pth_path, valid_pth_path = '../moments_train.pth', '../moments_val.pth'\n",
196 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n",
197 | "\n",
198 | "son2parent = get_son2parent(tree_file)\n",
199 | "emb = get_emb('hyp', emb_file, label_set)\n",
200 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
201 | "metric = Metric(label_set, son2parent)"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 2,
207 | "metadata": {},
208 | "outputs": [
209 | {
210 | "name": "stdout",
211 | "output_type": "stream",
212 | "text": [
213 | "loss:3.318,acc:0.649,1hop_acc:0.822,2hop_acc:0.900,mAP:0.571,1hop_mAP:0.812,2hop_mAP:0.896\n",
214 | "mAP:0.608,1hop_mAP:0.974,2hop_mAP:0.986\n",
215 | "loss:3.034,acc:0.729,1hop_acc:0.858,2hop_acc:0.917,mAP:0.647,1hop_mAP:0.847,2hop_mAP:0.913\n",
216 | "mAP:0.756,1hop_mAP:0.966,2hop_mAP:0.981\n",
217 | "loss:2.980,acc:0.742,1hop_acc:0.860,2hop_acc:0.919,mAP:0.666,1hop_mAP:0.850,2hop_mAP:0.914\n",
218 | "mAP:0.778,1hop_mAP:0.961,2hop_mAP:0.978\n",
219 | "loss:2.955,acc:0.748,1hop_acc:0.858,2hop_acc:0.918,mAP:0.674,1hop_mAP:0.848,2hop_mAP:0.911\n",
220 | "mAP:0.787,1hop_mAP:0.955,2hop_mAP:0.974\n",
221 | "loss:2.942,acc:0.749,1hop_acc:0.859,2hop_acc:0.918,mAP:0.676,1hop_mAP:0.845,2hop_mAP:0.909\n",
222 | "mAP:0.789,1hop_mAP:0.952,2hop_mAP:0.972\n",
223 | "loss:2.933,acc:0.747,1hop_acc:0.858,2hop_acc:0.916,mAP:0.678,1hop_mAP:0.843,2hop_mAP:0.908\n",
224 | "mAP:0.789,1hop_mAP:0.950,2hop_mAP:0.971\n",
225 | "loss:2.925,acc:0.749,1hop_acc:0.858,2hop_acc:0.916,mAP:0.679,1hop_mAP:0.843,2hop_mAP:0.907\n",
226 | "mAP:0.789,1hop_mAP:0.948,2hop_mAP:0.970\n",
227 | "loss:2.921,acc:0.750,1hop_acc:0.858,2hop_acc:0.917,mAP:0.679,1hop_mAP:0.842,2hop_mAP:0.907\n",
228 | "mAP:0.789,1hop_mAP:0.945,2hop_mAP:0.969\n",
229 | "loss:2.915,acc:0.751,1hop_acc:0.859,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.841,2hop_mAP:0.907\n",
230 | "mAP:0.790,1hop_mAP:0.945,2hop_mAP:0.969\n",
231 | "loss:2.909,acc:0.750,1hop_acc:0.858,2hop_acc:0.918,mAP:0.679,1hop_mAP:0.841,2hop_mAP:0.907\n",
232 | "mAP:0.791,1hop_mAP:0.945,2hop_mAP:0.968\n"
233 | ]
234 | }
235 | ],
236 | "source": [
237 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n",
238 | "from model import RegressNet\n",
239 | "import torch\n",
240 | "c, T, epochs, dist_func, eval_dist = .1,1, 50, pair_wise_hyp, pair_wise_cos\n",
241 | "torch.manual_seed(42)\n",
242 | "emb = emb.cuda()\n",
243 | "\n",
244 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
245 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n",
246 | "model = model.cuda()\n",
247 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
248 | "model._train(optimizer, epochs, T, eval_interval = 5)"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 4,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "import sys\n",
258 | "sys.path.append('../')\n",
259 | "from dataset import get_activitynet_dataset, get_kineticslike_dataset\n",
260 | "from dataset import get_son2parent, get_emb, get_dataloader\n",
261 | "from metric import Metric\n",
262 | "import os\n",
263 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
264 | "\n",
265 | "dataset_name, emb_type = 'kinetics', 'cone'\n",
266 | "\n",
267 | "if dataset_name == 'activitynet':\n",
268 | " tree_file = '../activity_net_depth_v5.csv'\n",
269 | " if emb_type == 'cone' : emb_file = '../cone_activity_net200.pth' \n",
270 | " if emb_type == 'glove' : emb_file = '../w2v_activity_300d.pth'\n",
271 | " feat_path, anno_fn = '../data.pickle', '../activity_net.v1-3.json'\n",
272 | " Xtr, Xval, ytr, yval, label_set = get_activitynet_dataset(feat_path, anno_fn)\n",
273 | " \n",
274 | "elif dataset_name == 'kinetics':\n",
275 | " tree_file = '../kinetics_depth_v2.csv'\n",
276 | " if emb_type == 'cone' : emb_file = '../cone_kinetics_200.pth'\n",
277 | " if emb_type == 'glove' : emb_file = '../w2v_kinetics_300d.pth'\n",
278 | " train_pth_path, valid_pth_path = '../kinetics_train.pth', '../kinetics_val.pth'\n",
279 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n",
280 | " \n",
281 | "elif dataset_name == 'moments':\n",
282 | " tree_file = '../moments_depth_v4.csv'\n",
283 | " if emb_type == 'cone' : emb_file = '../cones_moments300.pth'\n",
284 | " if emb_type == 'glove' : emb_file = '../w2v_moments_300d.pth'\n",
285 | " train_pth_path, valid_pth_path = '../moments_train.pth', '../moments_val.pth'\n",
286 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n",
287 | "\n",
288 | "son2parent = get_son2parent(tree_file)\n",
289 | "emb = get_emb('hyp', emb_file, label_set)\n",
290 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
291 | "metric = Metric(label_set, son2parent)"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 5,
297 | "metadata": {},
298 | "outputs": [
299 | {
300 | "name": "stdout",
301 | "output_type": "stream",
302 | "text": [
303 | "loss:2.935,acc:0.762,1hop_acc:0.872,2hop_acc:0.914,mAP:0.573,1hop_mAP:0.844,2hop_mAP:0.894\n",
304 | "mAP:0.698,1hop_mAP:0.956,2hop_mAP:0.969\n",
305 | "loss:2.845,acc:0.770,1hop_acc:0.870,2hop_acc:0.913,mAP:0.595,1hop_mAP:0.834,2hop_mAP:0.887\n",
306 | "mAP:0.719,1hop_mAP:0.937,2hop_mAP:0.957\n",
307 | "loss:2.846,acc:0.771,1hop_acc:0.870,2hop_acc:0.912,mAP:0.594,1hop_mAP:0.827,2hop_mAP:0.882\n",
308 | "mAP:0.717,1hop_mAP:0.929,2hop_mAP:0.951\n",
309 | "loss:2.850,acc:0.769,1hop_acc:0.867,2hop_acc:0.911,mAP:0.593,1hop_mAP:0.824,2hop_mAP:0.880\n",
310 | "mAP:0.715,1hop_mAP:0.928,2hop_mAP:0.950\n",
311 | "loss:2.854,acc:0.762,1hop_acc:0.864,2hop_acc:0.908,mAP:0.590,1hop_mAP:0.821,2hop_mAP:0.878\n",
312 | "mAP:0.713,1hop_mAP:0.927,2hop_mAP:0.949\n",
313 | "loss:2.853,acc:0.763,1hop_acc:0.863,2hop_acc:0.906,mAP:0.590,1hop_mAP:0.822,2hop_mAP:0.878\n",
314 | "mAP:0.712,1hop_mAP:0.926,2hop_mAP:0.949\n",
315 | "loss:2.856,acc:0.762,1hop_acc:0.863,2hop_acc:0.908,mAP:0.589,1hop_mAP:0.819,2hop_mAP:0.876\n",
316 | "mAP:0.711,1hop_mAP:0.924,2hop_mAP:0.947\n",
317 | "loss:2.854,acc:0.763,1hop_acc:0.860,2hop_acc:0.903,mAP:0.589,1hop_mAP:0.816,2hop_mAP:0.871\n",
318 | "mAP:0.711,1hop_mAP:0.922,2hop_mAP:0.944\n",
319 | "loss:2.857,acc:0.763,1hop_acc:0.859,2hop_acc:0.903,mAP:0.588,1hop_mAP:0.814,2hop_mAP:0.871\n",
320 | "mAP:0.711,1hop_mAP:0.921,2hop_mAP:0.944\n",
321 | "loss:2.853,acc:0.766,1hop_acc:0.863,2hop_acc:0.904,mAP:0.589,1hop_mAP:0.815,2hop_mAP:0.869\n",
322 | "mAP:0.711,1hop_mAP:0.919,2hop_mAP:0.942\n"
323 | ]
324 | }
325 | ],
326 | "source": [
327 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n",
328 | "from model import RegressNet\n",
329 | "import torch\n",
330 | "c, T, epochs, dist_func, eval_dist = .1,1, 50, pair_wise_hyp, pair_wise_cos\n",
331 | "torch.manual_seed(42)\n",
332 | "emb = emb.cuda()\n",
333 | "\n",
334 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
335 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n",
336 | "model = model.cuda()\n",
337 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
338 | "model._train(optimizer, epochs, T, eval_interval = 5)"
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": 7,
344 | "metadata": {},
345 | "outputs": [],
346 | "source": [
347 | "import sys\n",
348 | "sys.path.append('../')\n",
349 | "from dataset import get_activitynet_dataset, get_kineticslike_dataset\n",
350 | "from dataset import get_son2parent, get_emb, get_dataloader\n",
351 | "from metric import Metric\n",
352 | "import os\n",
353 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
354 | "\n",
355 | "dataset_name, emb_type = 'moments', 'cone'\n",
356 | "\n",
357 | "if dataset_name == 'activitynet':\n",
358 | " tree_file = '../activity_net_depth_v5.csv'\n",
359 | " if emb_type == 'cone' : emb_file = '../cone_activity_net200.pth' \n",
360 | " if emb_type == 'glove' : emb_file = '../w2v_activity_300d.pth'\n",
361 | " feat_path, anno_fn = '../data.pickle', '../activity_net.v1-3.json'\n",
362 | " Xtr, Xval, ytr, yval, label_set = get_activitynet_dataset(feat_path, anno_fn)\n",
363 | " \n",
364 | "elif dataset_name == 'kinetics':\n",
365 | " tree_file = '../kinetics_depth_v2.csv'\n",
366 | " if emb_type == 'cone' : emb_file = '../cone_kinetics_200.pth'\n",
367 | " if emb_type == 'glove' : emb_file = '../w2v_kinetics_300d.pth'\n",
368 | " train_pth_path, valid_pth_path = '../kinetics_train.pth', '../kinetics_val.pth'\n",
369 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n",
370 | " \n",
371 | "elif dataset_name == 'moments':\n",
372 | " tree_file = '../moments_depth_v5.csv'\n",
373 | " if emb_type == 'cone' : emb_file = '../cone_moments_200.pth'\n",
374 | " if emb_type == 'glove' : emb_file = '../w2v_moments_300d.pth'\n",
375 | " train_pth_path, valid_pth_path = '../moments_train.pth', '../moments_val.pth'\n",
376 | " Xtr, Xval, ytr, yval, label_set = get_kineticslike_dataset(train_pth_path, valid_pth_path)\n",
377 | "\n",
378 | "son2parent = get_son2parent(tree_file)\n",
379 | "emb = get_emb('hyp', emb_file, label_set)\n",
380 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
381 | "metric = Metric(label_set, son2parent)"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": 8,
387 | "metadata": {},
388 | "outputs": [
389 | {
390 | "name": "stdout",
391 | "output_type": "stream",
392 | "text": [
393 | "loss:4.406,acc:0.133,1hop_acc:0.188,2hop_acc:0.375,mAP:0.149,1hop_mAP:0.182,2hop_mAP:0.365\n",
394 | "mAP:0.209,1hop_mAP:0.388,2hop_mAP:0.527\n",
395 | "loss:4.365,acc:0.153,1hop_acc:0.207,2hop_acc:0.380,mAP:0.157,1hop_mAP:0.190,2hop_mAP:0.373\n",
396 | "mAP:0.254,1hop_mAP:0.414,2hop_mAP:0.557\n",
397 | "loss:4.353,acc:0.160,1hop_acc:0.214,2hop_acc:0.391,mAP:0.160,1hop_mAP:0.196,2hop_mAP:0.378\n",
398 | "mAP:0.273,1hop_mAP:0.417,2hop_mAP:0.561\n",
399 | "loss:4.352,acc:0.164,1hop_acc:0.218,2hop_acc:0.393,mAP:0.162,1hop_mAP:0.199,2hop_mAP:0.380\n",
400 | "mAP:0.281,1hop_mAP:0.419,2hop_mAP:0.563\n",
401 | "loss:4.356,acc:0.164,1hop_acc:0.219,2hop_acc:0.398,mAP:0.163,1hop_mAP:0.200,2hop_mAP:0.381\n",
402 | "mAP:0.287,1hop_mAP:0.420,2hop_mAP:0.564\n",
403 | "loss:4.358,acc:0.166,1hop_acc:0.222,2hop_acc:0.403,mAP:0.163,1hop_mAP:0.201,2hop_mAP:0.381\n",
404 | "mAP:0.292,1hop_mAP:0.419,2hop_mAP:0.560\n",
405 | "loss:4.366,acc:0.166,1hop_acc:0.222,2hop_acc:0.402,mAP:0.163,1hop_mAP:0.201,2hop_mAP:0.381\n",
406 | "mAP:0.296,1hop_mAP:0.419,2hop_mAP:0.561\n",
407 | "loss:4.370,acc:0.167,1hop_acc:0.223,2hop_acc:0.406,mAP:0.163,1hop_mAP:0.201,2hop_mAP:0.382\n",
408 | "mAP:0.300,1hop_mAP:0.415,2hop_mAP:0.561\n",
409 | "loss:4.374,acc:0.166,1hop_acc:0.222,2hop_acc:0.409,mAP:0.163,1hop_mAP:0.201,2hop_mAP:0.383\n",
410 | "mAP:0.302,1hop_mAP:0.415,2hop_mAP:0.559\n",
411 | "loss:4.381,acc:0.164,1hop_acc:0.218,2hop_acc:0.407,mAP:0.163,1hop_mAP:0.200,2hop_mAP:0.382\n",
412 | "mAP:0.307,1hop_mAP:0.414,2hop_mAP:0.564\n"
413 | ]
414 | }
415 | ],
416 | "source": [
417 | "from pmath import pair_wise_cos, pair_wise_eud, pair_wise_hyp\n",
418 | "from model import RegressNet\n",
419 | "import torch\n",
420 | "c, T, epochs, dist_func, eval_dist = .1,1, 20, pair_wise_hyp, pair_wise_cos\n",
421 | "torch.manual_seed(42)\n",
422 | "emb = emb.cuda()\n",
423 | "\n",
424 | "train_loader, val_loader = get_dataloader(Xtr, Xval, ytr, yval, emb, batch_size = 128)\n",
425 | "model = RegressNet(T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric)\n",
426 | "model = model.cuda()\n",
427 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
428 | "model._train(optimizer, epochs, T, eval_interval = 2)"
429 | ]
430 | },
431 | {
432 | "cell_type": "code",
433 | "execution_count": null,
434 | "metadata": {},
435 | "outputs": [],
436 | "source": []
437 | }
438 | ],
439 | "metadata": {
440 | "kernelspec": {
441 | "display_name": "Python 3",
442 | "language": "python",
443 | "name": "python3"
444 | },
445 | "language_info": {
446 | "codemirror_mode": {
447 | "name": "ipython",
448 | "version": 3
449 | },
450 | "file_extension": ".py",
451 | "mimetype": "text/x-python",
452 | "name": "python",
453 | "nbconvert_exporter": "python",
454 | "pygments_lexer": "ipython3",
455 | "version": "3.7.4"
456 | }
457 | },
458 | "nbformat": 4,
459 | "nbformat_minor": 4
460 | }
461 |
--------------------------------------------------------------------------------
/kinetics_depth.csv:
--------------------------------------------------------------------------------
1 | blowing glass,arts and crafts,Arts and Entertainment,Root,1
2 | spray painting,arts and crafts,Arts and Entertainment,Root,1
3 | weaving basket,arts and crafts,Arts and Entertainment,Root,0
4 | headbanging,body motions,Arts and Entertainment,Root,1
5 | pumping fist,body motions,Arts and Entertainment,Root,1
6 | stretching leg,body motions,Arts and Entertainment,Root,0
7 | belly dancing,dancing,Arts and Entertainment,Root,1
8 | breakdancing,dancing,Arts and Entertainment,Root,1
9 | cheerleading,dancing,Arts and Entertainment,Root,1
10 | country line dancing,dancing,Arts and Entertainment,Root,1
11 | dancing ballet,dancing,Arts and Entertainment,Root,1
12 | dancing gangnam style,dancing,Arts and Entertainment,Root,1
13 | dancing macarena,dancing,Arts and Entertainment,Root,0
14 | marching,dancing,Arts and Entertainment,Root,1
15 | robot dancing,dancing,Arts and Entertainment,Root,1
16 | salsa dancing,dancing,Arts and Entertainment,Root,1
17 | tango dancing,dancing,Arts and Entertainment,Root,1
18 | tap dancing,dancing,Arts and Entertainment,Root,1
19 | zumba,dancing,Arts and Entertainment,Root,0
20 | arm wrestling,martial arts,Arts and Entertainment,Root,1
21 | capoeira,martial arts,Arts and Entertainment,Root,1
22 | high kick,martial arts,Arts and Entertainment,Root,1
23 | punching bag,martial arts,Arts and Entertainment,Root,1
24 | side kick,martial arts,Arts and Entertainment,Root,1
25 | tai chi,martial arts,Arts and Entertainment,Root,0
26 | busking,music,Arts and Entertainment,Root,1
27 | playing accordion,music,Arts and Entertainment,Root,1
28 | playing bagpipes,music,Arts and Entertainment,Root,1
29 | playing bass guitar,music,Arts and Entertainment,Root,1
30 | playing cello,music,Arts and Entertainment,Root,1
31 | playing clarinet,music,Arts and Entertainment,Root,0
32 | playing didgeridoo,music,Arts and Entertainment,Root,1
33 | playing drums,music,Arts and Entertainment,Root,1
34 | playing guitar,music,Arts and Entertainment,Root,1
35 | playing harmonica,music,Arts and Entertainment,Root,1
36 | playing harp,music,Arts and Entertainment,Root,1
37 | playing recorder,music,Arts and Entertainment,Root,1
38 | playing saxophone,music,Arts and Entertainment,Root,0
39 | playing trombone,music,Arts and Entertainment,Root,1
40 | playing trumpet,music,Arts and Entertainment,Root,1
41 | playing ukulele,music,Arts and Entertainment,Root,1
42 | playing violin,music,Arts and Entertainment,Root,1
43 | playing xylophone,music,Arts and Entertainment,Root,1
44 | tapping guitar,music,Arts and Entertainment,Root,0
45 | cleaning floor,cleaning,Household Activities,Root,1
46 | tying knot (not on a tie),cleaning,Household Activities,Root,1
47 | washing dishes,cleaning,Household Activities,Root,0
48 | baking cookies,cooking,Household Activities,Root,1
49 | barbequing,cooking,Household Activities,Root,1
50 | cooking chicken,cooking,Household Activities,Root,1
51 | cutting watermelon,cooking,Household Activities,Root,1
52 | making pizza,cooking,Household Activities,Root,1
53 | picking fruit,cooking,Household Activities,Root,0
54 | scrambling eggs,cooking,Household Activities,Root,0
55 | chopping wood,garden + plants,Household Activities,Root,1
56 | mowing lawn,garden + plants,Household Activities,Root,1
57 | bookbinding,paper,Household Activities,Root,1
58 | folding napkins,paper,Household Activities,Root,1
59 | folding paper,paper,Household Activities,Root,1
60 | opening present,paper,Household Activities,Root,1
61 | reading book,paper,Household Activities,Root,1
62 | unboxing,paper,Household Activities,Root,1
63 | wrapping present,paper,Household Activities,Root,0
64 | sharpening pencil,using tools,Household Activities,Root,1
65 | using computer,using tools,Household Activities,Root,1
66 | welding,using tools,Household Activities,Root,0
67 | high jump,athletics jumping,Participating in Sports/Exercise/or Recreation,Root,1
68 | long jump,athletics jumping,Participating in Sports/Exercise/or Recreation,Root,1
69 | pole vault,athletics jumping,Participating in Sports/Exercise/or Recreation,Root,1
70 | triple jump,athletics jumping,Participating in Sports/Exercise/or Recreation,Root,0
71 | archery,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1
72 | catching or throwing frisbee,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1
73 | hammer throw,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1
74 | javelin throw,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1
75 | shot put,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1
76 | throwing axe,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,1
77 | throwing discus,athletics throwing + launching,Participating in Sports/Exercise/or Recreation,Root,0
78 | bowling,ball sports,Participating in Sports/Exercise/or Recreation,Root,1
79 | dribbling basketball,ball sports,Participating in Sports/Exercise/or Recreation,Root,1
80 | dunking basketball,ball sports,Participating in Sports/Exercise/or Recreation,Root,1
81 | kicking field goal,ball sports,Participating in Sports/Exercise/or Recreation,Root,1
82 | passing American football (in game),ball sports,Participating in Sports/Exercise/or Recreation,Root,1
83 | passing American football (not in game),ball sports,Participating in Sports/Exercise/or Recreation,Root,1
84 | playing basketball,ball sports,Participating in Sports/Exercise/or Recreation,Root,0
85 | playing volleyball,ball sports,Participating in Sports/Exercise/or Recreation,Root,0
86 | golf driving,golf,Participating in Sports/Exercise/or Recreation,Root,1
87 | golf putting,golf,Participating in Sports/Exercise/or Recreation,Root,1
88 | bench pressing,gym,Participating in Sports/Exercise/or Recreation,Root,1
89 | clean and jerk,gym,Participating in Sports/Exercise/or Recreation,Root,1
90 | deadlifting,gym,Participating in Sports/Exercise/or Recreation,Root,1
91 | front raises,gym,Participating in Sports/Exercise/or Recreation,Root,1
92 | pull ups,gym,Participating in Sports/Exercise/or Recreation,Root,1
93 | situp,gym,Participating in Sports/Exercise/or Recreation,Root,1
94 | snatch weight lifting,gym,Participating in Sports/Exercise/or Recreation,Root,1
95 | squat,gym,Participating in Sports/Exercise/or Recreation,Root,0
96 | yoga,gym,Participating in Sports/Exercise/or Recreation,Root,0
97 | lunge,gym,Participating in Sports/Exercise/or Recreation,Root,1
98 | gymnastics tumbling,gymnastics,Participating in Sports/Exercise/or Recreation,Root,1
99 | somersaulting,gymnastics,Participating in Sports/Exercise/or Recreation,Root,1
100 | balloon blowing,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1
101 | beatboxing,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1
102 | blowing out candles,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1
103 | shaking head,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1
104 | singing,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1
105 | smoking,head + mouth,Participating in Sports/Exercise/or Recreation,Root,1
106 | smoking hookah,head + mouth,Participating in Sports/Exercise/or Recreation,Root,0
107 | sticking tongue out,head + mouth,Participating in Sports/Exercise/or Recreation,Root,0
108 | abseiling,heights,Participating in Sports/Exercise/or Recreation,Root,1
109 | bungee jumping,heights,Participating in Sports/Exercise/or Recreation,Root,1
110 | climbing tree,heights,Participating in Sports/Exercise/or Recreation,Root,1
111 | diving cliff,heights,Participating in Sports/Exercise/or Recreation,Root,1
112 | paragliding,heights,Participating in Sports/Exercise/or Recreation,Root,1
113 | rock climbing,heights,Participating in Sports/Exercise/or Recreation,Root,1
114 | slacklining,heights,Participating in Sports/Exercise/or Recreation,Root,0
115 | trapezing,heights,Participating in Sports/Exercise/or Recreation,Root,1
116 | contact juggling,juggling,Participating in Sports/Exercise/or Recreation,Root,1
117 | hula hooping,juggling,Participating in Sports/Exercise/or Recreation,Root,1
118 | juggling balls,juggling,Participating in Sports/Exercise/or Recreation,Root,1
119 | spinning poi,juggling,Participating in Sports/Exercise/or Recreation,Root,0
120 | catching or throwing baseball,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1
121 | catching or throwing softball,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1
122 | hitting baseball,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1
123 | hurling (sport),racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1
124 | playing badminton,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1
125 | playing cricket,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,1
126 | playing squash or racquetball,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,0
127 | playing tennis,racquet + bat sports,Participating in Sports/Exercise/or Recreation,Root,0
128 | biking through snow,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
129 | ice climbing,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
130 | ice skating,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
131 | making snowman,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
132 | playing ice hockey,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
133 | shoveling snow,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
134 | ski jumping,snow + ice,Participating in Sports/Exercise/or Recreation,Root,0
135 | skiing (not slalom or crosscountry),snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
136 | sled dog racing,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
137 | snowboarding,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
138 | snowkiting,snow + ice,Participating in Sports/Exercise/or Recreation,Root,1
139 | tobogganing,snow + ice,Participating in Sports/Exercise/or Recreation,Root,0
140 | swimming backstroke,swimming,Participating in Sports/Exercise/or Recreation,Root,1
141 | swimming breast stroke,swimming,Participating in Sports/Exercise/or Recreation,Root,1
142 | canoeing or kayaking,water sports,Participating in Sports/Exercise/or Recreation,Root,1
143 | jetskiing,water sports,Participating in Sports/Exercise/or Recreation,Root,1
144 | kitesurfing,water sports,Participating in Sports/Exercise/or Recreation,Root,1
145 | parasailing,water sports,Participating in Sports/Exercise/or Recreation,Root,1
146 | sailing,water sports,Participating in Sports/Exercise/or Recreation,Root,1
147 | surfing water,water sports,Participating in Sports/Exercise/or Recreation,Root,0
148 | water skiing,water sports,Participating in Sports/Exercise/or Recreation,Root,1
149 | windsurfing,water sports,Participating in Sports/Exercise/or Recreation,Root,0
150 | braiding hair,hair,Personal Care,Root,1
151 | brushing hair,hair,Personal Care,Root,1
152 | curling hair,hair,Personal Care,Root,1
153 | shaving head,hair,Personal Care,Root,1
154 | trimming or shaving beard,hair,Personal Care,Root,0
155 | air drumming,hands,Personal Care,Root,1
156 | finger snapping,hands,Personal Care,Root,1
157 | doing nails,makeup,Personal Care,Root,1
158 | dying hair,makeup,Personal Care,Root,1
159 | filling eyebrows,makeup,Personal Care,Root,1
160 | waxing chest,makeup,Personal Care,Root,1
161 | waxing legs,makeup,Personal Care,Root,0
162 | brushing teeth,personal hygiene,Personal Care,Root,1
163 | washing feet,personal hygiene,Personal Care,Root,1
164 | washing hands,personal hygiene,Personal Care,Root,0
165 | feeding birds,interacting with animals,Relaxing and Leisure,Root,1
166 | feeding fish,interacting with animals,Relaxing and Leisure,Root,1
167 | feeding goats,interacting with animals,Relaxing and Leisure,Root,1
168 | milking cow,interacting with animals,Relaxing and Leisure,Root,1
169 | petting animal (not cat),interacting with animals,Relaxing and Leisure,Root,1
170 | riding elephant,interacting with animals,Relaxing and Leisure,Root,1
171 | riding or walking with horse,interacting with animals,Relaxing and Leisure,Root,0
172 | shearing sheep,interacting with animals,Relaxing and Leisure,Root,1
173 | walking the dog,interacting with animals,Relaxing and Leisure,Root,0
174 | crawling baby,mobility land,Relaxing and Leisure,Root,1
175 | driving car,mobility land,Relaxing and Leisure,Root,1
176 | driving tractor,mobility land,Relaxing and Leisure,Root,1
177 | motorcycling,mobility land,Relaxing and Leisure,Root,1
178 | pushing car,mobility land,Relaxing and Leisure,Root,1
179 | pushing cart,mobility land,Relaxing and Leisure,Root,1
180 | riding unicycle,mobility land,Relaxing and Leisure,Root,0
181 | roller skating,mobility land,Relaxing and Leisure,Root,1
182 | skateboarding,mobility land,Relaxing and Leisure,Root,1
183 | surfing crowd,mobility land,Relaxing and Leisure,Root,0
184 | crossing river,mobility water,Relaxing and Leisure,Root,1
185 | jumping into pool,mobility water,Relaxing and Leisure,Root,1
186 | scuba diving,mobility water,Relaxing and Leisure,Root,1
187 | snorkeling,mobility water,Relaxing and Leisure,Root,0
188 | flying kite,playing games,Relaxing and Leisure,Root,1
189 | playing chess,playing games,Relaxing and Leisure,Root,1
190 | playing paintball,playing games,Relaxing and Leisure,Root,1
191 | playing poker,playing games,Relaxing and Leisure,Root,1
192 | shuffling cards,playing games,Relaxing and Leisure,Root,0
193 | crying,communication,Social Activities,Root,1
194 | giving or receiving award,communication,Social Activities,Root,1
195 | laughing,communication,Social Activities,Root,1
196 | massaging back,communication,Social Activities,Root,1
197 | presenting weather forecast,communication,Social Activities,Root,0
198 | eating burger,eating + drinking,Social Activities,Root,1
199 | eating ice cream,eating + drinking,Social Activities,Root,1
200 | eating spaghetti,eating + drinking,Social Activities,Root,0
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import torch
3 |
4 | class Metric():
5 | def __init__(self,label_set,son2parent):
6 |
7 | self.label_set = label_set # Make sure that the label is sorted as we did in data preparation.
8 |
9 | self.son2parent = son2parent
10 | self.G = nx.Graph()
11 | for son,parent in self.son2parent.items():
12 | self.G.add_edge(son, parent) # Add an node and corresponding edges on the hierarchy tree
13 |
14 | shortest_path_gen = nx.all_pairs_shortest_path(self.G) # Store the shortest path for fast retrieval
15 | self.shortest_path_dict = dict(shortest_path_gen)
16 | self.hop_matrix = self._get_hops_matrix() # n_cls x n_cls Matrix, whose (i,j) store hops from label i to j.
17 |
18 | def _get_hops_matrix(self):
19 | n_cls = len(self.label_set)
20 | hop_matrix = torch.zeros(n_cls,n_cls)
21 | for i in range(n_cls):
22 | for j in range(n_cls):
23 | source, target = self.label_set[i], self.label_set[j]
24 | path = self.shortest_path_dict[source][target]
25 | hop_matrix[i,j] = len(path) - 1
26 | return hop_matrix
27 |
28 |
29 | def hop_acc(self,ypred,ybatch, hops):
30 |
31 | correct_count = 0
32 |
33 | for i in range (len(ypred)):
34 | source, target = self.label_set[ypred[i]], self.label_set[ybatch[i]]
35 | path = self.shortest_path_dict[source][target] # shortest path on the hierarchy
36 | if len(path) - 1 <= hops: # if the predction is the same class as ybatch (In contrast to pred and y are sliblings or cousins or other)
37 | correct_count = correct_count + 1
38 |
39 | return correct_count / len(ybatch)
40 |
41 | def hop_mAP(self, ypred_topk, ybatch, hop = 0):
42 | # ypred_topk: n x k
43 | # ybatch: n x 1
44 | n,k = ypred_topk.size(0), ypred_topk.size(1)
45 |
46 | # correct_inx = (ypred == ybatch) # broadcast automatically to n x k
47 | correct_inx = torch.zeros(n,k)
48 | for i in range(n):
49 | x_pos = ybatch[i] # 1 x 1
50 | y_pos = ypred_topk[i] # 1 x k
51 | hop_dist = self.hop_matrix[x_pos, y_pos] # Get the graph distance (by hop) between pred_i and ybatch_i
52 | correct_inx[i,:] = (hop_dist <= hop) # For all predictions whose hop distance smaller than hop, regard it as a correct prediction. (Acc 0 hop, Sibling 2 hops, Cousin 4 hops)
53 |
54 | numerator = [correct_inx[:,:i+1].sum(dim=1) for i in range(k)]
55 | numerator = torch.stack(numerator).t()
56 | denominator = torch.arange(1,k+1).repeat(n,1)
57 | P = numerator.float() / denominator.float()
58 | AP = P.mean(dim=1,keepdim=True)
59 |
60 | return AP.mean()
61 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 | def loss_fn(support, query, dist_func, c, T):
6 | #Here we use synthesised support.
7 | logits = -dist_func(support,query,c) / T
8 | fewshot_label = torch.arange(support.size(0)).cuda()
9 | loss = F.cross_entropy(logits, fewshot_label)
10 |
11 | return loss
12 |
13 | class RegressNet(nn.Module):
14 | def __init__(self, T, c, dist_func, eval_dist, train_loader, val_loader, emb, metric):
15 | super(RegressNet, self).__init__()
16 | self.layers = nn.Sequential(
17 | nn.Linear(2048, 2048),
18 | nn.LeakyReLU(),
19 | nn.Linear(2048, 2048),
20 | nn.LeakyReLU(),
21 | nn.Linear(2048, emb.size(1))
22 | )
23 | self.T = T
24 | self.c = c
25 | self.dist_func = dist_func
26 | self.eval_dist = eval_dist
27 | self.val_loader = val_loader
28 | self.train_loader = train_loader
29 | self.emb = emb
30 | self.metric = metric
31 |
32 | def forward(self, x):
33 | x = self.layers(x)
34 | if self.dist_func.__name__ == 'pair_wise_hyp':
35 | if (x.norm(dim=1) >= 1).sum() != 0:
36 | x = x / (x.norm(dim=1,keepdim=True) + 1e-2)
37 | return x
38 |
39 | def _train(self, optimizer, epochs, T, eval_interval):
40 | model = self
41 | model.train()
42 | for epoch in range(epochs):
43 |
44 | for i, (xbatch, ybatch, abatch) in enumerate(self.train_loader):
45 |
46 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
47 | optimizer.zero_grad()
48 | apred = model(xbatch)
49 |
50 | loss = loss_fn(apred, abatch, self.dist_func, self.c, self.T)
51 |
52 | loss.backward()
53 | optimizer.step()
54 |
55 | if epoch % eval_interval == 0:
56 | self.evaluation(flag = 'valid')
57 | self.evaluation2()
58 |
59 | def eval_zsl(self, unseen_idset):
60 | model = self
61 | unseen_idset = torch.tensor(unseen_idset).cuda()
62 | unseen_emb = self.emb[unseen_idset,:]
63 |
64 | # 0.query_candidates & GT
65 | aval_pred_list, yval_list = [], []
66 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
67 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
68 | apred = model(xbatch)
69 | aval_pred_list.append(apred)
70 | yval_list.append(ybatch)
71 | aval_pred = torch.cat(aval_pred_list) # search candidates
72 | yval_order1 = torch.cat(yval_list) # ground truth in search candidates's order.
73 |
74 | GT_list, loss_list, ypred_list, ypred_topk_list = [], [], [], []
75 |
76 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
77 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
78 | apred = model(xbatch)
79 | # 1.loss
80 | loss = loss_fn(apred, abatch, self.dist_func, self.c, self.T).item()
81 | loss_list.append(loss)
82 | # 2.ypred for recognition
83 | dist = self.eval_dist(apred, unseen_emb, self.c)
84 | rank = dist.sort()[1]
85 | top1_rank = rank[:,0]
86 | ypred = unseen_idset[top1_rank]
87 |
88 | # 3.ypred_topk for retrieval
89 | dist_retrieval = self.eval_dist(apred, aval_pred, self.c)
90 | rank = dist_retrieval.sort()[1]
91 | topk_result_inx = rank[:,:50] # k = 50
92 | ypred_topk = yval_order1[topk_result_inx]
93 | ypred_topk_list.append(ypred_topk)
94 | ypred_list.append(ypred)
95 | GT_list.append(ybatch)
96 | GT = torch.cat(GT_list)
97 | ypred = torch.cat(ypred_list)
98 | ypred_topk = torch.cat(ypred_topk_list)
99 |
100 | loss = torch.tensor(loss_list).mean()
101 | hop0_acc = (ypred == GT).float().mean().item()
102 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2)
103 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4)
104 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
105 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
106 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
107 | print('loss :%.3f acc:%.3f, 1hop_acc:%.3f, 2hop_acc:%.3f, mAP:%.3f, 1hop_mAP:%.3f, 2hop_mAP:%.3f'%(loss, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP))
108 |
109 | def evaluation(self, flag):
110 | model = self
111 |
112 | # 0.query_candidates & GT
113 | aval_pred_list, yval_list = [], []
114 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
115 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
116 | apred = model(xbatch)
117 | aval_pred_list.append(apred)
118 | yval_list.append(ybatch)
119 | aval_pred = torch.cat(aval_pred_list) # search candidates
120 | yval_order1 = torch.cat(yval_list) # ground truth in search candidates's order.
121 |
122 | GT_list = [] # ground truth in validation order
123 | loss_list, ypred_list, ypred_topk_list = [],[],[]
124 | torch.manual_seed(42)
125 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
126 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
127 | apred = model(xbatch)
128 | # 1.loss
129 | loss = loss_fn(apred, abatch, self.dist_func, self.c, self.T).item()
130 | loss_list.append(loss)
131 | # 2.ypred for recognition
132 | dist = self.eval_dist(apred, self.emb, self.c)
133 | rank = dist.sort()[1]
134 | ypred = rank[:,0]
135 | ypred_list.append(ypred)
136 | # 3.ypred_topk for retrieval
137 | dist_retrieval = self.eval_dist(apred, aval_pred, self.c)
138 | rank = dist_retrieval.sort()[1]
139 | topk_result_inx = rank[:,:50] # k = 50
140 | ypred_topk = yval_order1[topk_result_inx]
141 | ypred_topk_list.append(ypred_topk)
142 | GT_list.append(ybatch)
143 | GT = torch.cat(GT_list)
144 |
145 | loss = torch.tensor(loss_list).mean()
146 | ypred = torch.cat(ypred_list)
147 | ypred_topk = torch.cat(ypred_topk_list)
148 |
149 | hop0_acc = (ypred == GT).float().mean().item()
150 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2)
151 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4)
152 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
153 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
154 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
155 |
156 | print('loss:%.3f,acc:%.3f,1hop_acc:%.3f,2hop_acc:%.3f,mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%
157 | (loss, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP))
158 |
159 | def evaluation2(self):
160 | model = self
161 | torch.manual_seed(42)
162 |
163 | # 0.query_candidates & GT
164 | avalpred_list, yval_list = [], []
165 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader):
166 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda()
167 | avalpred_batch = model(xbatch)
168 | avalpred_list.append(avalpred_batch)
169 | yval_list.append(ybatch)
170 | aval_pred = torch.cat(avalpred_list) # search candidates
171 | yval = torch.cat(yval_list) # ground truth in search candidates's order.
172 |
173 | dist_retrieval = self.eval_dist(self.emb, aval_pred, self.c) # n_cls x n_val
174 | rank = dist_retrieval.sort()[1] # n_cls x n_val
175 |
176 | topk_result_inx = rank[:,:50] # n_cls x k (k = 50)
177 | ypred_topk = yval[topk_result_inx] # n_cls x k
178 |
179 | GT = torch.arange(self.emb.size(0)).cuda()
180 |
181 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
182 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
183 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
184 |
185 | print('mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP))
186 |
187 |
188 | class SoftmaxNet(nn.Module):
189 | def __init__(self, T, c, eval_dist, train_loader, val_loader, emb, metric):
190 | super(SoftmaxNet, self).__init__()
191 | self.layers = nn.Sequential(
192 | nn.Linear(2048, 2048),
193 | nn.LeakyReLU(),
194 | nn.Linear(2048, 2048),
195 | nn.LeakyReLU(),
196 | nn.Linear(2048, emb.size(0))
197 | )
198 | self.T = T
199 | self.c = c
200 | self.eval_dist = eval_dist
201 | self.val_loader = val_loader
202 | self.train_loader = train_loader
203 | self.emb = emb
204 | self.metric = metric
205 |
206 | def forward(self, x):
207 | x = self.layers(x)
208 | return x
209 |
210 | def _train(self, optimizer, epochs, eval_interval, feat_layer):
211 | model = self
212 | model.train()
213 | for epoch in range(epochs):
214 |
215 | for i, (xbatch, ybatch, abatch) in enumerate(self.train_loader):
216 |
217 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
218 | optimizer.zero_grad()
219 | logits = model(xbatch)
220 | logits = logits / self.T # Temperature trick
221 | loss_fn = F.cross_entropy
222 | loss = loss_fn(logits,ybatch)
223 |
224 | loss.backward()
225 | optimizer.step()
226 |
227 | if epoch % eval_interval == 0:
228 | self.evaluation('valid', feat_layer)
229 | self.evaluation2()
230 |
231 | def evaluation(self, flag, feat_layer): # 2,4,6三种
232 | model = self
233 |
234 | # 0.query_candidates & GT
235 | feat_list, yval_list = [], []
236 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
237 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
238 | feat_batch = model.layers[0:feat_layer](xbatch).detach()
239 | feat_list.append(feat_batch)
240 | yval_list.append(ybatch)
241 | feat = torch.cat(feat_list) # search candidates
242 | yval_order1 = torch.cat(yval_list) # ground truth in search candidates's order.
243 |
244 | GT_list = [] # ground truth in validation order
245 | loss_list, ypred_list, ypred_topk_list = [],[],[]
246 | torch.manual_seed(42)
247 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
248 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
249 | feat_batch = model.layers[0:feat_layer](xbatch).detach()
250 |
251 | logits = model(xbatch)
252 |
253 | # 1.loss
254 | loss_fn = F.cross_entropy
255 | loss = loss_fn(logits,ybatch).item()
256 | loss_list.append(loss)
257 |
258 | # 2.ypred for recognition
259 | rank = logits.sort()[1]
260 | ypred = rank[:,-1]
261 | ypred_list.append(ypred)
262 |
263 | # 3.ypred_topk for retrieval
264 | dist_retrieval = self.eval_dist(feat_batch, feat, self.c)
265 | rank = dist_retrieval.sort()[1]
266 | topk_result_inx = rank[:,:50] # k = 50
267 | ypred_topk = yval_order1[topk_result_inx] # batch_size x k
268 | ypred_topk_list.append(ypred_topk)
269 | GT_list.append(ybatch)
270 | GT = torch.cat(GT_list)
271 |
272 | loss = torch.tensor(loss_list).mean()
273 | ypred = torch.cat(ypred_list)
274 | ypred_topk = torch.cat(ypred_topk_list)
275 |
276 | hop0_acc = (ypred == GT).float().mean().item()
277 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2)
278 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4)
279 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
280 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
281 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
282 |
283 | print('loss:%.3f,acc:%.3f,1hop_acc:%.3f,2hop_acc:%.3f,mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%
284 | (loss, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP))
285 |
286 | def evaluation2(self):
287 | model = self
288 | # 0.query_candidates & GT
289 | logits_val_list, yval_list = [], []
290 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader):
291 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda()
292 | logits_batch = model(xbatch)
293 | logits_batch = logits_batch / self.T
294 | logits_val_list.append(logits_batch)
295 | yval_list.append(ybatch)
296 | logits_val = torch.cat(logits_val_list) # n_cls x n_val
297 | probs_val = F.softmax(logits_val,dim=1) # n_cls x n_val
298 | rank = probs_val.sort(dim=0)[1] # nval x n_cls
299 | topk_result_inx = rank[-50:,:] # k x n_cls (k = 50)
300 | yval = torch.cat(yval_list) # ground truth in search candidates's order.
301 | ypred_topk = yval[topk_result_inx] # k x n_cls
302 | ypred_topk = ypred_topk.t() # n_cls x k
303 | GT = torch.arange(self.emb.size(0)).unsqueeze(1).cuda() # n_cls x 1
304 |
305 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
306 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
307 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
308 |
309 | print('mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP))
310 |
311 | class CVPR19Net2(nn.Module):
312 | def __init__(self, T, c, parent_set, grandpa_set, son2parent, eval_dist, train_loader, val_loader, emb, metric):
313 | super(CVPR19Net2, self).__init__()
314 |
315 | n_pa, n_gp = len(parent_set), len(grandpa_set)
316 |
317 | self.dense1 = nn.Linear(2048, 2048)
318 | self.leaky1 = nn.LeakyReLU()
319 | self.dense2 = nn.Linear(2048, 2048)
320 | self.leaky2 = nn.LeakyReLU()
321 | self.dense_leaf = nn.Linear(2048, 200)
322 | self.dense_parent = nn.Linear(2048,n_pa)
323 | self.dense_grandpa = nn.Linear(2048,n_gp)
324 |
325 | self.c = c
326 | self.T = T
327 | self.parent_set = parent_set
328 | self.grandpa_set = grandpa_set
329 | self.eval_dist = eval_dist
330 | self.val_loader = val_loader
331 | self.train_loader = train_loader
332 | self.son2parent = son2parent
333 | self.emb = emb
334 | self.metric = metric
335 |
336 | def forward(self, x):
337 | hidden1 = self.dense1(x)
338 | action1 = self.leaky1(hidden1)
339 | hidden2 = self.dense1(action1)
340 | action2 = self.leaky1(hidden2)
341 |
342 | logits_leaf = self.dense_leaf(action1)
343 | logits_leaf = self.dense_leaf(action2)
344 | logits_pa = self.dense_parent(action2)
345 | logits_gp = self.dense_grandpa(action2)
346 |
347 | return logits_leaf, logits_pa, logits_gp, hidden1
348 |
349 | def evaluation(self):
350 | model = self
351 |
352 | # 0.query_candidates & GT
353 | cand_feat_list, cand_yval_list = [], []
354 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
355 | xbatch, ybatch, abatch, _, _ = self.batch_generation_on_gpu(xbatch, ybatch, abatch)
356 | _, _, _, batch_feat = model(xbatch)
357 |
358 | cand_feat_list.append(batch_feat)
359 | cand_yval_list.append(ybatch)
360 | cand_feat = torch.cat(cand_feat_list) # search candidates
361 | cand_yval = torch.cat(cand_yval_list) # ground truth in search candidates's order.
362 |
363 | GT_list = [] # ground truth in validation order
364 | loss_list, ypred_list, ypred_topk_list = [],[],[]
365 | torch.manual_seed(42)
366 |
367 |
368 | loss_lists = ([],[],[])
369 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
370 |
371 | xbatch, ybatch, abatch, ybatch_pa, ybatch_gp = self.batch_generation_on_gpu(xbatch, ybatch, abatch)
372 | # 1.Loss
373 | logits_leaf, logits_pa, logits_gp, batch_feat = model(xbatch)
374 | loss_fn = F.cross_entropy
375 |
376 | loss_leaf = loss_fn(logits_leaf,ybatch).item()
377 | loss_pa = loss_fn(logits_pa,ybatch_pa).item()
378 | loss_gp = loss_fn(logits_gp,ybatch_gp).item()
379 | loss_lists[0].append(loss_leaf)
380 | loss_lists[1].append(loss_pa)
381 | loss_lists[2].append(loss_gp)
382 |
383 | # For accs
384 | rank = logits_leaf.sort()[1]
385 | ypred = rank[:,-1]
386 |
387 | # For retrievals
388 | dist_retrieval = self.eval_dist(batch_feat, cand_feat, self.c)
389 | rank = dist_retrieval.sort()[1]
390 | topk_result_inx = rank[:,:50] # k = 50
391 | ypred_topk = cand_yval[topk_result_inx]
392 |
393 | # For ground truth
394 | GT_list.append(ybatch)
395 | ypred_list.append(ypred)
396 | ypred_topk_list.append(ypred_topk)
397 |
398 | loss_leaf = torch.tensor(loss_lists[0]).mean()
399 | loss_pa = torch.tensor(loss_lists[1]).mean()
400 | loss_gp = torch.tensor(loss_lists[2]).mean()
401 |
402 | GT = torch.cat(GT_list)
403 | ypred = torch.cat(ypred_list)
404 | ypred_topk = torch.cat(ypred_topk_list)
405 |
406 | hop0_acc = (ypred == GT).float().mean().item()
407 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2)
408 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4)
409 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
410 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
411 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
412 |
413 | print('loss:%.3f,loss:%.3f,loss:%.3f,acc:%.3f,1hop_acc:%.3f,2hop_acc:%.3f,mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(loss_leaf,loss_pa,loss_gp, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP))
414 |
415 | def batch_generation_on_gpu(self, xbatch, ybatch, abatch):
416 | son2parent = self.son2parent
417 | label_set = [key for key,value in son2parent.items() if key not in son2parent.values()]
418 | son2grandpa = {key:son2parent[value] for key,value in son2parent.items() if value in self.parent_set}
419 |
420 | ybatch_pa = [self.parent_set.index(son2parent[label_set[item]]) for item in ybatch]
421 | ybatch_pa = torch.tensor(ybatch_pa)
422 | ybatch_gp = [self.grandpa_set.index(son2grandpa[label_set[item]]) for item in ybatch]
423 | ybatch_gp = torch.tensor(ybatch_gp)
424 |
425 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
426 | ybatch_pa, ybatch_gp = ybatch_pa.cuda(), ybatch_gp.cuda()
427 |
428 | return xbatch, ybatch, abatch, ybatch_pa, ybatch_gp
429 |
430 | def _train(self, optimizer, epochs,T, eval_interval):
431 | model = self
432 | model.train()
433 | for epoch in range(epochs):
434 |
435 | for i, (xbatch, ybatch, abatch) in enumerate(self.train_loader):
436 |
437 | xbatch, ybatch, abatch, ybatch_pa, ybatch_gp = self.batch_generation_on_gpu(xbatch, ybatch, abatch)
438 | optimizer.zero_grad()
439 | logits_leaf, logits_pa, logits_gp, _ = model(xbatch)
440 | T = self.T
441 | logits_leaf, logits_pa, logits_gp = logits_leaf/T, logits_pa/T, logits_gp/T # Temperature trick
442 |
443 | loss_fn = F.cross_entropy
444 | loss_leaf = loss_fn(logits_leaf,ybatch)
445 | loss_pa = loss_fn(logits_pa,ybatch_pa)
446 | loss_gp = loss_fn(logits_gp,ybatch_gp)
447 |
448 | loss = loss_leaf + 1* loss_pa + 1 * loss_gp
449 |
450 | loss.backward()
451 | optimizer.step()
452 |
453 | if epoch % eval_interval == 0:
454 | self.evaluation()
455 | self.evaluation2()
456 |
457 | def eval_zsl(self, unseen_idset):
458 | model = self
459 | logits_val_list, yval_list = [], []
460 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader):
461 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda()
462 | logits_batch, _, _, _ = model(xbatch)
463 | logits_batch = logits_batch / self.T
464 | logits_val_list.append(logits_batch)
465 | yval_list.append(ybatch)
466 | logits_val = torch.cat(logits_val_list) # n_val x n_cls
467 |
468 | logits_val = logits_val[:,unseen_idset]
469 |
470 | probs_val = F.softmax(logits_val,dim=1) # n_val x n_unseencls
471 | rank = probs_val.sort(dim=0)[1] # n_val x n_unseencls
472 | topk_result_inx = rank[-50:,:] # k x n_unseencls (top k for each class)
473 |
474 | yval = torch.cat(yval_list) # ground truth in search candidates's order.
475 | ypred_topk = yval[topk_result_inx] # k x n_unseencls # Get top-K predictions for each sample, using for retrieval, we use k = 50 in this context.
476 | ypred_topk = ypred_topk.t()
477 |
478 | GT = torch.tensor(unseen_idset).cuda()
479 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
480 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
481 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
482 | print('ZSL-search by name:, mAP:%.3f, 1hop_mAP:%.3f, 2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP))
483 |
484 | def evaluation2(self):
485 | model = self
486 | # 0.query_candidates & GT
487 | logits_val_list, yval_list = [], []
488 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader):
489 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda()
490 | logits_batch, _, _, _ = model(xbatch)
491 | logits_batch = logits_batch / self.T
492 | logits_val_list.append(logits_batch)
493 | yval_list.append(ybatch)
494 | logits_val = torch.cat(logits_val_list) # n_cls x n_val
495 | probs_val = F.softmax(logits_val,dim=1) # n_cls x n_val
496 | rank = probs_val.sort(dim=0)[1] # n_cls x nval
497 | topk_result_inx = rank[-50:,:] # k x nval (k = 50)
498 | yval = torch.cat(yval_list) # ground truth in search candidates's order.
499 | ypred_topk = yval[topk_result_inx] # k x nval
500 | ypred_topk = ypred_topk.t() # nval x k
501 | GT = torch.arange(self.emb.size(0)).unsqueeze(1).cuda() # n_cls x 1
502 |
503 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
504 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
505 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
506 |
507 | print('mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP))
508 |
509 |
510 | class CVPR19Net(nn.Module):
511 | def __init__(self, T, c, parent_set, grandpa_set, son2parent, eval_dist, train_loader, val_loader, emb, metric):
512 | super(CVPR19Net, self).__init__()
513 |
514 | n_pa, n_gp = len(parent_set), len(grandpa_set)
515 |
516 | self.dense1 = nn.Linear(2048, 2048)
517 | self.leaky1 = nn.LeakyReLU()
518 | self.dense_pa = nn.Linear(2048, 2048)
519 | self.leaky_pa = nn.LeakyReLU()
520 | self.dense_gp = nn.Linear(2048, 2048)
521 | self.leaky_gp = nn.LeakyReLU()
522 | self.dense_leaf = nn.Linear(2048, 200)
523 | self.dense_parent = nn.Linear(4096, n_pa)
524 | self.dense_grandpa = nn.Linear(6144, n_gp)
525 |
526 | self.c = c
527 | self.T = T
528 | self.parent_set = parent_set
529 | self.grandpa_set = grandpa_set
530 | self.eval_dist = eval_dist
531 | self.val_loader = val_loader
532 | self.train_loader = train_loader
533 | self.son2parent = son2parent
534 | self.emb = emb
535 | self.metric = metric
536 |
537 | def forward(self, x):
538 | hidden1 = self.dense1(x)
539 | action1 = self.leaky1(hidden1)
540 | hidden_pa = self.dense_pa(action1)
541 | action_pa = self.leaky_pa(hidden_pa)
542 | hidden_gp = self.dense_gp(action1)
543 | action_gp = self.leaky_gp(hidden_gp)
544 |
545 | logits_leaf = self.dense_leaf(action1)
546 | # import pdb
547 | # pdb.set_trace()
548 | logits_pa = self.dense_parent(torch.cat((action1,action_pa),dim=1))
549 | logits_gp = self.dense_grandpa(torch.cat((action1,action_pa,action_gp),dim=1))
550 | # logits_leaf = self.dense_leaf(action2)
551 | # logits_pa = self.dense_parent(action2)
552 | # logits_gp = self.dense_grandpa(action2)
553 |
554 |
555 | return logits_leaf, logits_pa, logits_gp, hidden1
556 |
557 | def evaluation(self):
558 | model = self
559 |
560 | # 0.query_candidates & GT
561 | cand_feat_list, cand_yval_list = [], []
562 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
563 | xbatch, ybatch, abatch, _, _ = self.batch_generation_on_gpu(xbatch, ybatch, abatch)
564 | _, _, _, batch_feat = model(xbatch)
565 |
566 | cand_feat_list.append(batch_feat)
567 | cand_yval_list.append(ybatch)
568 | cand_feat = torch.cat(cand_feat_list) # search candidates
569 | cand_yval = torch.cat(cand_yval_list) # ground truth in search candidates's order.
570 |
571 | GT_list = [] # ground truth in validation order
572 | loss_list, ypred_list, ypred_topk_list = [],[],[]
573 | torch.manual_seed(42)
574 |
575 |
576 | loss_lists = ([],[],[])
577 | for i, (xbatch, ybatch, abatch) in enumerate(self.val_loader):
578 |
579 | xbatch, ybatch, abatch, ybatch_pa, ybatch_gp = self.batch_generation_on_gpu(xbatch, ybatch, abatch)
580 | # 1.Loss
581 | logits_leaf, logits_pa, logits_gp, batch_feat = model(xbatch)
582 | loss_fn = F.cross_entropy
583 |
584 | loss_leaf = loss_fn(logits_leaf,ybatch).item()
585 | loss_pa = loss_fn(logits_pa,ybatch_pa).item()
586 | loss_gp = loss_fn(logits_gp,ybatch_gp).item()
587 | loss_lists[0].append(loss_leaf)
588 | loss_lists[1].append(loss_pa)
589 | loss_lists[2].append(loss_gp)
590 |
591 | # For accs
592 | rank = logits_leaf.sort()[1]
593 | ypred = rank[:,-1]
594 |
595 | # For retrievals
596 | dist_retrieval = self.eval_dist(batch_feat, cand_feat, self.c)
597 | rank = dist_retrieval.sort()[1]
598 | topk_result_inx = rank[:,:50] # k = 50
599 | ypred_topk = cand_yval[topk_result_inx]
600 |
601 | # For ground truth
602 | GT_list.append(ybatch)
603 | ypred_list.append(ypred)
604 | ypred_topk_list.append(ypred_topk)
605 |
606 | loss_leaf = torch.tensor(loss_lists[0]).mean()
607 | loss_pa = torch.tensor(loss_lists[1]).mean()
608 | loss_gp = torch.tensor(loss_lists[2]).mean()
609 |
610 | GT = torch.cat(GT_list)
611 | ypred = torch.cat(ypred_list)
612 | ypred_topk = torch.cat(ypred_topk_list)
613 |
614 | hop0_acc = (ypred == GT).float().mean().item()
615 | hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2)
616 | hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4)
617 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
618 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
619 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
620 |
621 | print('loss:%.3f,loss:%.3f,loss:%.3f,acc:%.3f,1hop_acc:%.3f,2hop_acc:%.3f,mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(loss_leaf,loss_pa,loss_gp, hop0_acc, hop1_acc, hop2_acc, hop0_mAP, hop1_mAP, hop2_mAP))
622 |
623 | def batch_generation_on_gpu(self, xbatch, ybatch, abatch):
624 | son2parent = self.son2parent
625 | label_set = [key for key,value in son2parent.items() if key not in son2parent.values()]
626 | son2grandpa = {key:son2parent[value] for key,value in son2parent.items() if value in self.parent_set}
627 |
628 | ybatch_pa = [self.parent_set.index(son2parent[label_set[item]]) for item in ybatch]
629 | ybatch_pa = torch.tensor(ybatch_pa)
630 | ybatch_gp = [self.grandpa_set.index(son2grandpa[label_set[item]]) for item in ybatch]
631 | ybatch_gp = torch.tensor(ybatch_gp)
632 |
633 | xbatch, ybatch, abatch = xbatch.cuda(), ybatch.cuda(), abatch.cuda()
634 | ybatch_pa, ybatch_gp = ybatch_pa.cuda(), ybatch_gp.cuda()
635 |
636 | return xbatch, ybatch, abatch, ybatch_pa, ybatch_gp
637 |
638 | def _train(self, optimizer, epochs,T, eval_interval):
639 | model = self
640 | model.train()
641 | for epoch in range(epochs):
642 |
643 | for i, (xbatch, ybatch, abatch) in enumerate(self.train_loader):
644 |
645 | xbatch, ybatch, abatch, ybatch_pa, ybatch_gp = self.batch_generation_on_gpu(xbatch, ybatch, abatch)
646 | optimizer.zero_grad()
647 | logits_leaf, logits_pa, logits_gp, _ = model(xbatch)
648 | T = self.T
649 | logits_leaf, logits_pa, logits_gp = logits_leaf/T, logits_pa/T, logits_gp/T # Temperature trick
650 |
651 | loss_fn = F.cross_entropy
652 | loss_leaf = loss_fn(logits_leaf,ybatch)
653 | loss_pa = loss_fn(logits_pa,ybatch_pa)
654 | loss_gp = loss_fn(logits_gp,ybatch_gp)
655 |
656 | loss = loss_leaf + 1* loss_pa + 1 * loss_gp
657 |
658 | loss.backward()
659 | optimizer.step()
660 |
661 | if epoch % eval_interval == 0:
662 | self.evaluation()
663 | self.evaluation2()
664 |
665 | def eval_zsl(self, unseen_idset):
666 | model = self
667 | logits_val_list, yval_list = [], []
668 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader):
669 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda()
670 | logits_batch, _, _, _ = model(xbatch)
671 | logits_batch = logits_batch / self.T
672 | logits_val_list.append(logits_batch)
673 | yval_list.append(ybatch)
674 | logits_val = torch.cat(logits_val_list) # n_val x n_cls
675 | # 关键步骤
676 | logits_val = logits_val[:,unseen_idset]
677 |
678 | probs_val = F.softmax(logits_val,dim=1) # n_val x n_unseencls
679 | rank = probs_val.sort(dim=0)[1] # n_val x n_unseencls
680 | topk_result_inx = rank[-50:,:] # k x n_unseencls (top k for each class)
681 |
682 | yval = torch.cat(yval_list) # ground truth in search candidates's order.
683 | ypred_topk = yval[topk_result_inx] # k x n_unseencls # Get top-K predictions for each sample, using for retrieval, we use k = 50 in this context.
684 | ypred_topk = ypred_topk.t()
685 |
686 | # top1_result_inx = rank[-1,:] # n_unseencls x 1 (k = 1)
687 | # ypred = yval[top1_result_inx]
688 | # hop0_acc = (ypred == GT).float().mean().item()
689 | # hop1_acc = self.metric.hop_acc(ypred,GT, hops = 2)
690 | # hop2_acc = self.metric.hop_acc(ypred,GT, hops = 4)
691 |
692 | GT = torch.tensor(unseen_idset).cuda()
693 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
694 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
695 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
696 | print('ZSL-search by name:, mAP:%.3f, 1hop_mAP:%.3f, 2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP))
697 |
698 | # dist = self.eval_dist(apred, unseen_emb, self.c)
699 |
700 | def evaluation2(self):
701 | model = self
702 | # 0.query_candidates & GT
703 | logits_val_list, yval_list = [], []
704 | for i, (xbatch, ybatch, _) in enumerate(self.val_loader):
705 | xbatch, ybatch = xbatch.cuda(), ybatch.cuda()
706 | logits_batch, _, _, _ = model(xbatch)
707 | logits_batch = logits_batch / self.T
708 | logits_val_list.append(logits_batch)
709 | yval_list.append(ybatch)
710 | logits_val = torch.cat(logits_val_list) # n_cls x n_val
711 | probs_val = F.softmax(logits_val,dim=1) # n_cls x n_val
712 | rank = probs_val.sort(dim=0)[1] # n_cls x nval
713 | topk_result_inx = rank[-50:,:] # k x nval (k = 50)
714 | yval = torch.cat(yval_list) # ground truth in search candidates's order.
715 | ypred_topk = yval[topk_result_inx] # k x nval
716 | ypred_topk = ypred_topk.t() # nval x k
717 | GT = torch.arange(self.emb.size(0)).unsqueeze(1).cuda() # n_cls x 1
718 |
719 | hop0_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 0)
720 | hop1_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 2)
721 | hop2_mAP = self.metric.hop_mAP(ypred_topk, GT, hop = 4)
722 |
723 | print('mAP:%.3f,1hop_mAP:%.3f,2hop_mAP:%.3f'%(hop0_mAP, hop1_mAP, hop2_mAP))
724 |
725 | class NIP19Proto():
726 | def __init__(self, w2v_emb, lr, epochs, loss_lambdas):
727 | self.w2v_emb = w2v_emb
728 | self.lr = lr
729 | self.epochs = epochs
730 | self.loss_lambdas = loss_lambdas
731 |
732 | def similarity(self,prototypes):
733 | norm = torch.norm(prototypes, dim=1) # each row is norm 1.
734 |
735 | deviation = (norm.sum() - prototypes.shape[0])
736 | if deviation > 1e1:
737 | print('deviation from norm 1', deviation)
738 |
739 | t1 = norm.unsqueeze(1) # n_cls x 1
740 | t2 = norm.unsqueeze(0) # 1 x n_cls
741 | denominator = torch.matmul(t1, t2) # n_cls x n_cls, each element is a norm product
742 | numerator = torch.matmul(prototypes, prototypes.t()) # each element is a in-prod
743 | cos_sim = numerator / denominator # n_cls x n_cls, each element is a cos_sim
744 | cos_sim_off_diag = cos_sim - torch.diag(torch.diag(cos_sim))
745 | obj = cos_sim_off_diag.max(dim=1)[0]
746 |
747 | return obj.mean(), cos_sim
748 |
749 | def order_loss(self,prototypes, w2v, lmd=0):
750 | B = prototypes.t()
751 | _, S = self.similarity(w2v)
752 | S = S.float()
753 |
754 | # Laplacian matrix L
755 | S1 = S - lmd
756 | ones = torch.ones(S1.shape[0], 1)
757 | L = torch.diag(S1.matmul(ones)) - S
758 |
759 | # Loss = Trace(BLB')
760 | M = B.matmul(L)
761 | M = M.matmul(B.t())
762 | o_loss = torch.trace(M) / B.shape[1]**2
763 |
764 | return o_loss
765 |
766 | def train(self):
767 |
768 | emb = self.w2v_emb.cpu()
769 | emb = F.normalize(emb)
770 | prototypes = nn.Parameter(F.normalize(torch.randn(emb.size(0), 300), p=2, dim=1))
771 | optimizer = torch.optim.SGD([prototypes], lr=self.lr, momentum=0.9)
772 | best_loss = 1000
773 |
774 | for i in range(self.epochs):
775 | optimizer.zero_grad()
776 | sim, _ = self.similarity(prototypes)
777 | o_loss = self.order_loss(prototypes, emb.cpu(), lmd=0)
778 | loss = self.loss_lambdas[0] * sim + self.loss_lambdas[1] * o_loss # 默认是1:1
779 |
780 | loss.backward(retain_graph=True)
781 | optimizer.step()
782 | if i % 10 == 0:
783 | if i % 100 == 0:
784 | print(f'Loss: {loss}, Order Loss: {o_loss} and Sim Loss: {sim}')
785 | prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1))
786 | optimizer = torch.optim.SGD([prototypes], lr=self.lr, momentum=0.9)
787 |
788 | self.prototypes = prototypes
789 |
790 |
791 |
--------------------------------------------------------------------------------
/moments_depth.csv:
--------------------------------------------------------------------------------
1 | ,bowling,compete.v.01,act.v.01,Root,1
2 | ,competing,compete.v.01,act.v.01,Root,1
3 | ,playing+sports,compete.v.01,act.v.01,Root,1
4 | ,racing,compete.v.01,act.v.01,Root,0
5 | ,,descending,act.v.01,Root,1
6 | ,imitating,perform.v.02,act.v.01,Root,1
7 | marrying,officiate.v.01,perform.v.02,act.v.01,Root,1
8 | officiating,officiate.v.01,perform.v.02,act.v.01,Root,1
9 | ,performing,perform.v.02,act.v.01,Root,0
10 | ,gambling,play.v.16,act.v.01,Root,1
11 | ,sneezing,act_involuntarily.v.01,act.v.02,Root,1
12 | ,squinting,act_involuntarily.v.01,act.v.02,Root,1
13 | ,winking,act_involuntarily.v.01,act.v.02,Root,0
14 | ,playing+fun,play.v.16,act.v.02,Root,1
15 | ,playing+videogames,play.v.16,act.v.02,Root,1
16 | ,,hanging,be.v.01,Root,1
17 | ,,leaning,be.v.01,Root,1
18 | ,kneeling,rest.v.01,be.v.01,Root,0
19 | ,queuing,rest.v.01,be.v.01,Root,1
20 | ,resting,rest.v.01,be.v.01,Root,1
21 | ,,standing,be.v.01,Root,1
22 | ,,yawning,be.v.01,Root,1
23 | ,crouching,sit.v.01,be.v.03,Root,1
24 | ,sitting,sit.v.01,be.v.03,Root,1
25 | ,squatting,sit.v.01,be.v.03,Root,0
26 | ,,sleeping,be.v.03,Root,1
27 | ,sniffing,inhale.v.02,cause_to_be_perceived.v.01,Root,1
28 | ,smelling,smell.v.02,cause_to_be_perceived.v.01,Root,1
29 | ,dining,sound.v.02,cause_to_be_perceived.v.01,Root,1
30 | ,knocking,sound.v.02,cause_to_be_perceived.v.01,Root,1
31 | ,whistling,sound.v.02,cause_to_be_perceived.v.01,Root,0
32 | ,tuning,adjust.v.01,change.v.01,Root,1
33 | bathing,fancify.v.01,better.v.02,change.v.01,Root,1
34 | grooming,fancify.v.01,better.v.02,change.v.01,Root,1
35 | ,,camping,change.v.01,Root,1
36 | ,mopping,clean.v.01,change.v.01,Root,1
37 | ,rinsing,clean.v.01,change.v.01,Root,1
38 | ,scrubbing,clean.v.01,change.v.01,Root,1
39 | ,vacuuming,clean.v.01,change.v.01,Root,1
40 | ,washing,clean.v.01,change.v.01,Root,0
41 | ,,clearing,change.v.01,Root,1
42 | ,,crushing,change.v.01,Root,1
43 | ,,cutting,change.v.01,Root,1
44 | ,trimming,decorate.v.01,change.v.01,Root,1
45 | ,,drying,change.v.01,Root,1
46 | ,loading,fill.v.01,change.v.01,Root,1
47 | ,stretching,increase.v.02,change.v.01,Root,1
48 | ,,inflating,change.v.01,Root,1
49 | ,,putting,change.v.01,Root,1
50 | ,clipping,reduce.v.01,change.v.01,Root,1
51 | ,handcuffing,restrain.v.03,change.v.01,Root,1
52 | ,,socializing,change.v.01,Root,1
53 | ,combing,straighten.v.02,change.v.01,Root,1
54 | ,welding,unite.v.06,change.v.01,Root,1
55 | ,,waking,change.v.01,Root,1
56 | ,draining,weaken.v.01,change.v.01,Root,1
57 | ,drenching,wet.v.01,change.v.01,Root,1
58 | ,flooding,wet.v.01,change.v.01,Root,1
59 | ,leaking,wet.v.01,change.v.01,Root,1
60 | ,overflowing,wet.v.01,change.v.01,Root,1
61 | ,submerging,wet.v.01,change.v.01,Root,1
62 | ,watering,wet.v.01,change.v.01,Root,1
63 | ,wetting,wet.v.01,change.v.01,Root,0
64 | ,,boiling,change.v.02,Root,1
65 | ,,breaking,change.v.02,Root,1
66 | chewing,break.v.02,change_integrity.v.01,change.v.02,Root,1
67 | frying,cook.v.03,change_integrity.v.01,change.v.02,Root,1
68 | barbecuing,cook.v.03,change_integrity.v.01,change.v.02,Root,1
69 | grilling,cook.v.03,change_integrity.v.01,change.v.02,Root,0
70 | climbing,increase.v.01,change_magnitude.v.01,change.v.02,Root,1
71 | waxing,increase.v.01,change_magnitude.v.01,change.v.02,Root,1
72 | ,bowing,change_posture.v.01,change.v.02,Root,1
73 | ,bending,change_shape.v.01,change.v.02,Root,1
74 | ,pressing,change_shape.v.01,change.v.02,Root,0
75 | ,combusting,change_state.v.01,change.v.02,Root,1
76 | ,sanding,change_surface.v.01,change.v.02,Root,1
77 | ,cracking,crack.v.01,change.v.02,Root,1
78 | ,,crashing,change.v.02,Root,1
79 | ,dressing,dress.v.01,change.v.02,Root,1
80 | ,,dropping,change.v.02,Root,1
81 | ,tattooing,dye.v.01,change.v.02,Root,1
82 | ,,erupting,change.v.02,Root,1
83 | ,,falling,change.v.02,Root,1
84 | ,,folding,change.v.02,Root,1
85 | ,,landing,change.v.02,Root,1
86 | ,,removing,change.v.02,Root,1
87 | ,,rising,change.v.02,Root,1
88 | ,,turning,change.v.02,Root,0
89 | ,,adult+female+speaking,communicate.v.02,Root,1
90 | ,,adult+male+speaking,communicate.v.02,Root,1
91 | ,,baptizing,communicate.v.02,Root,1
92 | ,,buying,communicate.v.02,Root,1
93 | ,,child+speaking,communicate.v.02,Root,1
94 | ,giggling,express_emotion.v.01,communicate.v.02,Root,1
95 | ,laughing,express_emotion.v.01,communicate.v.02,Root,1
96 | ,shouting,express_emotion.v.01,communicate.v.02,Root,1
97 | ,applauding,gesticulate.v.01,communicate.v.02,Root,1
98 | ,clapping,gesticulate.v.01,communicate.v.02,Root,1
99 | ,shrugging,gesticulate.v.01,communicate.v.02,Root,1
100 | ,frowning,grimace.v.01,communicate.v.02,Root,1
101 | ,grinning,grimace.v.01,communicate.v.02,Root,1
102 | ,smiling,grimace.v.01,communicate.v.02,Root,1
103 | ,coaching,inform.v.01,communicate.v.02,Root,1
104 | ,instructing,inform.v.01,communicate.v.02,Root,1
105 | ,lecturing,inform.v.01,communicate.v.02,Root,1
106 | ,pointing,inform.v.01,communicate.v.02,Root,1
107 | ,teaching,inform.v.01,communicate.v.02,Root,0
108 | ,,paying,communicate.v.02,Root,1
109 | ,asking,request.v.01,communicate.v.02,Root,1
110 | ,praying,request.v.01,communicate.v.02,Root,1
111 | ,,selling,communicate.v.02,Root,1
112 | ,adult+female+singing,sing.v.02,communicate.v.02,Root,1
113 | ,adult+male+singing,sing.v.02,communicate.v.02,Root,1
114 | ,child+singing,sing.v.02,communicate.v.02,Root,1
115 | ,signing,sing.v.02,communicate.v.02,Root,1
116 | ,singing,sing.v.02,communicate.v.02,Root,0
117 | ,,speaking,communicate.v.02,Root,1
118 | ,discussing,talk.v.02,communicate.v.02,Root,1
119 | ,interviewing,talk.v.02,communicate.v.02,Root,1
120 | ,preaching,talk.v.02,communicate.v.02,Root,1
121 | ,talking,talk.v.02,communicate.v.02,Root,0
122 | ,telephoning,telecommunicate.v.01,communicate.v.02,Root,1
123 | ,,buttoning,connect.v.01,Root,1
124 | ,,joining,connect.v.01,Root,1
125 | ,,sewing,connect.v.01,Root,1
126 | ,,stitching,connect.v.01,Root,1
127 | ,,tying,connect.v.01,Root,0
128 | ,,drinking,consume.v.02,Root,1
129 | ,dunking,eat.v.01,consume.v.02,Root,1
130 | ,eating,eat.v.02,consume.v.02,Root,1
131 | ,,filling,consume.v.02,Root,1
132 | ,,smoking,consume.v.02,Root,0
133 | ,,boxing,contend.v.06,Root,1
134 | ,,fencing,contend.v.06,Root,1
135 | ,,fighting,contend.v.06,Root,1
136 | ,,wrestling,contend.v.06,Root,0
137 | ,piloting,steer.v.01,control.v.01,Root,1
138 | ,starting,steer.v.01,control.v.01,Root,1
139 | ,steering,steer.v.01,control.v.01,Root,1
140 | ,typing,steer.v.01,control.v.01,Root,0
141 | ,,dusting,cover.v.01,Root,1
142 | ,,painting,cover.v.01,Root,1
143 | ,,wrapping,cover.v.01,Root,1
144 | ,,boarding,enter.v.01,Root,1
145 | dipping,immerse.v.01,penetrate.v.01,enter.v.01,Root,1
146 | plunging,immerse.v.01,penetrate.v.01,enter.v.01,Root,1
147 | biting,pierce.v.05,penetrate.v.01,enter.v.01,Root,0
148 | ,bubbling,emit.v.01,exhaust.v.05,Root,1
149 | ,coughing,expectorate.v.02,exhaust.v.05,Root,1
150 | filming,record.v.01,save.v.02,have.v.01,Root,1
151 | photographing,record.v.01,save.v.02,have.v.01,Root,1
152 | ,,carrying,have.v.02,Root,1
153 | ,,burying,hide.v.01,Root,1
154 | ,,covering,hide.v.01,Root,1
155 | ,cuddling,embrace.v.02,hold.v.02,Root,1
156 | ,,locking,hold.v.02,Root,1
157 | ,,balancing,hold.v.14,Root,1
158 | ,,juggling,hold.v.14,Root,1
159 | ,arresting,arouse.v.01,make.v.03,Root,1
160 | ,celebrating,arouse.v.01,make.v.03,Root,1
161 | ,cheering,arouse.v.01,make.v.03,Root,1
162 | ,cheerleading,arouse.v.01,make.v.03,Root,1
163 | ,marching,arouse.v.01,make.v.03,Root,1
164 | ,parading,arouse.v.01,make.v.03,Root,1
165 | ,protesting,arouse.v.01,make.v.03,Root,0
166 | ,,building,make.v.03,Root,1
167 | ,,constructing,make.v.03,Root,1
168 | ,assembling,create_from_raw_material.v.01,make.v.03,Root,1
169 | ,baking,create_from_raw_material.v.01,make.v.03,Root,1
170 | ,cooking,create_from_raw_material.v.01,make.v.03,Root,1
171 | ,crafting,create_from_raw_material.v.01,make.v.03,Root,1
172 | ,hammering,create_from_raw_material.v.01,make.v.03,Root,1
173 | ,knitting,create_from_raw_material.v.01,make.v.03,Root,1
174 | ,spinning,create_from_raw_material.v.01,make.v.03,Root,0
175 | ,,playing,make.v.03,Root,1
176 | ,,raising,make.v.03,Root,1
177 | ,autographing,write.v.02,make.v.03,Root,1
178 | ,handwriting,write.v.02,make.v.03,Root,1
179 | ,sketching,write.v.02,make.v.03,Root,1
180 | ,writing,write.v.02,make.v.03,Root,0
181 | ,poking,agitate.v.06,move.v.02,Root,1
182 | ,,blowing,move.v.02,Root,1
183 | ,towing,drag.v.01,move.v.02,Root,1
184 | ,,dragging,move.v.02,Root,1
185 | ,,driving,move.v.02,Root,1
186 | ,,flicking,move.v.02,Root,1
187 | ,smashing,hit.v.12,move.v.02,Root,1
188 | ,,kicking,move.v.02,Root,1
189 | ,launching,launch.v.05,move.v.02,Root,1
190 | ,,lifting,move.v.02,Root,1
191 | ,dripping,pour.v.01,move.v.02,Root,1
192 | ,pouring,pour.v.01,move.v.02,Root,1
193 | ,punting,propel.v.01,move.v.02,Root,1
194 | ,,pulling,move.v.02,Root,1
195 | ,,pushing,move.v.02,Root,1
196 | ,aiming,put.v.01,move.v.02,Root,0
197 | ,attacking,put.v.01,move.v.02,Root,1
198 | ,cramming,put.v.01,move.v.02,Root,1
199 | ,placing,put.v.01,move.v.02,Root,1
200 | ,planting,put.v.01,move.v.02,Root,1
201 | ,plugging,put.v.01,move.v.02,Root,1
202 | ,shooting,put.v.01,move.v.02,Root,1
203 | ,snuggling,put.v.01,move.v.02,Root,1
204 | ,sowing,put.v.01,move.v.02,Root,1
205 | ,stacking,put.v.01,move.v.02,Root,1
206 | ,throwing,put.v.01,move.v.02,Root,0
207 | ,,rocking,move.v.02,Root,1
208 | ,carving,separate.v.02,move.v.02,Root,1
209 | ,chasing,separate.v.02,move.v.02,Root,1
210 | ,chopping,separate.v.02,move.v.02,Root,1
211 | ,manicuring,separate.v.02,move.v.02,Root,1
212 | ,mowing,separate.v.02,move.v.02,Root,1
213 | ,sawing,separate.v.02,move.v.02,Root,1
214 | ,shaving,separate.v.02,move.v.02,Root,1
215 | ,shredding,separate.v.02,move.v.02,Root,1
216 | ,slicing,separate.v.02,move.v.02,Root,1
217 | ,tearing,separate.v.02,move.v.02,Root,0
218 | ,,slipping,move.v.02,Root,1
219 | ,,spilling,move.v.02,Root,1
220 | ,reaching,succeed.v.01,move.v.02,Root,1
221 | ,entering,succeed.v.02,move.v.02,Root,1
222 | ,,swinging,move.v.02,Root,1
223 | ,,closing,move.v.03,Root,1
224 | ,,dancing,move.v.03,Root,1
225 | ,,exiting,move.v.03,Root,1
226 | ,,flipping,move.v.03,Root,1
227 | ,,flowing,move.v.03,Root,1
228 | ,bouncing,jump.v.01,move.v.03,Root,1
229 | ,jumping,jump.v.01,move.v.03,Root,1
230 | ,leaping,jump.v.01,move.v.03,Root,1
231 | ,skipping,jump.v.01,move.v.03,Root,0
232 | ,,pitching,move.v.03,Root,1
233 | ,,rolling,move.v.03,Root,1
234 | ,,sailing,move.v.03,Root,1
235 | ,,shaking,move.v.03,Root,1
236 | ,,snapping,move.v.03,Root,1
237 | ,,stealing,move.v.03,Root,1
238 | ,,stirring,move.v.03,Root,1
239 | ,,sweeping,move.v.03,Root,1
240 | ,,tripping,move.v.03,Root,1
241 | ,swerving,turn.v.01,move.v.03,Root,1
242 | ,drilling,turn.v.09,move.v.03,Root,1
243 | ,screwing,turn.v.09,move.v.03,Root,1
244 | ,twisting,turn.v.09,move.v.03,Root,0
245 | ,,waving,move.v.03,Root,1
246 | ,,catching,move.v.15,Root,1
247 | ,,opening,move.v.15,Root,1
248 | ,,raining,precipitate.v.03,Root,1
249 | ,,snowing,precipitate.v.03,Root,1
250 | ,,storming,precipitate.v.03,Root,1
251 | ,,stopping,prevent.v.01,Root,1
252 | guarding,protect.v.01,defend.v.02,prevent.v.02,Root,1
253 | blocking,obstruct.v.02,impede.v.01,prevent.v.02,Root,1
254 | ,,brushing,remove.v.01,Root,1
255 | ,,cleaning,remove.v.01,Root,1
256 | ,shoveling,dig.v.01,remove.v.01,Root,1
257 | ,digging,dig.v.01,remove.v.01,Root,1
258 | ,,drawing,remove.v.01,Root,1
259 | ,unloading,empty.v.04,remove.v.01,Root,0
260 | ,emptying,empty.v.04,remove.v.01,Root,1
261 | ,,picking,remove.v.01,Root,1
262 | ,peeling,take_off.v.02,remove.v.01,Root,1
263 | ,,unpacking,remove.v.01,Root,1
264 | ,,weeding,remove.v.01,Root,0
265 | ,,hunting,search.v.01,Root,1
266 | ,,surfing,search.v.01,Root,1
267 | ,drumming,play.v.07,sound.v.06,Root,1
268 | ,playing+music,play.v.07,sound.v.06,Root,1
269 | splashing,scatter.v.03,discharge.v.02,spread.v.01,Root,1
270 | spraying,scatter.v.03,discharge.v.02,spread.v.01,Root,1
271 | sprinkling,scatter.v.03,discharge.v.02,spread.v.01,Root,0
272 | ,,spreading,spread.v.01,Root,1
273 | ,gardening,care.v.02,support.v.01,Root,1
274 | ,,shopping,support.v.01,Root,1
275 | fishing,catch.v.04,seize.v.01,take.v.04,Root,1
276 | ,clawing,seize.v.01,take.v.04,Root,1
277 | ,gripping,seize.v.01,take.v.04,Root,1
278 | ,measuring,understand.v.01,think.v.03,Root,1
279 | ,reading,understand.v.01,think.v.03,Root,1
280 | ,studying,understand.v.01,think.v.03,Root,1
281 | massaging,manipulate.v.02,handle.v.04,touch.v.01,Root,1
282 | ,colliding,hit.v.02,touch.v.01,Root,0
283 | ,punching,hit.v.02,touch.v.01,Root,1
284 | ,,hitting,touch.v.01,Root,1
285 | ,,kissing,touch.v.01,Root,1
286 | ,,licking,touch.v.01,Root,1
287 | ,slapping,strike.v.01,touch.v.01,Root,1
288 | ,tickling,strike.v.01,touch.v.01,Root,0
289 | ,,stroking,touch.v.01,Root,1
290 | ,,clinging,touch.v.05,Root,1
291 | packaging,encase.v.01,enclose.v.03,touch.v.05,Root,1
292 | ,,hugging,touch.v.05,Root,1
293 | ,,rubbing,touch.v.05,Root,1
294 | ,,scratching,touch.v.05,Root,0
295 | feeding,provide.v.02,give.v.03,transfer.v.05,Root,1
296 | fueling,supply.v.01,give.v.03,transfer.v.05,Root,1
297 | ,,giving,transfer.v.05,Root,1
298 | ,,ascending,travel.v.01,Root,1
299 | saluting,greet.v.01,come.v.01,travel.v.01,Root,1
300 | ,,crawling,travel.v.01,Root,0
301 | ,diving,descend.v.01,travel.v.01,Root,1
302 | ,,floating,travel.v.01,Root,1
303 | ,,flying,travel.v.01,Root,1
304 | ,skating,glide.v.01,travel.v.01,Root,1
305 | ,,rafting,travel.v.01,Root,1
306 | ,,repairing,travel.v.01,Root,1
307 | ,bicycling,ride.v.02,travel.v.01,Root,1
308 | rowing,boat.v.01,ride.v.02,travel.v.01,Root,1
309 | ,boating,ride.v.02,travel.v.01,Root,1
310 | ,hitchhiking,ride.v.02,travel.v.01,Root,1
311 | ,pedaling,ride.v.02,travel.v.01,Root,0
312 | ,,riding,travel.v.01,Root,1
313 | ,,running,travel.v.01,Root,1
314 | ,,skiing,travel.v.01,Root,1
315 | ,,sliding,travel.v.01,Root,1
316 | ,,swimming,travel.v.01,Root,1
317 | ,jogging,travel_rapidly.v.01,travel.v.01,Root,1
318 | ,sprinting,travel_rapidly.v.01,travel.v.01,Root,1
319 | ,hiking,walk.v.01,travel.v.01,Root,1
320 | ,stomping,walk.v.01,travel.v.01,Root,0
321 | ,,walking,travel.v.01,Root,1
322 | ,,burning,treat.v.03,Root,1
323 | ,,injecting,treat.v.03,Root,1
324 | ,,operating,treat.v.03,Root,1
325 | ,,packing,treat.v.03,Root,0
326 | ,bulldozing,destroy.v.01,unmake.v.01,Root,1
327 | ,destroying,destroy.v.01,unmake.v.01,Root,1
328 | ,extinguishing,destroy.v.01,unmake.v.01,Root,0
329 | ,,exercising,use.v.01,Root,1
330 | ,,taping,use.v.01,Root,1
331 | ,,tapping,use.v.01,Root,1
332 | ,,barking,utter.v.02,Root,1
333 | ,,calling,utter.v.02,Root,1
334 | ,,crying,utter.v.02,Root,1
335 | ,,howling,utter.v.02,Root,1
336 | ,,roaring,utter.v.02,Root,1
337 | ,,spitting,utter.v.02,Root,0
338 | ,,working,work.v.01,Root,1
339 | ,,serving,work.v.02,Root,1
--------------------------------------------------------------------------------
/pmath.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def pair_wise_eud(x,y,c=1.0):
6 | # input:
7 | # x, m x d
8 | # y, n x d
9 | # output: m x n
10 | m = x.size(0)
11 | n = y.size(0)
12 | d = x.size(1)
13 | assert(x.size(1) == y.size(1))
14 | xx = x.pow(2).sum(-1, keepdim = True)
15 | yy = y.pow(2).sum(-1, keepdim = True)
16 | xy = torch.einsum('ij,kj->ik', (x, y))
17 |
18 | result = xx - 2*xy + yy.permute(1, 0)
19 | return result
20 |
21 | def pair_wise_cos(x,y,c=1.0):
22 | # input:
23 | # x, m x d
24 | # y, n x d
25 | # output: m x n
26 | x_norm = torch.norm(x, dim=1, keepdim = True)# m x 1
27 | y_norm = torch.norm(y, dim=1, keepdim = True)# n x 1
28 |
29 | denominator = torch.matmul(x_norm, y_norm.t()) # m x n
30 | numerator = torch.matmul(x, y.t()) # m x n, each element is a in-prod
31 |
32 | return -numerator / denominator # m x n
33 |
34 | def pair_wise_hyp(x, y, c=1.0):
35 | c = torch.as_tensor(c)
36 | return _dist_matrix(x, y, c)
37 |
38 | def tanh(x, clamp=15):
39 | return x.clamp(-clamp, clamp).tanh()
40 |
41 |
42 | def tensor_dot(x, y):
43 | res = torch.einsum('ij,kj->ik', (x, y))
44 | return res
45 |
46 | def _mobius_addition_batch(x, y, c):
47 | xy = tensor_dot(x, y) # B x C
48 | x2 = x.pow(2).sum(-1, keepdim=True) # B x 1
49 | y2 = y.pow(2).sum(-1, keepdim=True) # C x 1
50 | num = (1 + 2 * c * xy + c * y2.permute(1, 0)) # B x C
51 | num = num.unsqueeze(2) * x.unsqueeze(1) # B x C x 1 * B x 1 x D = B x C x D
52 | num = num + (1 - c * x2).unsqueeze(2) * y # B x C x D + B x 1 x 1 = B x C x D
53 | denom_part1 = 1 + 2 * c * xy # B x C
54 | denom_part2 = c ** 2 * x2 * y2.permute(1, 0) # B x 1 * 1 x C = B x C
55 | denom = denom_part1 + denom_part2
56 | res = num / (denom.unsqueeze(2) + 1e-5)
57 | return res
58 |
59 | def _mobius_addition_same_size(x, y, c):
60 | xy = torch.einsum('ij,ij -> i', (x, y)).unsqueeze(1) # n x1
61 | x2 = x.pow(2).sum(-1, keepdim=True) # n x 1
62 | y2 = y.pow(2).sum(-1, keepdim=True) # n x 1
63 | num = 1 + 2 * c * xy + c * y2 # n x 1
64 | num2 = num * x # n x D
65 | num3 = num + (1 - c * x2) * y # n x D
66 | denom_part1 = 1 + 2 * c * xy # n x 1
67 | denom_part2 = c ** 2 * x2 * y2 # n x 1
68 | denom = denom_part1 + denom_part2
69 | res = num3 / (denom + 1e-5)
70 |
71 | return res
72 |
73 | def _dist_matrix(x, y, c):
74 | sqrt_c = c ** 0.5
75 | return 2 / sqrt_c * artanh(sqrt_c * torch.norm(_mobius_addition_batch(-x, y, c=c), dim=-1))
76 |
77 |
78 | def dist_same_size(x,y, c=1.0):
79 | c = torch.as_tensor(c)
80 | sqrt_c = c ** 0.5
81 | return 2 / sqrt_c * artanh(sqrt_c * torch.norm(_mobius_addition_same_size(-x, y, c=c), dim=-1))
82 |
83 | def project(x, *, c=1.0):
84 | r"""
85 | Safe projection on the manifold for numerical stability. This was mentioned in [1]_
86 | Parameters
87 | """
88 | c = torch.as_tensor(c).type_as(x)
89 | return _project(x, c)
90 |
91 | def _project(x, c):
92 | """Parameters
93 | ----------
94 | x : tensor
95 | point on the Poincare ball
96 | c : float|tensor
97 | ball negative curvature
98 | Returns
99 | -------
100 | tensor
101 | projected vector on the manifold"""
102 | c = torch.as_tensor(c).type_as(x)
103 |
104 | norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
105 | maxnorm = (1 - 1e-3) / (c ** 0.5)
106 | cond = norm > maxnorm
107 | projected = x / norm * maxnorm
108 | return torch.where(cond, projected, x)
109 |
110 | def mobius_matvec(m, x, c):
111 | x_norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
112 | sqrt_c = c ** 0.5
113 | mx = x @ m.transpose(-1, -2)
114 | mx_norm = mx.norm(dim=-1, keepdim=True, p=2)
115 | res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
116 | cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8)
117 | res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
118 | res = torch.where(cond, res_0, res_c)
119 | return _project(res, c)
120 |
121 | def euc2hyp(x, c):
122 | new_x = _project(expmap0(x, c), c)
123 | return new_x
124 |
125 | ### 以下为原始代码
126 | def dist(x, y, *, c=1.0, keepdim=False):
127 | r"""
128 | Distance on the Poincare ball
129 | .. math::
130 | d_c(x, y) = \frac{2}{\sqrt{c}}\tanh^{-1}(\sqrt{c}\|(-x)\oplus_c y\|_2)
131 | .. plot:: plots/extended/poincare/distance.py
132 | Parameters
133 | ----------
134 | x : tensor
135 | point on poincare ball
136 | y : tensor
137 | point on poincare ball
138 | c : float|tensor
139 | ball negative curvature
140 | keepdim : bool
141 | retain the last dim? (default: false)
142 | Returns
143 | -------
144 | tensor
145 | geodesic distance between :math:`x` and :math:`y`
146 | """
147 | c = torch.as_tensor(c).type_as(x)
148 | return _dist(x, y, c, keepdim=keepdim)
149 |
150 |
151 | def _dist(x, y, c, keepdim: bool = False):
152 | sqrt_c = c ** 0.5
153 | dist_c = artanh(sqrt_c * _mobius_add(-x, y, c).norm(dim=-1, p=2, keepdim=keepdim))
154 | return dist_c * 2 / sqrt_c
155 |
156 | def mobius_add(x, y, *, c=1.0):
157 | r"""
158 | Mobius addition is a special operation in a hyperbolic space.
159 | .. math::
160 | x \oplus_c y = \frac{
161 | (1 + 2 c \langle x, y\rangle + c \|y\|^2_2) x + (1 - c \|x\|_2^2) y
162 | }{
163 | 1 + 2 c \langle x, y\rangle + c^2 \|x\|^2_2 \|y\|^2_2
164 | }
165 | In general this operation is not commutative:
166 | .. math::
167 | x \oplus_c y \ne y \oplus_c x
168 | But in some cases this property holds:
169 | * zero vector case
170 | .. math::
171 | \mathbf{0} \oplus_c x = x \oplus_c \mathbf{0}
172 | * zero negative curvature case that is same as Euclidean addition
173 | .. math::
174 | x \oplus_0 y = y \oplus_0 x
175 | Another usefull property is so called left-cancellation law:
176 | .. math::
177 | (-x) \oplus_c (x \oplus_c y) = y
178 | Parameters
179 | ----------
180 | x : tensor
181 | point on the Poincare ball
182 | y : tensor
183 | point on the Poincare ball
184 | c : float|tensor
185 | ball negative curvature
186 | Returns
187 | -------
188 | tensor
189 | the result of mobius addition
190 | """
191 | c = torch.as_tensor(c).type_as(x)
192 | return _mobius_add(x, y, c)
193 |
194 |
195 | def _mobius_add(x, y, c):
196 | x2 = x.pow(2).sum(dim=-1, keepdim=True)
197 | y2 = y.pow(2).sum(dim=-1, keepdim=True)
198 | xy = (x * y).sum(dim=-1, keepdim=True)
199 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
200 | denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
201 | return num / (denom + 1e-5)
202 |
203 |
204 | def expmap0(u, *, c=1.0):
205 | r"""
206 | Exponential map for Poincare ball model from :math:`0`.
207 | .. math::
208 | \operatorname{Exp}^c_0(u) = \tanh(\sqrt{c}/2 \|u\|_2) \frac{u}{\sqrt{c}\|u\|_2}
209 | Parameters
210 | ----------
211 | u : tensor
212 | speed vector on poincare ball
213 | c : float|tensor
214 | ball negative curvature
215 | Returns
216 | -------
217 | tensor
218 | :math:`\gamma_{0, u}(1)` end point
219 | """
220 | c = torch.as_tensor(c).type_as(u)
221 | return _expmap0(u, c)
222 |
223 |
224 | def _expmap0(u, c):
225 | sqrt_c = c ** 0.5
226 | u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
227 | gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
228 | return gamma_1
229 |
230 |
231 | def _hyperbolic_softmax(X, A, P, c):
232 | lambda_pkc = 2 / (1 - c * P.pow(2).sum(dim=1))
233 | k = lambda_pkc * torch.norm(A, dim=1) / torch.sqrt(c)
234 | mob_add = _mobius_addition_batch(-P, X, c)
235 | num = 2 * torch.sqrt(c) * torch.sum(mob_add * A.unsqueeze(1), dim=-1)
236 | denom = torch.norm(A, dim=1, keepdim=True) * (1 - c * mob_add.pow(2).sum(dim=2))
237 | logit = k.unsqueeze(1) * arsinh(num / denom)
238 | return logit.permute(1, 0)
239 |
240 | class Arsinh(torch.autograd.Function):
241 | @staticmethod
242 | def forward(ctx, x):
243 | ctx.save_for_backward(x)
244 | return (x + torch.sqrt_(1 + x.pow(2))).clamp_min_(1e-5).log_()
245 |
246 | @staticmethod
247 | def backward(ctx, grad_output):
248 | input, = ctx.saved_tensors
249 | return grad_output / (1 + input ** 2) ** 0.5
250 |
251 | class Artanh(torch.autograd.Function):
252 | @staticmethod
253 | def forward(ctx, x):
254 | x = x.clamp(-1 + 1e-5, 1 - 1e-5)
255 | ctx.save_for_backward(x)
256 | res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5)
257 | return res
258 |
259 | @staticmethod
260 | def backward(ctx, grad_output):
261 | input, = ctx.saved_tensors
262 | return grad_output / (1 - input ** 2)
263 |
264 | def artanh(x):
265 | return Artanh.apply(x)
266 |
267 |
268 | def arsinh(x):
269 | return Arsinh.apply(x)
270 |
--------------------------------------------------------------------------------
/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/poster.pdf
--------------------------------------------------------------------------------
/poster_jpg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tenglon/hyperbolic_action/4af6da6e85a8af33dd54955067efee2836508048/poster_jpg.jpg
--------------------------------------------------------------------------------