├── src ├── theta.npy ├── len_group.py ├── test.py ├── model.py ├── train.py └── data_loader.py ├── img_folder ├── fig_11.png ├── result.png └── architecture_10.png └── README.md /src/theta.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanglan0225/s3net/HEAD/src/theta.npy -------------------------------------------------------------------------------- /img_folder/fig_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanglan0225/s3net/HEAD/img_folder/fig_11.png -------------------------------------------------------------------------------- /img_folder/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanglan0225/s3net/HEAD/img_folder/result.png -------------------------------------------------------------------------------- /img_folder/architecture_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanglan0225/s3net/HEAD/img_folder/architecture_10.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # S3NET 2 | Pytorch implementation of "S3NET: GRAPH REPRESENTATIONAL NETWORK FOR SKETCH RECOGNITION"
3 | It has been accepted by ICME2020. 4 | 5 | ![Image](https://github.com/BUPTYangLan/s3net/blob/master/img_folder/architecture_10.png) 6 | 7 | ## Recognition Result 8 | 9 | 10 | 11 | ## Prerequisites 12 | - Linux (tested on Ubuntu 16.04)
13 | - Pytorch >= 1.2
14 | - NVIDIA GPU + CUDA CuDNN
15 | - torch_geometric [PyG](https://github.com/rusty1s/pytorch_geometric)
16 | 17 | ## Dataset 18 | - Sketch-RNN QuickDraw Dataset [Download](https://console.cloud.google.com/storage/quickdraw_dataset/sketchrnn) 19 | 20 | 21 | ## Conclusion 22 | Thank you and sorry for the bugs! 23 | If you would have further discussion on this code repository, please feel free to send email to LAN YANG. 24 | Email: ylan@bupt.edu.cn 25 | -------------------------------------------------------------------------------- /src/len_group.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def abs_data(data): 5 | abs_x = np.zeros(len(data)) 6 | abs_y = np.zeros(len(data)) 7 | abs_x[0] = data[0][0] 8 | abs_y[0] = data[0][1] 9 | ## convert the relative corrinates to the absolute corrdinates 10 | result = np.zeros((len(data),3)) 11 | for i in range(len(data)): 12 | if i != 0 : 13 | 14 | abs_x[i] = abs_x[i-1] + data[i][0] 15 | abs_y[i] = abs_y[i-1] + data[i][1] 16 | 17 | min_x = np.min(abs_x) 18 | min_y = np.min(abs_y) 19 | max_x = np.max(abs_x) 20 | max_y = np.max(abs_y) 21 | normalize_factor = np.max((max_x-min_x, max_y-min_y)) 22 | result[:,0] = abs_x/normalize_factor 23 | result[:,1] = abs_y/normalize_factor 24 | #np.divide(data[:,0], normalize_factor) 25 | result[:,2] = data[:,2] 26 | #data = ori_data 27 | return result 28 | 29 | def get_group(data, theta): 30 | absdata = abs_data(data) 31 | group_idx = 0 32 | length = 0 33 | label = 0 34 | stroke_id = 0 35 | group_result = np.zeros((len(data), 2), dtype=np.int) 36 | for i in range((len(data)-1)): 37 | if data[i][2] == 1: 38 | group_result[i, 1] = stroke_id 39 | stroke_id += 1 40 | dis = np.sqrt(np.sum(np.power(absdata[i + 1] - absdata[i], 2)[:2])) 41 | if dis >= 0.3 * theta: 42 | group_result[group_idx:(i + 1), 0] = label 43 | group_idx = i + 1 44 | label += 1 45 | length = 0 46 | else: 47 | group_result[i, 1] = stroke_id 48 | length += np.sqrt(np.sum(np.power(absdata[i+1] - absdata[i], 2)[:2])) 49 | if length >= theta: 50 | group_result[group_idx:(i+1), 0] = label 51 | group_idx = i+1 52 | label += 1 53 | length = 0 54 | 55 | group_result[group_idx:, 0] = label 56 | group_result[-1, 1] = stroke_id 57 | a = group_result.astype(int) 58 | return a 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import data_loader 4 | from torch_geometric.data import DataLoader 5 | import numpy as np 6 | import time 7 | import copy 8 | import os 9 | import model 10 | 11 | 12 | max_nodes = 400 13 | batch_size = 250 14 | num_class = 345 15 | input_chanel = 3 16 | hidden_chanel = 512 17 | fea_dim = 128 18 | hidden_chanel2 = 256 19 | hidden_chanel3 = 512 20 | out_chanel = 1024 21 | n_rnn_layer = 2 22 | num_epoches = 20 23 | learning_rate = 0.001 24 | data_dir = '/home/yl/sketchrnn.txt' 25 | class_list = '/home/yl/sketchrnn.txt' 26 | theta_list = np.load('/home/yl/theta.npy') 27 | model_path = '/home/yl/data/train_model/s3net/5.pkl' 28 | device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu') 29 | 30 | print('='*10, 'Initial Setting', '='*10) 31 | print('Batch Size: ', batch_size) 32 | print('Data_dir: ', data_dir) 33 | print('Input dim: ', input_chanel) 34 | print('hidden dim: ', hidden_chanel, ' ', hidden_chanel2, ' ', hidden_chanel3) 35 | print('Output dim: ', out_chanel) 36 | print('RNN Layers: ', n_rnn_layer) 37 | print('Num epochs:', num_epoches) 38 | print('Learning rate: ', learning_rate) 39 | print('Data_dir :', data_dir) 40 | print('Class info: ', class_list) 41 | print('Device: ', device) 42 | print('Train model save dir: ', model_path) 43 | 44 | 45 | 46 | """ 47 | dataset and data loader 48 | """ 49 | print('='*10, 'Start Data Loading', '='*10) 50 | test_dataset = data_loader.QuickDraw(data_dir, class_list, theta_list, type='test') 51 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 52 | 53 | print('='*10, 'Data Loaded', '='*10) 54 | 55 | 56 | """ 57 | model 58 | """ 59 | 60 | model = model.Net(input_chanel, hidden_chanel, fea_dim, hidden_chanel2, hidden_chanel3, out_chanel, num_class, n_rnn_layer).to(device) 61 | model.load_state_dict(torch.load(model_path, map_location=device)) 62 | 63 | """ 64 | test procedure 65 | """ 66 | 67 | model.eval() 68 | test_acc = 0.0 69 | test_loss = 0.0 70 | loss = 0.0 71 | 72 | for i, data in enumerate(test_loader): 73 | inputs = data 74 | label = data['y'].to(device).long() 75 | inputs = inputs.to(device) 76 | with torch.no_grad(): 77 | output, prediction, link_loss, ent_loss = model(inputs) 78 | loss = F.nll_loss(output, label.view(-1)) + link_loss + ent_loss 79 | test_loss = test_loss + data.y.size(0) * loss.item() 80 | _, preds = torch.max(output, 1) 81 | test_acc += torch.sum(preds == label.data) 82 | e = test_acc.double().cpu() 83 | 84 | 85 | g = test_loss / (len(test_dataset)) 86 | h = e / (len(test_dataset)) 87 | print('test: Loss:{:.6f}, Acc:{:.6f}'.format(g, h)) 88 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import DenseSAGEConv, dense_diff_pool 5 | from torch.autograd import Variable 6 | import torch_geometric.utils as utils 7 | 8 | 9 | 10 | device = torch.device('cuda:2') 11 | 12 | class GNN(torch.nn.Module): 13 | def __init__(self, 14 | in_channels, 15 | hidden_channels, 16 | out_channels, 17 | normalize=False, 18 | add_loop=False, 19 | lin=True): 20 | super(GNN, self).__init__() 21 | 22 | self.add_loop = add_loop 23 | 24 | self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize) 25 | self.bn1 = torch.nn.BatchNorm1d(hidden_channels) 26 | self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize) 27 | self.bn2 = torch.nn.BatchNorm1d(hidden_channels) 28 | self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize) 29 | self.bn3 = torch.nn.BatchNorm1d(out_channels) 30 | 31 | if lin is True: 32 | self.lin = torch.nn.Linear(2 * hidden_channels + out_channels, 33 | out_channels) 34 | else: 35 | self.lin = None 36 | 37 | def bn(self, i, x): 38 | batch_size, num_nodes, num_channels = x.size() 39 | 40 | x = x.view(-1, num_channels) 41 | x = getattr(self, 'bn{}'.format(i))(x) 42 | x = x.view(batch_size, num_nodes, num_channels) 43 | return x 44 | 45 | def forward(self, x, adj, mask=None): 46 | 47 | x0 = x 48 | x1 = self.bn(1, F.relu(self.conv1(x0, adj, mask, self.add_loop))) 49 | x2 = self.bn(2, F.relu(self.conv2(x1, adj, mask, self.add_loop))) 50 | x3 = self.bn(3, F.relu(self.conv3(x2, adj, mask, self.add_loop))) 51 | 52 | x = torch.cat([x1, x2, x3], dim=-1) 53 | 54 | if self.lin is not None: 55 | x = F.relu(self.lin(x)) 56 | 57 | return x 58 | 59 | 60 | class Net(torch.nn.Module): 61 | def __init__(self, input_chanel, hidden_chanel, fea_dim, hidden_chanel2,hidden_chanel3, out_chanel, num_class, n_rnn_layer): 62 | super(Net, self).__init__() 63 | 64 | num_nodes = 5 65 | 66 | self.gnn1_pool = GNN(fea_dim, hidden_chanel2, num_nodes) 67 | self.gnn1_embed = GNN(fea_dim, hidden_chanel2, hidden_chanel2, lin=False) 68 | 69 | self.gnn3_embed = GNN(3 * hidden_chanel2, hidden_chanel3, out_chanel, lin=False) 70 | 71 | self.lin1 = torch.nn.Linear(2 * hidden_chanel3 + out_chanel, out_chanel) 72 | self.lin2 = torch.nn.Linear(out_chanel, num_class) 73 | self.n_layer = n_rnn_layer 74 | self.n_classes = num_class 75 | self.lstm = nn.LSTM(input_chanel, hidden_chanel, n_rnn_layer, batch_first=True, bidirectional=True) 76 | self.dropout = nn.Dropout(0.5) 77 | 78 | self.fc = nn.Linear(hidden_chanel * 2, fea_dim) 79 | self.classify = nn.Linear(out_chanel, num_class) 80 | self.fea_dim = fea_dim 81 | 82 | 83 | def forward(self, data): 84 | seq_len = data['s'] 85 | 86 | inputs = data['c'].reshape((len(seq_len), -1, 3)) 87 | 88 | inputs = inputs.reshape((len(seq_len), -1, 3)) 89 | _, idx_sort = torch.sort(seq_len, dim=0, descending=True) 90 | _, idx_unsort = torch.sort(idx_sort, dim=0) 91 | input_x = inputs.index_select(0, Variable(idx_sort)) 92 | length_list = list(seq_len[idx_sort]) 93 | input_x = input_x.float() 94 | pack = nn.utils.rnn.pack_padded_sequence(input_x, length_list, batch_first=True) 95 | out, state = self.lstm(pack) 96 | del state 97 | un_padded = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) 98 | un_padded = un_padded[0].index_select(0, Variable(idx_unsort)) 99 | out = self.dropout(un_padded) 100 | feature = self.fc(out) 101 | batch_feature = None 102 | del out, pack, un_padded 103 | for i in range(data.num_graphs): 104 | emptyfeature = torch.zeros((1, self.fea_dim)).to(device) 105 | fea = torch.cat((feature[i][:(seq_len[i])], emptyfeature)) 106 | if batch_feature is None: 107 | batch_feature = fea 108 | else: 109 | batch_feature = torch.cat((batch_feature, fea)) 110 | 111 | data['x'] = batch_feature 112 | x, edge_index = data.x, data.edge_index 113 | dense_x = utils.to_dense_batch(x, batch=data.batch) 114 | x = dense_x[0] 115 | adj = utils.to_dense_adj(data.edge_index, batch=data.batch) 116 | s = self.gnn1_pool(x, adj) 117 | x = self.gnn1_embed(x, adj) 118 | x, adj, l1, e1 = dense_diff_pool(x, adj, s) 119 | 120 | 121 | x = self.gnn3_embed(x, adj) 122 | 123 | x = x.mean(dim=1) 124 | x1 = self.lin1(x) 125 | x = F.relu(x1) 126 | x = self.lin2(x) 127 | return F.log_softmax(x, dim=-1), x1, l1, e1 128 | 129 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import data_loader 4 | from torch_geometric.data import DataLoader 5 | import numpy as np 6 | import time 7 | import copy 8 | import os 9 | import model 10 | 11 | 12 | max_nodes = 400 13 | batch_size = 250 14 | num_class = 345 15 | input_chanel = 3 16 | hidden_chanel = 512 17 | fea_dim = 128 18 | hidden_chanel2 = 256 19 | hidden_chanel3 = 512 20 | out_chanel = 1024 21 | n_rnn_layer = 2 22 | num_epoches = 1 23 | learning_rate = 0.001 24 | data_dir = '/home/yl/sketchrnn.txt' 25 | class_list = '/home/yl/sketchrnn.txt' 26 | theta_list = np.load('/home/yl/theta.npy') 27 | train_model_save_dir = '/home/yl/data/train_model/s3net' 28 | save_dir = '/home/yl/data/model/s3net.pkl' 29 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 30 | 31 | print('='*10, 'Initial Setting', '='*10) 32 | print('Batch Size: ', batch_size) 33 | print('Data_dir: ', data_dir) 34 | print('Input dim: ', input_chanel) 35 | print('hidden dim: ', hidden_chanel, ' ', hidden_chanel2, ' ', hidden_chanel3) 36 | print('Output dim: ', out_chanel) 37 | print('RNN Layers: ', n_rnn_layer) 38 | print('Num epochs:', num_epoches) 39 | print('Learning rate: ', learning_rate) 40 | print('Data_dir :', data_dir) 41 | print('Class info: ', class_list) 42 | print('Device: ', device) 43 | print('Train model save dir: ', train_model_save_dir) 44 | print('Final model save path: ', save_dir) 45 | 46 | 47 | """ 48 | dataset and data loader 49 | """ 50 | print('='*10, 'Start Data Loading', '='*10) 51 | train_dataset = data_loader.QuickDraw(data_dir, class_list, theta_list, type='train') 52 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 53 | val_dataset = data_loader.QuickDraw(data_dir, class_list, theta_list, type='valid') 54 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) 55 | print('='*10, 'Data Loaded', '='*10) 56 | 57 | 58 | """ 59 | model and optimizer 60 | """ 61 | 62 | model = model.Net(input_chanel, hidden_chanel, fea_dim, hidden_chanel2, hidden_chanel3, out_chanel, num_class, n_rnn_layer).to(device) 63 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 64 | model.load_state_dict(torch.load('/home/yl/data/train_model/s3net/4.pkl')) 65 | 66 | 67 | train_loss = [] 68 | train_acc = [] 69 | valid_loss = [] 70 | valid_acc = [] 71 | best_acc = 0.0 72 | 73 | print('='*10, 'Start training', '='*10) 74 | 75 | for epoch in range(num_epoches): 76 | # print('=' * 10, 'Epoch ', epoch, '=' * 10) 77 | # if epoch == 5: 78 | # optimizer.param_groups[0]['lr'] = 1e-4 79 | # if epoch == 10: 80 | # optimizer.param_groups[0]['lr'] = 1e-5 81 | # if epoch == 15: 82 | # optimizer.param_groups[0]['lr'] = 1e-6 83 | print('learning rate: ', optimizer.param_groups[0]['lr']) 84 | 85 | since = time.time() 86 | running_acc = 0.0 87 | running_loss = 0.0 88 | val_loss = 0.0 89 | val_acc = 0.0 90 | model.train() 91 | for i, data in enumerate(train_loader): 92 | inputs = data 93 | label = data['y'].to(device).long() 94 | inputs = inputs.to(device) 95 | optimizer.zero_grad() 96 | output, prediction, link_loss, ent_loss = model(inputs) 97 | loss = F.nll_loss(output, label.view(-1)) + link_loss + ent_loss 98 | loss.backward() 99 | running_loss += data.y.size(0) * loss.item() 100 | optimizer.step() 101 | _, preds = torch.max(output, 1) 102 | running_acc += torch.sum(preds == label.data) 103 | if i % 10 == 0: 104 | print('the {}-th batch, loss: {:.6f}, acc: {:.6f}'.format(i, running_loss / (i*inputs.num_graphs + 1), 105 | running_acc.double().cpu() / (i*inputs.num_graphs + 1))) 106 | #return loss_all / len(train_dataset) 107 | j = running_loss / (len(train_dataset)) 108 | e = running_acc.double().cpu() / (len(train_dataset)) 109 | print('Finish {} epoch, Loss:{:.6f}, Acc:{:.6f}'.format(epoch + 1, j, e)) 110 | train_loss.append(j) 111 | train_acc.append(e) 112 | time_epoch = time.time() - since 113 | print("This epoch train costs time:{:.0f}m {:.0f}s".format(time_epoch // 60, time_epoch % 60)) 114 | 115 | model.eval() 116 | loss = 0.0 117 | for i, data in enumerate(val_loader): 118 | inputs = data 119 | label = data['y'].to(device).long() 120 | inputs = inputs.to(device) 121 | output, prediction, link_loss, ent_loss = model(inputs) 122 | loss = F.nll_loss(output, label.view(-1)) + link_loss + ent_loss 123 | val_loss = val_loss + data.y.size(0) * loss.item() 124 | _, preds = torch.max(output, 1) 125 | val_acc += torch.sum(preds == label.data) 126 | d = val_acc.double().cpu() 127 | save_path = os.path.join(train_model_save_dir, str(epoch) + '.pkl') 128 | torch.save(model.state_dict(), save_path) 129 | c = val_loss / (len(val_dataset)) 130 | f = d / (len(val_dataset)) 131 | if f > best_acc: 132 | best_acc = f 133 | best_model_wts = copy.deepcopy(model.state_dict()) 134 | print('val: Loss:{:.6f}, Acc:{:.6f}'.format(c, f)) 135 | valid_loss.append(c) 136 | valid_acc.append(f) 137 | time_epoch_val = time.time() - since 138 | del c, d, f 139 | print("This epoch val costs time:{:.0f}m {:.0f}s".format(time_epoch_val // 60, time_epoch_val % 60)) 140 | 141 | 142 | model.load_state_dict(best_model_wts) 143 | torch.save(model.state_dict(), save_dir) 144 | print('train_loss:{} train_acc:{} val_loss{} val_acc{}'.format(train_loss, train_acc, valid_loss, valid_acc)) 145 | -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import len_group 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | from torch_geometric.data import Data 7 | 8 | 9 | """ 10 | Define the device 11 | """ 12 | device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu') 13 | 14 | class QuickDraw(data.Dataset): 15 | def __init__(self, data_dir, class_list, theta_list, type): 16 | """ 17 | 18 | :param data_dir: txt file, the path of sketches 19 | :param class_list: txt file, the path of category info 20 | :param theta_list: numpy file, save theta for each category 21 | :param type: 'train', 'vaild', 'test' 22 | """ 23 | self.class_list = class_list 24 | self.type = type 25 | self.classes, self.class_to_idx = self.find_class(class_list) 26 | self.label_data_npy = np.zeros((1, 2)) # initial the array to save the label and sketches, dim0 is label, dim1 is storke-3 sketch 27 | self.theta_list = theta_list 28 | with open(data_dir) as class_url_list: 29 | for classes_list in class_url_list: 30 | 31 | self.classnpy = np.load(classes_list.replace('yanglan', 'yl/data').strip(), encoding='latin1', allow_pickle=True) 32 | classpath1, tempclass = os.path.split(classes_list) 33 | classname, exten = os.path.splitext(tempclass) 34 | self.label = self.class_to_idx[classname] 35 | 36 | if self.type == 'train': 37 | np.random.shuffle(self.classnpy['train']) 38 | self.coordinate = self.classnpy['train'][:9000] 39 | self.label_np = self.label * np.ones((9000, 1)) 40 | label_data_npy = np.c_[self.label_np, self.coordinate.reshape(9000, -1)] 41 | self.label_data_npy = np.r_[self.label_data_npy, label_data_npy] 42 | 43 | if self.type == 'valid': 44 | self.coordinate = self.classnpy['valid'] 45 | self.label_np = self.label * np.ones((2500, 1)) 46 | label_data_npy = np.c_[self.label_np, self.coordinate.reshape(2500, -1)] 47 | self.label_data_npy = np.r_[self.label_data_npy, label_data_npy] 48 | 49 | if self.type == 'test': 50 | self.coordinate = self.classnpy['test'] 51 | self.label_np = self.label * np.ones((2500, 1)) 52 | label_data_npy = np.c_[self.label_np, self.coordinate.reshape(2500, -1)] 53 | self.label_data_npy = np.r_[self.label_data_npy, label_data_npy] 54 | 55 | self.label_data_npy1 = self.label_data_npy[1:, :] # remove the first useless element 56 | self.max_length, self.max_groupnum = self.max_len() 57 | 58 | 59 | 60 | 61 | def __len__(self): 62 | return len(self.label_data_npy1) 63 | 64 | 65 | def __getitem__(self, item): 66 | tempcoordinate = self.label_data_npy1[item] 67 | label = tempcoordinate[0] 68 | coordinate = tempcoordinate[1] # original coordinate 69 | coordinate2 = np.zeros((self.max_length, 3)) 70 | coordinate2[:len(coordinate)] = coordinate 71 | c = torch.from_numpy(coordinate2).to(device) 72 | groupid = len_group.get_group(coordinate, self.theta_list[int(label)]) 73 | src, dst, groupNum = self.get_affinity_matrix(torch.squeeze(torch.from_numpy(groupid))) 74 | edge_idx = torch.tensor([np.concatenate((src,dst)), np.concatenate((dst,src))],dtype=torch.long) 75 | feature = torch.zeros((len(coordinate)+1, 128)) 76 | data = Data(x=feature, edge_index=edge_idx, y=label, s=len(coordinate), g=int(groupNum.item()), c=c) 77 | del feature, edge_idx 78 | 79 | 80 | return data 81 | 82 | 83 | def find_class(self, dir): 84 | with open(dir) as class_url_list: 85 | classlist = [] 86 | for classpath in class_url_list: 87 | classpath1, tempclass = os.path.split(classpath) 88 | classname, exten = os.path.splitext(tempclass) 89 | classlist.append(classname) 90 | classlist.sort() 91 | class_to_idx = {classlist[i]: i for i in range(len(classlist))} 92 | return classlist, class_to_idx 93 | 94 | def max_len(self): 95 | max_len = 0 96 | pos = 0 97 | for i in range(len(self.label_data_npy1)): 98 | 99 | if len(self.label_data_npy1[i][1]) >= max_len: 100 | max_len = len(self.label_data_npy1[i][1]) 101 | pos = i 102 | groupid = len_group.get_group(self.label_data_npy1[pos][1], self.theta_list[int(self.label_data_npy1[pos][0])]) 103 | src, dst, groupNum = self.get_affinity_matrix(torch.squeeze(torch.from_numpy(groupid))) 104 | return max_len, groupNum 105 | 106 | def get_affinity_matrix(self, groupId): 107 | groupnum = torch.max(groupId) 108 | src = [] 109 | dst = [] 110 | repre_point = [] 111 | id = 0 112 | 113 | # select the first point of each stroke as the representative point 114 | for i in range(len(groupId)): 115 | if i == 0: 116 | repre_point.append(i) 117 | id = groupId[i][0] 118 | else: 119 | if groupId[i][0] != id: # next group 120 | repre_point.append(i) 121 | id = groupId[i][0] 122 | repre_point.append(len(groupId)-1) 123 | 124 | 125 | # build edges of rule 1 126 | for i in range(len(repre_point)-1): 127 | for j in range(repre_point[i]+1, repre_point[i+1]+1): 128 | src.append(i) 129 | dst.append(j) 130 | 131 | 132 | # build edges of rule 2 133 | for i in range(len(repre_point)-1): 134 | if groupId[repre_point[i]][1] == groupId[repre_point[i+1]][1]: 135 | src.append(repre_point[i]) 136 | dst.append(repre_point[i+1]) 137 | 138 | 139 | # build edges of rule 3 140 | for i in range(len(repre_point)-1): 141 | src.append(len(groupId)) 142 | dst.append(repre_point[i]) 143 | 144 | 145 | return np.array(src), np.array(dst), groupnum + 1 146 | 147 | 148 | 149 | def abs_data(self, data): 150 | abs_x = np.zeros(len(data)) 151 | abs_y = np.zeros(len(data)) 152 | abs_x[0] = data[0][0] 153 | abs_y[0] = data[0][1] 154 | ## convert the relative corrinates to the absolute corrdinates 155 | result = np.zeros((len(data), 3)) 156 | for i in range(len(data)): 157 | if i != 0: 158 | abs_x[i] = abs_x[i - 1] + data[i][0] 159 | abs_y[i] = abs_y[i - 1] + data[i][1] 160 | 161 | min_x = np.min(abs_x) 162 | min_y = np.min(abs_y) 163 | max_x = np.max(abs_x) 164 | max_y = np.max(abs_y) 165 | normalize_factor = np.max((max_x - min_x, max_y - min_y)) 166 | result[:, 0] = abs_x / normalize_factor 167 | result[:, 1] = abs_y / normalize_factor 168 | # np.divide(data[:,0], normalize_factor) 169 | result[:, 2] = data[:, 2] 170 | # data = ori_data 171 | return result 172 | --------------------------------------------------------------------------------