├── input ├── acmv9.mat ├── dblpv7.mat └── citationv1.mat ├── inits.py ├── .gitignore ├── metrics.py ├── README.md ├── layers.py ├── utils.py ├── train_WD.py └── models.py /input/acmv9.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daiquanyu/AdaGCN_TKDE/HEAD/input/acmv9.mat -------------------------------------------------------------------------------- /input/dblpv7.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daiquanyu/AdaGCN_TKDE/HEAD/input/dblpv7.mat -------------------------------------------------------------------------------- /input/citationv1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daiquanyu/AdaGCN_TKDE/HEAD/input/citationv1.mat -------------------------------------------------------------------------------- /inits.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def glorot(shape, name=None): 6 | """Glorot & Bengio (AISTATS 2010) init.""" 7 | init_range = np.sqrt(6.0/(shape[0]+shape[1])) 8 | initial = tf.random_uniform(shape, minval=-init_range, maxval=init_range, dtype=tf.float32) 9 | return tf.Variable(initial, name=name) 10 | 11 | 12 | def zeros(shape, name=None): 13 | """All zeros.""" 14 | initial = tf.zeros(shape, dtype=tf.float32) 15 | return tf.Variable(initial, name=name) 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | *.idea 3 | *.png 4 | *.pdf 5 | tmp/ 6 | *.txt 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # IPython Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | *.pickle 99 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import matplotlib.pyplot as plt 3 | from matplotlib.ticker import MultipleLocator, FormatStrFormatter 4 | import numpy as np 5 | 6 | 7 | def masked_sigmoid_cross_entropy(preds, labels, mask): 8 | """Sigmoid cross-entropy loss with masking""" 9 | # loss has the same shape as logits: 1 loss per class and per sample in the batch 10 | loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=preds, labels=labels) 11 | loss = tf.reduce_sum(loss, axis=1) 12 | 13 | mask = tf.cast(mask, dtype=tf.float32) 14 | mask /= tf.reduce_mean(mask) 15 | loss *= mask 16 | return tf.reduce_mean(loss) 17 | 18 | def multi_label_hot(prediction, threshold=0.5): 19 | """ 20 | Examples: 21 | prediction = tf.sigmoid(logits) 22 | one_hot_prediction = multi_label_hot(prediction) 23 | """ 24 | prediction = tf.cast(prediction, tf.float32) 25 | threshold = float(threshold) 26 | return tf.cast(tf.greater_equal(prediction, threshold), tf.int64) 27 | 28 | 29 | def f1_score(y_true, y_pred, mask, epsilon=1e-8): 30 | f1s = [0, 0, 0] 31 | 32 | y_true = tf.cast(tf.boolean_mask(y_true, mask, axis=0), tf.float64) 33 | y_pred = tf.cast(tf.boolean_mask(y_pred, mask, axis=0), tf.float64) 34 | 35 | for i, axis in enumerate([None, 0]): 36 | TP = tf.cast(tf.count_nonzero(y_pred * y_true, axis=axis), tf.float64) 37 | FP = tf.cast(tf.count_nonzero(y_pred * (y_true - 1), axis=axis), tf.float64) 38 | FN = tf.cast(tf.count_nonzero((y_pred - 1) * y_true, axis=axis), tf.float64) 39 | 40 | precision = TP / (TP + FP + epsilon) 41 | recall = TP / (TP + FN + epsilon) 42 | f1 = 2 * precision * recall / (precision + recall + epsilon) 43 | 44 | f1s[i] = tf.reduce_mean(f1) 45 | 46 | weights = tf.reduce_sum(y_true, axis=0) 47 | weights /= tf.reduce_sum(weights + epsilon) 48 | 49 | f1s[2] = tf.reduce_sum(f1 * weights) 50 | 51 | micro, macro, weighted = f1s 52 | return micro, macro, weighted, TP, FP, FN 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Graph transfer learning via adversarial domain adaptation with graph convolution 3 | 4 | This is our implementation for the following paper: 5 | 6 | >[Q. Dai, X.-M. Wu, J. Xiao, X. Shen, and D. Wang, “Graph transfer learning via adversarial domain adaptation with graph convolution,” IEEE Transactions on Knowledge and Data Engineering, pp. 1–1, 2022](https://ieeexplore.ieee.org/abstract/document/9684927). 7 | 8 | 9 | ## Abstract 10 | This paper studies the problem of cross-network node classification to overcome the insufficiency of labeled data in a single network. It aims to leverage the label information in a partially labeled source network to assist node classification in a completely unlabeled or partially labeled target network. Existing methods for single network learning cannot solve this problem due to the domain shift across networks. Some multi-network learning methods heavily rely on the existence of cross-network connections, thus are inapplicable for this problem. To tackle this problem, we propose a novel graph transfer learning framework AdaGCN by leveraging the techniques of adversarial domain adaptation and graph convolution. It consists of two components: a semi-supervised learning component and an adversarial domain adaptation component. The former aims to learn class discriminative node representations with given label information of the source and target networks, while the latter contributes to mitigating the distribution divergence between the source and target domains to facilitate knowledge transfer. Extensive empirical evaluations on real-world datasets show that AdaGCN can successfully transfer class information with a low label rate on the source network and a substantial divergence between the source and target domains. 11 | 12 | ## Environment requirement 13 | The code has been tested running under Python 3.5.2. The required packages are as follows: 14 | * python == 3.5.2 15 | * tensorflow-gpu == 1.13.0-rc0 16 | * numpy == 1.16.2 17 | 18 | ## Examples to run the codes 19 | * Multi-label classification with source training rate as 10% (Table 3) 20 | ``` 21 | python train_WD.py # set signal = [1], target_train_rate = [0], FLAGS.gnn=gcn or FLAGS.gnn=igcn 22 | ``` 23 | 24 | * Multi-label classification with source training rate as 10% and target train set as 5% (Table 4) 25 | ``` 26 | python train_WD.py # set signal = [2], target_train_rate = [0.05], FLAGS.gnn=gcn or FLAGS.gnn=igcn 27 | ``` 28 | 29 | ## Citation 30 | If you would like to use our code, please cite: 31 | ``` 32 | @ARTICLE{dai_graph_2022, 33 | author={Dai, Quanyu and Wu, Xiao-Ming and Xiao, Jiaren and Shen, Xiao and Wang, Dan}, 34 | journal={IEEE Transactions on Knowledge and Data Engineering}, 35 | title={Graph Transfer Learning via Adversarial Domain Adaptation with Graph Convolution}, 36 | year={2022}, 37 | volume={}, 38 | number={}, 39 | pages={1-1}, 40 | doi={10.1109/TKDE.2022.3144250}} 41 | ``` 42 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from inits import * 2 | import tensorflow as tf 3 | 4 | flags = tf.app.flags 5 | FLAGS = flags.FLAGS 6 | 7 | # global unique layer ID dictionary for layer name assignment 8 | _LAYER_UIDS = {} 9 | 10 | 11 | def get_layer_uid(layer_name=''): 12 | """Helper function, assigns unique layer IDs.""" 13 | if layer_name not in _LAYER_UIDS: 14 | _LAYER_UIDS[layer_name] = 1 15 | return 1 16 | else: 17 | _LAYER_UIDS[layer_name] += 1 18 | return _LAYER_UIDS[layer_name] 19 | 20 | 21 | def sparse_dropout(x, keep_prob, noise_shape): 22 | """Dropout for sparse tensors.""" 23 | random_tensor = keep_prob 24 | random_tensor += tf.random_uniform(noise_shape) 25 | dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool) 26 | pre_out = tf.sparse_retain(x, dropout_mask) 27 | return pre_out * (1./keep_prob) 28 | 29 | 30 | def dot(x, y, sparse=False): 31 | """Wrapper for tf.matmul (sparse vs dense).""" 32 | if sparse: 33 | res = tf.sparse_tensor_dense_matmul(x, y) 34 | else: 35 | res = tf.matmul(x, y) 36 | return res 37 | 38 | 39 | class Layer(object): 40 | """Base layer class. Defines basic API for all layer objects. 41 | Implementation inspired by keras (http://keras.io). 42 | 43 | # Properties 44 | name: String, defines the variable scope of the layer. 45 | logging: Boolean, switches Tensorflow histogram logging on/off 46 | 47 | # Methods 48 | _call(inputs): Defines computation graph of layer 49 | (i.e. takes input, returns output) 50 | __call__(inputs): Wrapper for _call() 51 | _log_vars(): Log all variables 52 | """ 53 | 54 | def __init__(self, **kwargs): 55 | allowed_kwargs = {'name', 'logging'} 56 | for kwarg in kwargs.keys(): 57 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 58 | name = kwargs.get('name') 59 | if not name: 60 | layer = self.__class__.__name__.lower() 61 | name = layer + '_' + str(get_layer_uid(layer)) 62 | self.name = name 63 | 64 | logging = kwargs.get('logging', False) 65 | self.logging = logging 66 | self.sparse_inputs = False 67 | 68 | def _call(self, inputs): 69 | return inputs 70 | 71 | def __call__(self, inputs): 72 | with tf.name_scope(self.name): 73 | if self.logging and not self.sparse_inputs: 74 | tf.summary.histogram(self.name + '/inputs', inputs) 75 | outputs = self._call(inputs) 76 | if self.logging: 77 | tf.summary.histogram(self.name + '/outputs', outputs) 78 | return outputs 79 | 80 | class Dense(Layer): 81 | """Dense layer.""" 82 | def __init__(self, 83 | placeholder_dropout, 84 | placeholder_num_features_nonzero, 85 | weights, 86 | bias, 87 | dropout=True, 88 | sparse_inputs=False, 89 | act=tf.nn.relu, 90 | flag_bias=False, 91 | **kwargs): 92 | super(Dense, self).__init__(**kwargs) 93 | 94 | if dropout: 95 | self.dropout = placeholder_dropout 96 | else: 97 | self.dropout = 0. 98 | 99 | self.act = act 100 | self.sparse_inputs = sparse_inputs 101 | self.weights = weights 102 | self.bias = bias 103 | self.flag_bias = flag_bias 104 | 105 | # helper variable for sparse dropout 106 | self.num_features_nonzero = placeholder_num_features_nonzero, 107 | 108 | def _call(self, inputs): 109 | x = inputs 110 | 111 | # dropout 112 | if self.sparse_inputs: 113 | x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero) 114 | else: 115 | x = tf.nn.dropout(x, 1-self.dropout) 116 | 117 | # transform 118 | output = dot(x, self.weights, sparse=self.sparse_inputs) 119 | 120 | # bias 121 | if self.flag_bias: 122 | output += self.bias 123 | 124 | return self.act(output) 125 | 126 | 127 | class GraphConvolution(Layer): 128 | """Graph convolution layer.""" 129 | def __init__(self, 130 | input_dim, 131 | output_dim, 132 | placeholder_dropout, 133 | placeholder_support, 134 | placeholder_num_features_nonzero, 135 | weights, 136 | bias, 137 | dropout=True, 138 | sparse_inputs=False, 139 | act=tf.nn.relu, 140 | flag_bias=False, 141 | featureless=False, 142 | **kwargs): 143 | super(GraphConvolution, self).__init__(**kwargs) 144 | 145 | if dropout: 146 | self.dropout = placeholder_dropout 147 | else: 148 | self.dropout = 0. 149 | 150 | self.act = act 151 | self.support = placeholder_support 152 | self.sparse_inputs = sparse_inputs 153 | self.featureless = featureless 154 | self.flag_bias = flag_bias 155 | self.weights = weights 156 | self.bias = bias 157 | 158 | # helper variable for sparse dropout 159 | self.num_features_nonzero = placeholder_num_features_nonzero 160 | 161 | def conv(self, adj, features): 162 | ''' 163 | IGCN renormalization filtering 164 | ''' 165 | 166 | def tf_rnm(adj, features, k): 167 | new_feature = features 168 | for _ in range(k): 169 | new_feature = tf.sparse_tensor_dense_matmul(adj, new_feature) 170 | # dense_adj = tf.sparse_tensor_to_dense(adj, validate_indices=False) 171 | # new_feature = tf.matmul(dense_adj, new_feature, a_is_sparse=True) 172 | return new_feature 173 | 174 | result = tf_rnm(adj, features, FLAGS.smoothing_steps) 175 | return result 176 | 177 | 178 | def _call(self, inputs): 179 | x = inputs 180 | 181 | # dropout 182 | if self.sparse_inputs: 183 | x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero) 184 | else: 185 | x = tf.nn.dropout(x, 1-self.dropout) 186 | 187 | # convolve 188 | supports = list() 189 | for i in range(len(self.support)): 190 | if not self.featureless: 191 | pre_sup = dot(x, self.weights, sparse=self.sparse_inputs) 192 | else: 193 | pre_sup = self.weights 194 | 195 | if FLAGS.gnn=='gcn': 196 | support = dot(self.support[i], pre_sup, sparse=True) 197 | elif FLAGS.gnn=='igcn': 198 | support = self.conv(self.support[i], pre_sup) 199 | supports.append(support) 200 | 201 | output = tf.add_n(supports) 202 | 203 | # bias 204 | if self.flag_bias: 205 | output += self.bias 206 | 207 | return self.act(output) 208 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import pickle as pkl 4 | import networkx as nx 5 | import scipy 6 | import scipy.sparse as sp 7 | import scipy.io as sio 8 | from scipy.sparse.linalg.eigen.arpack import eigsh 9 | from scipy.sparse import csc_matrix, hstack, vstack 10 | from sklearn.decomposition import PCA 11 | from sklearn.decomposition import TruncatedSVD 12 | import sys 13 | 14 | import tensorflow as tf 15 | 16 | 17 | def parse_index_file(filename): 18 | """Parse index file.""" 19 | index = [] 20 | for line in open(filename): 21 | index.append(int(line.strip())) 22 | return index 23 | 24 | def sample_mask(idx, l): 25 | """Create mask.""" 26 | mask = np.zeros(l) 27 | mask[idx] = 1 28 | return np.array(mask, dtype=np.bool) 29 | 30 | 31 | def sparse_to_tuple(sparse_mx): 32 | """Convert sparse matrix to tuple representation.""" 33 | def to_tuple(mx): 34 | if not sp.isspmatrix_coo(mx): 35 | mx = mx.tocoo() 36 | coords = np.vstack((mx.row, mx.col)).transpose() 37 | values = mx.data 38 | shape = mx.shape 39 | return coords, values, shape 40 | 41 | if isinstance(sparse_mx, list): 42 | for i in range(len(sparse_mx)): 43 | sparse_mx[i] = to_tuple(sparse_mx[i]) 44 | else: 45 | sparse_mx = to_tuple(sparse_mx) 46 | 47 | return sparse_mx 48 | 49 | 50 | def preprocess_features(features): 51 | """Row-normalize feature matrix and convert to tuple representation""" 52 | rowsum = np.array(features.sum(1)) 53 | r_inv = np.power(rowsum, -1).flatten() 54 | r_inv[np.isinf(r_inv)] = 0. 55 | r_mat_inv = sp.diags(r_inv) 56 | features = r_mat_inv.dot(features) 57 | return sparse_to_tuple(features) 58 | 59 | 60 | def normalize_adj(adj): 61 | """Symmetrically normalize adjacency matrix.""" 62 | adj = sp.coo_matrix(adj) 63 | rowsum = np.array(adj.sum(1)) 64 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 65 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 66 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 67 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 68 | 69 | 70 | def preprocess_adj(adj): 71 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" 72 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) 73 | 74 | return sparse_to_tuple(adj_normalized) 75 | 76 | 77 | def construct_feed_dict(features_t, Y_t, support_t, labels_t, labels_mask_t, \ 78 | features_s, Y_s, support_s, labels_s, labels_mask_s, placeholders, lr_gen, lr_dis): 79 | """Construct feed dictionary.""" 80 | feed_dict = dict() 81 | # targetn network 82 | feed_dict.update({placeholders['labels_t']: labels_t}) 83 | feed_dict.update({placeholders['labels_mask_t']: labels_mask_t}) 84 | feed_dict.update({placeholders['features_t']: features_t}) 85 | feed_dict.update({placeholders['support_t'][i]: support_t[i] for i in range(len(support_t))}) 86 | feed_dict.update({placeholders['num_features_nonzero_t']: features_t[1].shape}) 87 | # source network 88 | feed_dict.update({placeholders['labels_s']: labels_s}) 89 | feed_dict.update({placeholders['labels_mask_s']: labels_mask_s}) 90 | feed_dict.update({placeholders['features_s']: features_s}) 91 | feed_dict.update({placeholders['support_s'][i]: support_s[i] for i in range(len(support_s))}) 92 | feed_dict.update({placeholders['num_features_nonzero_s']: features_s[1].shape}) 93 | #learning rate 94 | feed_dict.update({placeholders['lr_gen']: lr_gen}) 95 | feed_dict.update({placeholders['lr_dis']: lr_dis}) 96 | 97 | feed_dict.update({placeholders['source_top_k_list']: np.array(np.sum(Y_s, 1), dtype=np.int32)}) 98 | feed_dict.update({placeholders['target_top_k_list']: np.array(np.sum(Y_t, 1), dtype=np.int32)}) 99 | 100 | return feed_dict 101 | 102 | 103 | def construct_feed_dict_target(features, y, support, labels, labels_mask, placeholders): 104 | """Construct feed dictionary.""" 105 | feed_dict = dict() 106 | feed_dict.update({placeholders['labels_t']: labels}) 107 | feed_dict.update({placeholders['labels_mask_t']: labels_mask}) 108 | feed_dict.update({placeholders['features_t']: features}) 109 | feed_dict.update({placeholders['support_t'][i]: support[i] for i in range(len(support))}) 110 | feed_dict.update({placeholders['num_features_nonzero_t']: features[1].shape}) 111 | feed_dict.update({placeholders['dropout']: 0.0}) 112 | feed_dict.update({placeholders['target_top_k_list']: np.array(np.sum(y, 1), dtype=np.int32)}) 113 | return feed_dict 114 | 115 | def construct_feed_dict_source(features, y, support, labels, labels_mask, placeholders): 116 | """Construct feed dictionary.""" 117 | feed_dict = dict() 118 | feed_dict.update({placeholders['labels_s']: labels}) 119 | feed_dict.update({placeholders['labels_mask_s']: labels_mask}) 120 | feed_dict.update({placeholders['features_s']: features}) 121 | feed_dict.update({placeholders['support_s'][i]: support[i] for i in range(len(support))}) 122 | feed_dict.update({placeholders['num_features_nonzero_s']: features[1].shape}) 123 | feed_dict.update({placeholders['dropout']: 0.0}) 124 | feed_dict.update({placeholders['source_top_k_list']: np.array(np.sum(y, 1), dtype=np.int32)}) 125 | return feed_dict 126 | 127 | 128 | def data_splits(y, Ntr, Nval, Nts, s_type='planetoid'): 129 | 130 | np.random.seed(123456) 131 | 132 | if s_type=='planetoid': 133 | if Ntr>0: 134 | idx_train = [] 135 | tr_label_per_class = Ntr // y.shape[1] 136 | 137 | for i in range(y.shape[1]): 138 | if i==y.shape[1]-1: 139 | tr_label_per_class = Ntr - tr_label_per_class*(y.shape[1]-1) 140 | idx_tr_class_i = list(np.random.choice(np.where(y[:, i]!=0)[0], tr_label_per_class, replace=False)) 141 | # print('idx_tr_class_i:', idx_tr_class_i) 142 | idx_train = idx_train + idx_tr_class_i 143 | idx = list(set(range(y.shape[0]))-set(idx_train)) 144 | 145 | if Nval>0: 146 | idx_val = list(np.random.choice(idx, Nval, replace=False)) 147 | idx_test = list(set(idx)-set(idx_val)) 148 | else: 149 | idx_val = None 150 | idx_test = idx 151 | 152 | elif Ntr==0 and Nval>0: 153 | idx_train = None 154 | idx = range(y.shape[0]) 155 | 156 | idx_val = list(np.random.choice(idx, Nval, replace=False)) 157 | idx_test = list(set(range(y.shape[0]))-set(idx_val)) 158 | 159 | elif Ntr==0 and Nval==0: 160 | idx_train = None 161 | idx_val = None 162 | idx_test = range(y.shape[0]) 163 | 164 | elif s_type=='random': 165 | if Ntr>0: 166 | idx_train = list(np.random.choice(np.array(range(y.shape[0])), Ntr, replace=False)) 167 | y_train_label = np.sum(y[idx_train, :], axis=0) 168 | 169 | while np.where(y_train_label==0)[0].shape[0]>0: 170 | idx_train = list(np.random.choice(np.array(range(y.shape[0])), Ntr, replace=False)) 171 | y_train_label = np.sum(y[idx_train, :], axis=0) 172 | 173 | idx = list(set(range(y.shape[0]))-set(idx_train)) 174 | 175 | if Nval>0: 176 | idx_val = list(np.random.choice(idx, Nval, replace=False)) 177 | idx_test = list(set(idx)-set(idx_val)) 178 | else: 179 | idx_val = None 180 | idx_test = idx 181 | 182 | elif Ntr==0 and Nval>0: 183 | idx_train = None 184 | idx = range(y.shape[0]) 185 | 186 | idx_val = list(np.random.choice(idx, Nval, replace=False)) 187 | idx_test = list(set(range(y.shape[0]))-set(idx_val)) 188 | 189 | elif Ntr==0 and Nval==0: 190 | idx_train = None 191 | idx_val = None 192 | idx_test = range(y.shape[0]) 193 | 194 | return idx_train, idx_val, idx_test 195 | 196 | def get_splits(y, tr_ratio, val_ratio, ts_ratio, flag=True, s_type='planetoid'): 197 | """ 198 | flag: 199 | - True : tr_ratio, val_ratio, ts_ratio are ratios 200 | - False: tr_ratio, val_ratio, ts_ratio are ratios 201 | """ 202 | 203 | def sample_mask(idx, l): 204 | """Create mask.""" 205 | mask = np.zeros(l) 206 | mask[idx] = 1 207 | return np.array(mask, dtype=np.bool) 208 | 209 | N = y.shape[0] 210 | if flag: 211 | Ntr = int(N*tr_ratio) 212 | Nval = int(N*val_ratio) 213 | Nts = N - Ntr - Nval 214 | else: 215 | Ntr = tr_ratio 216 | Nval = val_ratio 217 | Nts = ts_ratio 218 | 219 | idx_train, idx_val, idx_test = data_splits(y, Ntr, Nval, Nts, s_type=s_type) 220 | 221 | if Ntr==0: 222 | y_train = np.zeros(y.shape, dtype=np.int32) 223 | train_mask = np.array(np.zeros(N), dtype=np.bool) 224 | else: 225 | y_train = np.zeros(y.shape, dtype=np.int32) 226 | y_train[idx_train] = y[idx_train] 227 | train_mask = sample_mask(idx_train, N) 228 | 229 | if Nval==0: 230 | y_val = np.zeros(y.shape, dtype=np.int32) 231 | val_mask = np.array(np.zeros(N), dtype=np.bool) 232 | else: 233 | y_val = np.zeros(y.shape, dtype=np.int32) 234 | y_val[idx_val] = y[idx_val] 235 | val_mask = sample_mask(idx_val, N) 236 | 237 | y_test = np.zeros(y.shape, dtype=np.int32) 238 | y_test[idx_test] = y[idx_test] 239 | test_mask = sample_mask(idx_test, N) 240 | 241 | return y_train, y_val, y_test, train_mask, val_mask, test_mask 242 | 243 | def load_mat_data(file, train_ratio, val_ratio, test_ratio, s_type='planetoid'): 244 | net = sio.loadmat(file) 245 | X, A, Y = net['attrb'], net['network'], net['group'] 246 | if not isinstance(X, scipy.sparse.csc.csc_matrix): 247 | X = csc_matrix(X) 248 | train_num, val_num, test_num = int(Y.shape[0]*train_ratio), int(Y.shape[0]*val_ratio), int(Y.shape[0]*test_ratio) 249 | y_train, y_val, y_test, train_mask, val_mask, test_mask = get_splits(Y, train_num, val_num, test_num, flag=False, s_type=s_type) 250 | # print() 251 | 252 | return A, X, Y, y_train, y_val, y_test, train_mask, val_mask, test_mask 253 | 254 | 255 | -------------------------------------------------------------------------------- /train_WD.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os 5 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 6 | 7 | import time 8 | from utils import * 9 | from models import GCN 10 | 11 | # Define model evaluation function 12 | def evaluate(sess, model, features, y, support, labels, mask, placeholders): 13 | t_test = time.time() 14 | feed_dict_val = construct_feed_dict_target(features, y, support, labels, mask, placeholders) 15 | outs_val = sess.run([model.clf_loss_t, model.micro_f1_t, model.macro_f1_t, 16 | model.weighted_f1_t], feed_dict=feed_dict_val) 17 | 18 | return outs_val[0], outs_val[1], outs_val[2], outs_val[3], (time.time() - t_test) 19 | 20 | def train(FLAGS, X_t, Y_t, support_t, y_train_t, train_mask_t, 21 | X_s, Y_s, support_s, y_train_s, train_mask_s, placeholders): 22 | 23 | # Create model 24 | model = model_func(placeholders, X_t[2][1], Y_s.shape[0], Y_t.shape[0], logging=True) 25 | 26 | config = tf.ConfigProto() 27 | config.gpu_options.allow_growth = True 28 | sess = tf.Session(config=config) 29 | sess.run(tf.global_variables_initializer()) 30 | 31 | micro_f1 = [] 32 | macro_f1 = [] 33 | 34 | # Train model 35 | for epoch in range(FLAGS.epochs): 36 | # Construct feed dictionary 37 | ########################### 38 | # learning rate decaying 39 | ########################### 40 | if FLAGS.da_method=='WD': 41 | # naive method 42 | if (epoch+1)>=500 and (epoch+1)%100==0: 43 | FLAGS.lr_gen = FLAGS.lr_gen * FLAGS.shrinking 44 | FLAGS.lr_dis = FLAGS.lr_dis * FLAGS.shrinking 45 | 46 | feed_dict = construct_feed_dict(X_t, Y_t, support_t, y_train_t, train_mask_t, \ 47 | X_s, Y_s, support_s, y_train_s, train_mask_s, placeholders, FLAGS.lr_gen, FLAGS.lr_dis) 48 | feed_dict.update({placeholders['dropout']: FLAGS.dropout}) 49 | if FLAGS.signal==1: 50 | # domain adaptation 51 | # only source has labeled nodes 52 | if FLAGS.da_method=='WD': 53 | wd_loss, dis_loss_total = [], [] 54 | for _ in range(FLAGS.D_train_step): 55 | outs_dis = sess.run([model.wd_d_op, model.wd_loss, model.dis_loss_total], feed_dict=feed_dict) 56 | wd_loss.append(outs_dis[1]) 57 | dis_loss_total.append(outs_dis[2]) 58 | outs_gen = sess.run([model.opt_op_total_s], feed_dict=feed_dict) 59 | 60 | elif FLAGS.signal==2: 61 | # domain adaptation 62 | # both source and target have labeled nodes 63 | if FLAGS.da_method=='WD': 64 | for _ in range(FLAGS.D_train_step): 65 | outs_dis = sess.run([model.wd_d_op, model.wd_loss, model.dis_loss_total], feed_dict=feed_dict) 66 | 67 | outs_gen = sess.run([model.opt_op_total_s_t], feed_dict=feed_dict) 68 | 69 | ########################################## 70 | # Recording test results after each epoch 71 | ########################################## 72 | test_clf_loss_t, test_micf1, test_macf1, test_wf1, test_duration = evaluate(sess, model, X_t, Y_t, support_t, y_test, test_mask, placeholders) 73 | print("Epoch:{}".format(epoch+1), "signal={}".format(FLAGS.signal), "S-T:{}-{}".format(FLAGS.source, FLAGS.target), 74 | "hiddens={}, dropout={}, l2_param={}".format(FLAGS.hiddens_gcn, FLAGS.dropout, FLAGS.l2_param), 75 | "cost=", "{:.3f}".format(test_clf_loss_t), "micro_f1={:.3f}".format(test_micf1), 76 | "macro_f1={:.3f}".format(test_macf1), "weighted_f1={:.3f}".format(test_wf1)) 77 | 78 | micro_f1.append(test_micf1) 79 | macro_f1.append(test_macf1) 80 | 81 | ##################### 82 | # Testing 83 | ##################### 84 | test_clf_loss_t, test_micf1, test_macf1, test_wf1, test_duration = evaluate(sess, model, X_t, Y_t, support_t, y_test, test_mask, placeholders) 85 | print("signal={}".format(FLAGS.signal), "S-T:{}-{}".format(FLAGS.source, FLAGS.target), 86 | "hiddens={}, dropout={}, l2_param={}".format(FLAGS.hiddens_gcn, FLAGS.dropout, FLAGS.l2_param), 87 | "cost=", "{:.3f}".format(test_clf_loss_t), "micro_f1={:.3f}".format(test_micf1), 88 | "macro_f1={:.3f}".format(test_macf1), "weighted_f1={:.3f}".format(test_wf1)) 89 | 90 | return test_micf1, test_macf1 91 | 92 | ################################################################################################################## 93 | # # Set random seed 94 | seed = 1234 95 | np.random.seed(seed) 96 | tf.set_random_seed(seed) 97 | 98 | # Settings 99 | flags = tf.app.flags 100 | FLAGS = flags.FLAGS 101 | flags.DEFINE_string('source', 'citationv1', 'Source dataset string.') # 'dblpv7', 'citationv1', 'acmv9' 102 | flags.DEFINE_string('target', 'dblpv7', 'Target dataset string.') # 'dblpv7', 'citationv1', 'acmv9' 103 | flags.DEFINE_integer('epochs', 1000, 'Number of epochs to train.') 104 | flags.DEFINE_float('dropout', 0.3, 'Dropout rate (1 - keep probability).') 105 | flags.DEFINE_float('l2_param', 5e-5, 'Weight for L2 loss on embedding matrix.') 106 | flags.DEFINE_integer('signal', 4, 'The network to train: 1-with domain adaptation (source_only), 2-with domain adaptation (source and target).') 107 | flags.DEFINE_float('da_param', 1, 'Weight for wassertein loss.') 108 | flags.DEFINE_float('gp_param', 10, 'Weight for penalty loss.') 109 | flags.DEFINE_integer('D_train_step', 10, 'The number of steps for training discriminator.') 110 | flags.DEFINE_float('shrinking', 0.8, 'Initial learning rate for discriminator.') 111 | flags.DEFINE_float('train_rate', 0, 'The ratio of labeled nodes in target networks.') 112 | flags.DEFINE_float('val_rate', 0, 'The ratio of labeled nodes in validation set in target networks.') 113 | flags.DEFINE_string('hiddens_gcn', '1000|100|16', 'Number of units in different hidden layers for gcn.') 114 | flags.DEFINE_string('hiddens_clf', '', 'Number of units in different hidden layers for supervised classifier.') 115 | flags.DEFINE_string('hiddens_dis', '16', 'Number of units in different hidden layers for dicriminator.') 116 | flags.DEFINE_string('da_method', 'WD', 'Domain adaptation method.') 117 | flags.DEFINE_float('lr_gen', 1.5e-3, 'Initial learning rate.') 118 | flags.DEFINE_float('lr_dis', 1.5e-3, 'Initial learning rate for discriminator.') 119 | flags.DEFINE_boolean('with_metrics', True, 'whether computing f1 scores within tensorflow.') 120 | flags.DEFINE_float('source_train_rate', 0.1, 'The ratio of labeled nodes in target networks.') 121 | #----------------- 122 | # IGCN 123 | flags.DEFINE_integer('num_gcn_layers', 1, 'The number of gcn layers in the IGCN model.') 124 | flags.DEFINE_integer('smoothing_steps', 10, 'The setting of k in A^k.') 125 | flags.DEFINE_string('gnn', 'igcn', 'Convolutional methods.') # 'gcn', 'igcn' 126 | #----------------- 127 | 128 | ############################################################ 129 | ############################################################ 130 | datasets = ['dblpv7', 'citationv1', 'acmv9'] 131 | signal = [1] 132 | target_train_rate = [0] # shall be zero for signal=1; shall not be zero (e.g., 0.05) for signal=2. 133 | smoothing_steps = [10] # setting smoothing steps according to source training rate 134 | lr_gen = 1.5e-3 135 | lr_dis = 1.5e-3 136 | for s in range(len(signal)): 137 | FLAGS.signal = signal[s] 138 | ################################################################################### 139 | for i in range(len(datasets)): 140 | for j in range(len(datasets)): 141 | FLAGS.target = datasets[i] 142 | if i==j: 143 | continue 144 | else: 145 | FLAGS.source = datasets[j] 146 | final_micro = [] 147 | final_macro = [] 148 | for k, t in enumerate(target_train_rate): 149 | FLAGS.lr_gen = lr_gen 150 | FLAGS.lr_dis = lr_dis 151 | FLAGS.train_rate = t 152 | FLAGS.smoothing_steps = smoothing_steps[k] 153 | #################### 154 | # Load target data 155 | #################### 156 | # determining labeled nodes sampling method 157 | Y_tmp = sio.loadmat('./input/{}.mat'.format(FLAGS.target))['group'] 158 | N_tmp, M_tmp = Y_tmp.shape 159 | Ntr_s = int(N_tmp*FLAGS.train_rate) 160 | tr_label_per_class = Ntr_s // M_tmp 161 | label_node_min = np.min(np.sum(Y_tmp, axis=0)) 162 | if tr_label_per_class>label_node_min: 163 | s_type = 'random' 164 | else: 165 | s_type = 'planetoid' 166 | #------------------------------ 167 | train_ratio = FLAGS.train_rate 168 | val_ratio = FLAGS.val_rate 169 | test_ratio = 1 - FLAGS.train_rate - FLAGS.val_rate 170 | A_t, X_t, Y_t, y_train_t, y_val, y_test, train_mask_t, val_mask, test_mask = load_mat_data('./input/{}.mat'.format(FLAGS.target), 171 | train_ratio, val_ratio, test_ratio, s_type=s_type) 172 | #################### 173 | #################### 174 | ############################################# 175 | # determining labeled nodes sampling method 176 | Y_tmp = sio.loadmat('./input/{}.mat'.format(FLAGS.source))['group'] 177 | N_tmp, M_tmp = Y_tmp.shape 178 | Ntr_s = int(N_tmp*FLAGS.source_train_rate) 179 | tr_label_per_class = Ntr_s // M_tmp 180 | # Load source data 181 | label_node_min = np.min(np.sum(Y_tmp, axis=0)) 182 | if tr_label_per_class>label_node_min: 183 | s_type = 'random' 184 | else: 185 | s_type = 'planetoid' 186 | ############################################## 187 | source_train_ratio = FLAGS.source_train_rate 188 | source_val_ratio = 0 189 | source_test_ratio = 1 - source_train_ratio - source_val_ratio 190 | A_s, X_s, Y_s, y_train_s, y_val_s, y_test_s, train_mask_s, val_mask_s, test_mask_s = load_mat_data('./input/{}.mat'.format(FLAGS.source), 191 | source_train_ratio, source_val_ratio, source_test_ratio, s_type=s_type) 192 | #################### 193 | # Some preprocessing 194 | #################### 195 | N_t = Y_t.shape[0] 196 | N_s = Y_s.shape[0] 197 | X_t = preprocess_features(X_t) 198 | X_s = preprocess_features(X_s) 199 | support_t = [preprocess_adj(A_t)] 200 | support_s = [preprocess_adj(A_s)] 201 | num_supports = 1 202 | model_func = GCN 203 | #################### 204 | # Define placeholders 205 | placeholders = { 206 | 'support_t': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)], 207 | 'features_t': tf.sparse_placeholder(tf.float32, shape=tf.constant(X_t[2], dtype=tf.int64)), 208 | 'labels_t': tf.placeholder(tf.float32, shape=(Y_t.shape[0], Y_t.shape[1])), 209 | 'labels_mask_t': tf.placeholder(tf.bool, shape=(Y_t.shape[0])), 210 | 'num_features_nonzero_t': tf.placeholder(tf.int32), # helper variable for sparse dropout 211 | 'support_s': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)], 212 | 'features_s': tf.sparse_placeholder(tf.float32, shape=tf.constant(X_s[2], dtype=tf.int64)), 213 | 'labels_s': tf.placeholder(tf.float32, shape=(Y_s.shape[0], Y_s.shape[1])), 214 | 'labels_mask_s': tf.placeholder(tf.bool, shape=(Y_s.shape[0])), 215 | 'num_features_nonzero_s': tf.placeholder(tf.int32), # helper variable for sparse dropout 216 | 'dropout': tf.placeholder_with_default(0., shape=()), 217 | 'lr_dis': tf.placeholder(tf.float32, shape=()), 218 | 'lr_gen': tf.placeholder(tf.float32, shape=()), 219 | 'l': tf.placeholder(tf.float32, shape=()), 220 | 'source_top_k_list': tf.placeholder(tf.int32, shape=(Y_s.shape[0])), # for multi-label classification 221 | 'target_top_k_list': tf.placeholder(tf.int32, shape=(Y_t.shape[0])) 222 | } 223 | ############################################################################### 224 | test_micf1, test_macf1 = train(FLAGS, X_t, Y_t, support_t, y_train_t, train_mask_t, \ 225 | X_s, Y_s, support_s, y_train_s, train_mask_s, placeholders) 226 | 227 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from layers import * 2 | from metrics import * 3 | 4 | 5 | flags = tf.app.flags 6 | FLAGS = flags.FLAGS 7 | 8 | def define_variables(hiddens, weight_name, bias_name, flag=False): 9 | variables = {} 10 | for i in range(len(hiddens)-1): 11 | variables[weight_name.format(i)] = glorot([hiddens[i], hiddens[i+1]], name=weight_name.format(i)) 12 | if flag: 13 | variables[bias_name.format(i)] = zeros([hiddens[i+1]], name=bias_name.format(i)) 14 | 15 | return variables 16 | 17 | #################### 18 | 19 | class GCN(object): 20 | def __init__(self, 21 | placeholders, 22 | input_dim, 23 | N_s, 24 | N_t, 25 | bias_flag = False, 26 | c_type='multi-label', 27 | **kwargs): 28 | allowed_kwargs = {'name', 'logging'} 29 | for kwarg in kwargs.keys(): 30 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 31 | name = kwargs.get('name') 32 | if not name: 33 | name = self.__class__.__name__.lower() 34 | self.name = name 35 | 36 | logging = kwargs.get('logging', False) 37 | self.logging = logging 38 | 39 | self.da_method = FLAGS.da_method 40 | self.c_type = c_type 41 | self.bias_flag = bias_flag 42 | 43 | self.inputs_t = placeholders['features_t'] 44 | self.inputs_s = placeholders['features_s'] 45 | self.output_dim = placeholders['labels_t'].get_shape().as_list()[1] 46 | 47 | self.N = N_s if N_s