├── src
├── theta.npy
├── len_group.py
├── test.py
├── model.py
├── train.py
└── data_loader.py
├── img_folder
├── fig_11.png
├── result.png
└── architecture_10.png
└── README.md
/src/theta.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanglan0225/s3net/HEAD/src/theta.npy
--------------------------------------------------------------------------------
/img_folder/fig_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanglan0225/s3net/HEAD/img_folder/fig_11.png
--------------------------------------------------------------------------------
/img_folder/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanglan0225/s3net/HEAD/img_folder/result.png
--------------------------------------------------------------------------------
/img_folder/architecture_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yanglan0225/s3net/HEAD/img_folder/architecture_10.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # S3NET
2 | Pytorch implementation of "S3NET: GRAPH REPRESENTATIONAL NETWORK FOR SKETCH RECOGNITION"
3 | It has been accepted by ICME2020.
4 |
5 | 
6 |
7 | ## Recognition Result
8 |
9 |
10 |
11 | ## Prerequisites
12 | - Linux (tested on Ubuntu 16.04)
13 | - Pytorch >= 1.2
14 | - NVIDIA GPU + CUDA CuDNN
15 | - torch_geometric [PyG](https://github.com/rusty1s/pytorch_geometric)
16 |
17 | ## Dataset
18 | - Sketch-RNN QuickDraw Dataset [Download](https://console.cloud.google.com/storage/quickdraw_dataset/sketchrnn)
19 |
20 |
21 | ## Conclusion
22 | Thank you and sorry for the bugs!
23 | If you would have further discussion on this code repository, please feel free to send email to LAN YANG.
24 | Email: ylan@bupt.edu.cn
25 |
--------------------------------------------------------------------------------
/src/len_group.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def abs_data(data):
5 | abs_x = np.zeros(len(data))
6 | abs_y = np.zeros(len(data))
7 | abs_x[0] = data[0][0]
8 | abs_y[0] = data[0][1]
9 | ## convert the relative corrinates to the absolute corrdinates
10 | result = np.zeros((len(data),3))
11 | for i in range(len(data)):
12 | if i != 0 :
13 |
14 | abs_x[i] = abs_x[i-1] + data[i][0]
15 | abs_y[i] = abs_y[i-1] + data[i][1]
16 |
17 | min_x = np.min(abs_x)
18 | min_y = np.min(abs_y)
19 | max_x = np.max(abs_x)
20 | max_y = np.max(abs_y)
21 | normalize_factor = np.max((max_x-min_x, max_y-min_y))
22 | result[:,0] = abs_x/normalize_factor
23 | result[:,1] = abs_y/normalize_factor
24 | #np.divide(data[:,0], normalize_factor)
25 | result[:,2] = data[:,2]
26 | #data = ori_data
27 | return result
28 |
29 | def get_group(data, theta):
30 | absdata = abs_data(data)
31 | group_idx = 0
32 | length = 0
33 | label = 0
34 | stroke_id = 0
35 | group_result = np.zeros((len(data), 2), dtype=np.int)
36 | for i in range((len(data)-1)):
37 | if data[i][2] == 1:
38 | group_result[i, 1] = stroke_id
39 | stroke_id += 1
40 | dis = np.sqrt(np.sum(np.power(absdata[i + 1] - absdata[i], 2)[:2]))
41 | if dis >= 0.3 * theta:
42 | group_result[group_idx:(i + 1), 0] = label
43 | group_idx = i + 1
44 | label += 1
45 | length = 0
46 | else:
47 | group_result[i, 1] = stroke_id
48 | length += np.sqrt(np.sum(np.power(absdata[i+1] - absdata[i], 2)[:2]))
49 | if length >= theta:
50 | group_result[group_idx:(i+1), 0] = label
51 | group_idx = i+1
52 | label += 1
53 | length = 0
54 |
55 | group_result[group_idx:, 0] = label
56 | group_result[-1, 1] = stroke_id
57 | a = group_result.astype(int)
58 | return a
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import data_loader
4 | from torch_geometric.data import DataLoader
5 | import numpy as np
6 | import time
7 | import copy
8 | import os
9 | import model
10 |
11 |
12 | max_nodes = 400
13 | batch_size = 250
14 | num_class = 345
15 | input_chanel = 3
16 | hidden_chanel = 512
17 | fea_dim = 128
18 | hidden_chanel2 = 256
19 | hidden_chanel3 = 512
20 | out_chanel = 1024
21 | n_rnn_layer = 2
22 | num_epoches = 20
23 | learning_rate = 0.001
24 | data_dir = '/home/yl/sketchrnn.txt'
25 | class_list = '/home/yl/sketchrnn.txt'
26 | theta_list = np.load('/home/yl/theta.npy')
27 | model_path = '/home/yl/data/train_model/s3net/5.pkl'
28 | device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
29 |
30 | print('='*10, 'Initial Setting', '='*10)
31 | print('Batch Size: ', batch_size)
32 | print('Data_dir: ', data_dir)
33 | print('Input dim: ', input_chanel)
34 | print('hidden dim: ', hidden_chanel, ' ', hidden_chanel2, ' ', hidden_chanel3)
35 | print('Output dim: ', out_chanel)
36 | print('RNN Layers: ', n_rnn_layer)
37 | print('Num epochs:', num_epoches)
38 | print('Learning rate: ', learning_rate)
39 | print('Data_dir :', data_dir)
40 | print('Class info: ', class_list)
41 | print('Device: ', device)
42 | print('Train model save dir: ', model_path)
43 |
44 |
45 |
46 | """
47 | dataset and data loader
48 | """
49 | print('='*10, 'Start Data Loading', '='*10)
50 | test_dataset = data_loader.QuickDraw(data_dir, class_list, theta_list, type='test')
51 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
52 |
53 | print('='*10, 'Data Loaded', '='*10)
54 |
55 |
56 | """
57 | model
58 | """
59 |
60 | model = model.Net(input_chanel, hidden_chanel, fea_dim, hidden_chanel2, hidden_chanel3, out_chanel, num_class, n_rnn_layer).to(device)
61 | model.load_state_dict(torch.load(model_path, map_location=device))
62 |
63 | """
64 | test procedure
65 | """
66 |
67 | model.eval()
68 | test_acc = 0.0
69 | test_loss = 0.0
70 | loss = 0.0
71 |
72 | for i, data in enumerate(test_loader):
73 | inputs = data
74 | label = data['y'].to(device).long()
75 | inputs = inputs.to(device)
76 | with torch.no_grad():
77 | output, prediction, link_loss, ent_loss = model(inputs)
78 | loss = F.nll_loss(output, label.view(-1)) + link_loss + ent_loss
79 | test_loss = test_loss + data.y.size(0) * loss.item()
80 | _, preds = torch.max(output, 1)
81 | test_acc += torch.sum(preds == label.data)
82 | e = test_acc.double().cpu()
83 |
84 |
85 | g = test_loss / (len(test_dataset))
86 | h = e / (len(test_dataset))
87 | print('test: Loss:{:.6f}, Acc:{:.6f}'.format(g, h))
88 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import DenseSAGEConv, dense_diff_pool
5 | from torch.autograd import Variable
6 | import torch_geometric.utils as utils
7 |
8 |
9 |
10 | device = torch.device('cuda:2')
11 |
12 | class GNN(torch.nn.Module):
13 | def __init__(self,
14 | in_channels,
15 | hidden_channels,
16 | out_channels,
17 | normalize=False,
18 | add_loop=False,
19 | lin=True):
20 | super(GNN, self).__init__()
21 |
22 | self.add_loop = add_loop
23 |
24 | self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize)
25 | self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
26 | self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize)
27 | self.bn2 = torch.nn.BatchNorm1d(hidden_channels)
28 | self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize)
29 | self.bn3 = torch.nn.BatchNorm1d(out_channels)
30 |
31 | if lin is True:
32 | self.lin = torch.nn.Linear(2 * hidden_channels + out_channels,
33 | out_channels)
34 | else:
35 | self.lin = None
36 |
37 | def bn(self, i, x):
38 | batch_size, num_nodes, num_channels = x.size()
39 |
40 | x = x.view(-1, num_channels)
41 | x = getattr(self, 'bn{}'.format(i))(x)
42 | x = x.view(batch_size, num_nodes, num_channels)
43 | return x
44 |
45 | def forward(self, x, adj, mask=None):
46 |
47 | x0 = x
48 | x1 = self.bn(1, F.relu(self.conv1(x0, adj, mask, self.add_loop)))
49 | x2 = self.bn(2, F.relu(self.conv2(x1, adj, mask, self.add_loop)))
50 | x3 = self.bn(3, F.relu(self.conv3(x2, adj, mask, self.add_loop)))
51 |
52 | x = torch.cat([x1, x2, x3], dim=-1)
53 |
54 | if self.lin is not None:
55 | x = F.relu(self.lin(x))
56 |
57 | return x
58 |
59 |
60 | class Net(torch.nn.Module):
61 | def __init__(self, input_chanel, hidden_chanel, fea_dim, hidden_chanel2,hidden_chanel3, out_chanel, num_class, n_rnn_layer):
62 | super(Net, self).__init__()
63 |
64 | num_nodes = 5
65 |
66 | self.gnn1_pool = GNN(fea_dim, hidden_chanel2, num_nodes)
67 | self.gnn1_embed = GNN(fea_dim, hidden_chanel2, hidden_chanel2, lin=False)
68 |
69 | self.gnn3_embed = GNN(3 * hidden_chanel2, hidden_chanel3, out_chanel, lin=False)
70 |
71 | self.lin1 = torch.nn.Linear(2 * hidden_chanel3 + out_chanel, out_chanel)
72 | self.lin2 = torch.nn.Linear(out_chanel, num_class)
73 | self.n_layer = n_rnn_layer
74 | self.n_classes = num_class
75 | self.lstm = nn.LSTM(input_chanel, hidden_chanel, n_rnn_layer, batch_first=True, bidirectional=True)
76 | self.dropout = nn.Dropout(0.5)
77 |
78 | self.fc = nn.Linear(hidden_chanel * 2, fea_dim)
79 | self.classify = nn.Linear(out_chanel, num_class)
80 | self.fea_dim = fea_dim
81 |
82 |
83 | def forward(self, data):
84 | seq_len = data['s']
85 |
86 | inputs = data['c'].reshape((len(seq_len), -1, 3))
87 |
88 | inputs = inputs.reshape((len(seq_len), -1, 3))
89 | _, idx_sort = torch.sort(seq_len, dim=0, descending=True)
90 | _, idx_unsort = torch.sort(idx_sort, dim=0)
91 | input_x = inputs.index_select(0, Variable(idx_sort))
92 | length_list = list(seq_len[idx_sort])
93 | input_x = input_x.float()
94 | pack = nn.utils.rnn.pack_padded_sequence(input_x, length_list, batch_first=True)
95 | out, state = self.lstm(pack)
96 | del state
97 | un_padded = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
98 | un_padded = un_padded[0].index_select(0, Variable(idx_unsort))
99 | out = self.dropout(un_padded)
100 | feature = self.fc(out)
101 | batch_feature = None
102 | del out, pack, un_padded
103 | for i in range(data.num_graphs):
104 | emptyfeature = torch.zeros((1, self.fea_dim)).to(device)
105 | fea = torch.cat((feature[i][:(seq_len[i])], emptyfeature))
106 | if batch_feature is None:
107 | batch_feature = fea
108 | else:
109 | batch_feature = torch.cat((batch_feature, fea))
110 |
111 | data['x'] = batch_feature
112 | x, edge_index = data.x, data.edge_index
113 | dense_x = utils.to_dense_batch(x, batch=data.batch)
114 | x = dense_x[0]
115 | adj = utils.to_dense_adj(data.edge_index, batch=data.batch)
116 | s = self.gnn1_pool(x, adj)
117 | x = self.gnn1_embed(x, adj)
118 | x, adj, l1, e1 = dense_diff_pool(x, adj, s)
119 |
120 |
121 | x = self.gnn3_embed(x, adj)
122 |
123 | x = x.mean(dim=1)
124 | x1 = self.lin1(x)
125 | x = F.relu(x1)
126 | x = self.lin2(x)
127 | return F.log_softmax(x, dim=-1), x1, l1, e1
128 |
129 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import data_loader
4 | from torch_geometric.data import DataLoader
5 | import numpy as np
6 | import time
7 | import copy
8 | import os
9 | import model
10 |
11 |
12 | max_nodes = 400
13 | batch_size = 250
14 | num_class = 345
15 | input_chanel = 3
16 | hidden_chanel = 512
17 | fea_dim = 128
18 | hidden_chanel2 = 256
19 | hidden_chanel3 = 512
20 | out_chanel = 1024
21 | n_rnn_layer = 2
22 | num_epoches = 1
23 | learning_rate = 0.001
24 | data_dir = '/home/yl/sketchrnn.txt'
25 | class_list = '/home/yl/sketchrnn.txt'
26 | theta_list = np.load('/home/yl/theta.npy')
27 | train_model_save_dir = '/home/yl/data/train_model/s3net'
28 | save_dir = '/home/yl/data/model/s3net.pkl'
29 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
30 |
31 | print('='*10, 'Initial Setting', '='*10)
32 | print('Batch Size: ', batch_size)
33 | print('Data_dir: ', data_dir)
34 | print('Input dim: ', input_chanel)
35 | print('hidden dim: ', hidden_chanel, ' ', hidden_chanel2, ' ', hidden_chanel3)
36 | print('Output dim: ', out_chanel)
37 | print('RNN Layers: ', n_rnn_layer)
38 | print('Num epochs:', num_epoches)
39 | print('Learning rate: ', learning_rate)
40 | print('Data_dir :', data_dir)
41 | print('Class info: ', class_list)
42 | print('Device: ', device)
43 | print('Train model save dir: ', train_model_save_dir)
44 | print('Final model save path: ', save_dir)
45 |
46 |
47 | """
48 | dataset and data loader
49 | """
50 | print('='*10, 'Start Data Loading', '='*10)
51 | train_dataset = data_loader.QuickDraw(data_dir, class_list, theta_list, type='train')
52 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
53 | val_dataset = data_loader.QuickDraw(data_dir, class_list, theta_list, type='valid')
54 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
55 | print('='*10, 'Data Loaded', '='*10)
56 |
57 |
58 | """
59 | model and optimizer
60 | """
61 |
62 | model = model.Net(input_chanel, hidden_chanel, fea_dim, hidden_chanel2, hidden_chanel3, out_chanel, num_class, n_rnn_layer).to(device)
63 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
64 | model.load_state_dict(torch.load('/home/yl/data/train_model/s3net/4.pkl'))
65 |
66 |
67 | train_loss = []
68 | train_acc = []
69 | valid_loss = []
70 | valid_acc = []
71 | best_acc = 0.0
72 |
73 | print('='*10, 'Start training', '='*10)
74 |
75 | for epoch in range(num_epoches):
76 | # print('=' * 10, 'Epoch ', epoch, '=' * 10)
77 | # if epoch == 5:
78 | # optimizer.param_groups[0]['lr'] = 1e-4
79 | # if epoch == 10:
80 | # optimizer.param_groups[0]['lr'] = 1e-5
81 | # if epoch == 15:
82 | # optimizer.param_groups[0]['lr'] = 1e-6
83 | print('learning rate: ', optimizer.param_groups[0]['lr'])
84 |
85 | since = time.time()
86 | running_acc = 0.0
87 | running_loss = 0.0
88 | val_loss = 0.0
89 | val_acc = 0.0
90 | model.train()
91 | for i, data in enumerate(train_loader):
92 | inputs = data
93 | label = data['y'].to(device).long()
94 | inputs = inputs.to(device)
95 | optimizer.zero_grad()
96 | output, prediction, link_loss, ent_loss = model(inputs)
97 | loss = F.nll_loss(output, label.view(-1)) + link_loss + ent_loss
98 | loss.backward()
99 | running_loss += data.y.size(0) * loss.item()
100 | optimizer.step()
101 | _, preds = torch.max(output, 1)
102 | running_acc += torch.sum(preds == label.data)
103 | if i % 10 == 0:
104 | print('the {}-th batch, loss: {:.6f}, acc: {:.6f}'.format(i, running_loss / (i*inputs.num_graphs + 1),
105 | running_acc.double().cpu() / (i*inputs.num_graphs + 1)))
106 | #return loss_all / len(train_dataset)
107 | j = running_loss / (len(train_dataset))
108 | e = running_acc.double().cpu() / (len(train_dataset))
109 | print('Finish {} epoch, Loss:{:.6f}, Acc:{:.6f}'.format(epoch + 1, j, e))
110 | train_loss.append(j)
111 | train_acc.append(e)
112 | time_epoch = time.time() - since
113 | print("This epoch train costs time:{:.0f}m {:.0f}s".format(time_epoch // 60, time_epoch % 60))
114 |
115 | model.eval()
116 | loss = 0.0
117 | for i, data in enumerate(val_loader):
118 | inputs = data
119 | label = data['y'].to(device).long()
120 | inputs = inputs.to(device)
121 | output, prediction, link_loss, ent_loss = model(inputs)
122 | loss = F.nll_loss(output, label.view(-1)) + link_loss + ent_loss
123 | val_loss = val_loss + data.y.size(0) * loss.item()
124 | _, preds = torch.max(output, 1)
125 | val_acc += torch.sum(preds == label.data)
126 | d = val_acc.double().cpu()
127 | save_path = os.path.join(train_model_save_dir, str(epoch) + '.pkl')
128 | torch.save(model.state_dict(), save_path)
129 | c = val_loss / (len(val_dataset))
130 | f = d / (len(val_dataset))
131 | if f > best_acc:
132 | best_acc = f
133 | best_model_wts = copy.deepcopy(model.state_dict())
134 | print('val: Loss:{:.6f}, Acc:{:.6f}'.format(c, f))
135 | valid_loss.append(c)
136 | valid_acc.append(f)
137 | time_epoch_val = time.time() - since
138 | del c, d, f
139 | print("This epoch val costs time:{:.0f}m {:.0f}s".format(time_epoch_val // 60, time_epoch_val % 60))
140 |
141 |
142 | model.load_state_dict(best_model_wts)
143 | torch.save(model.state_dict(), save_dir)
144 | print('train_loss:{} train_acc:{} val_loss{} val_acc{}'.format(train_loss, train_acc, valid_loss, valid_acc))
145 |
--------------------------------------------------------------------------------
/src/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import len_group
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | from torch_geometric.data import Data
7 |
8 |
9 | """
10 | Define the device
11 | """
12 | device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
13 |
14 | class QuickDraw(data.Dataset):
15 | def __init__(self, data_dir, class_list, theta_list, type):
16 | """
17 |
18 | :param data_dir: txt file, the path of sketches
19 | :param class_list: txt file, the path of category info
20 | :param theta_list: numpy file, save theta for each category
21 | :param type: 'train', 'vaild', 'test'
22 | """
23 | self.class_list = class_list
24 | self.type = type
25 | self.classes, self.class_to_idx = self.find_class(class_list)
26 | self.label_data_npy = np.zeros((1, 2)) # initial the array to save the label and sketches, dim0 is label, dim1 is storke-3 sketch
27 | self.theta_list = theta_list
28 | with open(data_dir) as class_url_list:
29 | for classes_list in class_url_list:
30 |
31 | self.classnpy = np.load(classes_list.replace('yanglan', 'yl/data').strip(), encoding='latin1', allow_pickle=True)
32 | classpath1, tempclass = os.path.split(classes_list)
33 | classname, exten = os.path.splitext(tempclass)
34 | self.label = self.class_to_idx[classname]
35 |
36 | if self.type == 'train':
37 | np.random.shuffle(self.classnpy['train'])
38 | self.coordinate = self.classnpy['train'][:9000]
39 | self.label_np = self.label * np.ones((9000, 1))
40 | label_data_npy = np.c_[self.label_np, self.coordinate.reshape(9000, -1)]
41 | self.label_data_npy = np.r_[self.label_data_npy, label_data_npy]
42 |
43 | if self.type == 'valid':
44 | self.coordinate = self.classnpy['valid']
45 | self.label_np = self.label * np.ones((2500, 1))
46 | label_data_npy = np.c_[self.label_np, self.coordinate.reshape(2500, -1)]
47 | self.label_data_npy = np.r_[self.label_data_npy, label_data_npy]
48 |
49 | if self.type == 'test':
50 | self.coordinate = self.classnpy['test']
51 | self.label_np = self.label * np.ones((2500, 1))
52 | label_data_npy = np.c_[self.label_np, self.coordinate.reshape(2500, -1)]
53 | self.label_data_npy = np.r_[self.label_data_npy, label_data_npy]
54 |
55 | self.label_data_npy1 = self.label_data_npy[1:, :] # remove the first useless element
56 | self.max_length, self.max_groupnum = self.max_len()
57 |
58 |
59 |
60 |
61 | def __len__(self):
62 | return len(self.label_data_npy1)
63 |
64 |
65 | def __getitem__(self, item):
66 | tempcoordinate = self.label_data_npy1[item]
67 | label = tempcoordinate[0]
68 | coordinate = tempcoordinate[1] # original coordinate
69 | coordinate2 = np.zeros((self.max_length, 3))
70 | coordinate2[:len(coordinate)] = coordinate
71 | c = torch.from_numpy(coordinate2).to(device)
72 | groupid = len_group.get_group(coordinate, self.theta_list[int(label)])
73 | src, dst, groupNum = self.get_affinity_matrix(torch.squeeze(torch.from_numpy(groupid)))
74 | edge_idx = torch.tensor([np.concatenate((src,dst)), np.concatenate((dst,src))],dtype=torch.long)
75 | feature = torch.zeros((len(coordinate)+1, 128))
76 | data = Data(x=feature, edge_index=edge_idx, y=label, s=len(coordinate), g=int(groupNum.item()), c=c)
77 | del feature, edge_idx
78 |
79 |
80 | return data
81 |
82 |
83 | def find_class(self, dir):
84 | with open(dir) as class_url_list:
85 | classlist = []
86 | for classpath in class_url_list:
87 | classpath1, tempclass = os.path.split(classpath)
88 | classname, exten = os.path.splitext(tempclass)
89 | classlist.append(classname)
90 | classlist.sort()
91 | class_to_idx = {classlist[i]: i for i in range(len(classlist))}
92 | return classlist, class_to_idx
93 |
94 | def max_len(self):
95 | max_len = 0
96 | pos = 0
97 | for i in range(len(self.label_data_npy1)):
98 |
99 | if len(self.label_data_npy1[i][1]) >= max_len:
100 | max_len = len(self.label_data_npy1[i][1])
101 | pos = i
102 | groupid = len_group.get_group(self.label_data_npy1[pos][1], self.theta_list[int(self.label_data_npy1[pos][0])])
103 | src, dst, groupNum = self.get_affinity_matrix(torch.squeeze(torch.from_numpy(groupid)))
104 | return max_len, groupNum
105 |
106 | def get_affinity_matrix(self, groupId):
107 | groupnum = torch.max(groupId)
108 | src = []
109 | dst = []
110 | repre_point = []
111 | id = 0
112 |
113 | # select the first point of each stroke as the representative point
114 | for i in range(len(groupId)):
115 | if i == 0:
116 | repre_point.append(i)
117 | id = groupId[i][0]
118 | else:
119 | if groupId[i][0] != id: # next group
120 | repre_point.append(i)
121 | id = groupId[i][0]
122 | repre_point.append(len(groupId)-1)
123 |
124 |
125 | # build edges of rule 1
126 | for i in range(len(repre_point)-1):
127 | for j in range(repre_point[i]+1, repre_point[i+1]+1):
128 | src.append(i)
129 | dst.append(j)
130 |
131 |
132 | # build edges of rule 2
133 | for i in range(len(repre_point)-1):
134 | if groupId[repre_point[i]][1] == groupId[repre_point[i+1]][1]:
135 | src.append(repre_point[i])
136 | dst.append(repre_point[i+1])
137 |
138 |
139 | # build edges of rule 3
140 | for i in range(len(repre_point)-1):
141 | src.append(len(groupId))
142 | dst.append(repre_point[i])
143 |
144 |
145 | return np.array(src), np.array(dst), groupnum + 1
146 |
147 |
148 |
149 | def abs_data(self, data):
150 | abs_x = np.zeros(len(data))
151 | abs_y = np.zeros(len(data))
152 | abs_x[0] = data[0][0]
153 | abs_y[0] = data[0][1]
154 | ## convert the relative corrinates to the absolute corrdinates
155 | result = np.zeros((len(data), 3))
156 | for i in range(len(data)):
157 | if i != 0:
158 | abs_x[i] = abs_x[i - 1] + data[i][0]
159 | abs_y[i] = abs_y[i - 1] + data[i][1]
160 |
161 | min_x = np.min(abs_x)
162 | min_y = np.min(abs_y)
163 | max_x = np.max(abs_x)
164 | max_y = np.max(abs_y)
165 | normalize_factor = np.max((max_x - min_x, max_y - min_y))
166 | result[:, 0] = abs_x / normalize_factor
167 | result[:, 1] = abs_y / normalize_factor
168 | # np.divide(data[:,0], normalize_factor)
169 | result[:, 2] = data[:, 2]
170 | # data = ori_data
171 | return result
172 |
--------------------------------------------------------------------------------