├── .gitignore ├── README.md ├── configs ├── config_random_15_new.py └── config_random_60_new.py ├── data ├── TPT-48.zip ├── toy_d15_spiral_tight_boundary.pkl └── toy_d60_spiral.pkl ├── dataset_generator ├── dataset_generator.py └── draw_graph_utils.py ├── dataset_utils └── dataset.py ├── derive_g_encode ├── 499_pred_GDA_new.pkl ├── derive.py ├── g_encode.pkl ├── g_encode_15.pkl └── g_encode_60.pkl ├── fig ├── GRDA-DG-15-results.png ├── GRDA-domain-graph-US.png ├── blog-method-DA-vs-GRDA.png ├── compcar_quantitive_result.jpg ├── dg_15_60_quantitive_result.jpg └── tpt_48_quantitive_result.jpg ├── main.py ├── model ├── model.py └── modules.py ├── requirements.txt └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__/ 3 | dataset_utils/__pycache__/ 4 | dataset_utils/__init__.py 5 | dump/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph-Relational Domain Adaptation (GRDA) 2 | This repo contains the code for our ICLR 2022 paper:
3 | **Graph-Relational Domain Adaptation**
4 | Zihao Xu, Hao He, Guang-He Lee, Yuyang Wang, Hao Wang
5 | *Tenth International Conference on Learning Representations (ICLR), 2022*
6 | [[Paper](http://wanghao.in/paper/ICLR22_GRDA.pdf)] [[Talk](https://www.youtube.com/watch?v=oNM5hZGVv34)] [[Slides](http://wanghao.in/slides/GRDA_slides.pptx)][[TPT-48 Dataset](data/TPT-48.zip)] 7 | 8 | 9 | ## Beyond Domain Adaptation: Brief Introduction for GRDA 10 | Essentially GRDA goes beyond current (categorical) domain adaptation regime and proposes the first approach to **adapt across graph-relational domains**. We introduce a new notion, dubbed "**domain graph**", that to encode domain adjacency, e.g., a graph of states in the US with each state as a domain and each edge indicating adjacency. Theoretical analysis shows that *at equilibrium, GRDA recovers classic domain adaptation when the graph is a clique, and achieves non-trivial alignment for other types of graphs*. See the following example (black nodes as source domains and white nodes as target domains). 11 | 12 |

13 | 14 |

15 | 16 | ## Sample Results 17 | In a DA problem with 15 domains connected by a domain graph (see the figure below), if we use domains 0, 3, 4, 8, 12, 14 as source domains (left of the following figure) and the rest as target domains, below are some sample results from previous domain adaptation methods and GRDA (right of the figure), where GRDA successfully generalizes across different domains in the graph. 18 | 19 |

20 | 21 |

22 | 23 | ## Method Overview 24 | We provide a simple yet effective learning framework with **theoretical guarantees** (see the [Theory section](https://github.com/Wang-ML-Lab/GRDA/edit/main/README.md#theory-informal) at the end of this README). Below is a quick comparison between previous domain adaptation methods and GRDA (differences marked in red). 25 | * Previous domain adaptation methods use a discriminator is classifify different domains (as categorical values), while GRDA's discriminator directly reconstructs the domain graph (as a adjacency matrix). 26 | * Previous domain adaptation methods' encoders ignore domain IDs, while GRDA takes the domain IDs with the domain graph as input. 27 |

28 | 29 |

30 | 31 | ## Quantitative Result 32 | #### Toy Dataset: DG-15 and DG-60 33 |

34 | 35 |

36 | 37 | #### TPT-48 38 |

39 | 40 |

41 | 42 | #### CompCars 43 |

44 | 45 |

46 | 47 | ## Theory (Informal) 48 | - Traditional DA is equivalent to using our GRDA with a fully-connected graph (i.e., a clique). 49 | - D and E converge if and only if 𝔼i\~p(u|e),j\~p(u|e')[Ai,j|e, e'] = 𝔼i,j[Ai,j]. 50 | - The global optimum of the two-player game between E and D matches the three-player game between E, D, and F. 51 | 52 | 53 | 54 | ## Installation 55 | pip install -r requirements.txt 56 | 57 | ## How to Train GRDA 58 | python main.py 59 | 60 | ## Visualization 61 | We use visdom to visualize. We assume the code is run on a remote gpu machine. 62 | 63 | ### Change Configurations 64 | Find the config in "config" folder. Choose the config you need and Set "opt.use_visdom" to "True". 65 | 66 | ### Start a Visdom Server on Your Machine 67 | python -m visdom.server -p 2000 68 | Now connect your computer with the gpu server and forward the port 2000 to your local computer. You can now go to: 69 | http://localhost:2000 (Your local address) 70 | to see the visualization during training. 71 | 72 | ## Also Check Out Relevant Work 73 | **Continuously Indexed Domain Adaptation**
74 | Hao Wang*, Hao He*, Dina Katabi
75 | *Thirty-Seventh International Conference on Machine Learning (ICML), 2020*
76 | [[Paper](http://wanghao.in/paper/ICML20_CIDA.pdf)] [[Code](https://github.com/hehaodele/CIDA)] [[Talk](https://www.youtube.com/watch?v=KtZPSCD-WhQ)] [[Blog](http://wanghao.in/CIDA-Blog/CIDA.html)] [[Slides](http://wanghao.in/slides/CIDA_slides.pptx)] 77 | 78 | ## Reference 79 | [Graph-Relational Domain Adaptation](http://wanghao.in/paper/ICLR22_GRDA.pdf) 80 | ```bib 81 | @inproceedings{GRDA, 82 | title={Graph-Relational Domain Adaptation}, 83 | author={Xu, Zihao and He, Hao and Lee, Guang-He and Wang, Yuyang and Wang, Hao}, 84 | booktitle={International Conference on Learning Representations}, 85 | year={2022} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /configs/config_random_15_new.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import numpy as np 3 | import pickle 4 | 5 | 6 | def read_pickle(name): 7 | with open(name, "rb") as f: 8 | data = pickle.load(f) 9 | return data 10 | 11 | 12 | # load/output dir 13 | opt = EasyDict() 14 | opt.loadf = "./dump" 15 | opt.outf = "./dump" 16 | 17 | # normalize each data domain 18 | # opt.normalize_domain = False 19 | 20 | # now it is half circle 21 | opt.num_domain = 15 22 | # the specific source and target domain: 23 | opt.src_domain = np.array([0, 12, 3, 4, 14, 8]) # tight_boundary 24 | opt.num_source = opt.src_domain.shape[0] 25 | opt.num_target = opt.num_domain - opt.num_source 26 | opt.test_on_all_dmn = True 27 | 28 | 29 | print("src domain: {}".format(opt.src_domain)) 30 | 31 | # opt.model = "DANN" 32 | # opt.model = "CDANN" 33 | # opt.model = "ADDA" 34 | # opt.model = 'MDD' 35 | opt.model = "GDA" 36 | opt.cond_disc = ( 37 | False # whether use conditional discriminator or not (for CDANN) 38 | ) 39 | print("model: {}".format(opt.model)) 40 | opt.use_visdom = False 41 | opt.visdom_port = 2000 42 | 43 | opt.use_g_encode = True # False # True 44 | if opt.use_g_encode: 45 | opt.g_encode = read_pickle("derive_g_encode/g_encode_15.pkl") 46 | 47 | opt.device = "cuda" 48 | opt.seed = 233 # 1# 101 # 1 # 233 # 1 49 | 50 | opt.lambda_gan = 1 # 0.5 51 | # for MDD use only 52 | opt.lambda_src = 0.5 53 | opt.lambda_tgt = 0.5 54 | 55 | 56 | opt.num_epoch = 500 57 | opt.batch_size = 10 58 | opt.lr_d = 1e-4 # 3e-5 # 1e-4 59 | opt.lr_e = 1e-4 # 3e-5 # 1e-4 60 | opt.lr_g = 1e-4 61 | opt.gamma = 100 62 | opt.beta1 = 0.9 63 | opt.weight_decay = 5e-4 64 | opt.wgan = False # do not use wgan to train 65 | opt.no_bn = True # do not use batch normalization # True 66 | 67 | # model size configs, used for D, E, F 68 | opt.nx = 2 # dimension of the input data 69 | opt.nt = 2 # dimension of the vertex embedding 70 | opt.nh = 512 # dimension of hidden # 512 71 | opt.nc = 2 # number of label class 72 | opt.nd_out = 2 # dimension of D's output 73 | 74 | # sample how many vertices for training D 75 | opt.sample_v = 10 76 | 77 | # # sample how many vertices for training G 78 | opt.sample_v_g = 15 79 | 80 | opt.test_interval = 20 81 | opt.save_interval = 100 82 | # drop out rate 83 | opt.p = 0.2 84 | opt.shuffle = True 85 | 86 | # dataset 87 | opt.dataset = "data/toy_d15_spiral_tight_boundary.pkl" 88 | -------------------------------------------------------------------------------- /configs/config_random_60_new.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import numpy as np 3 | import pickle 4 | 5 | 6 | def read_pickle(name): 7 | with open(name, "rb") as f: 8 | data = pickle.load(f) 9 | return data 10 | 11 | 12 | # load/output dir 13 | opt = EasyDict() # set experiment configs 14 | opt.loadf = "./dump" 15 | opt.outf = "./dump" 16 | 17 | # normalize each data domain 18 | # opt.normalize_domain = False 19 | 20 | # now it is half circle 21 | opt.num_domain = 60 22 | # the specific source and target domain: 23 | opt.src_domain = np.array([2, 14, 41, 23, 59, 33]) # 60-spiral 24 | opt.num_source = opt.src_domain.shape[0] 25 | opt.num_target = opt.num_domain - opt.num_source 26 | opt.test_on_all_dmn = True 27 | 28 | 29 | print("src domain: {}".format(opt.src_domain)) 30 | 31 | # opt.model = "DANN" 32 | # opt.model = "CDANN" 33 | # opt.model = "ADDA" 34 | # opt.model = 'MDD' 35 | opt.model = "GDA" 36 | opt.cond_disc = ( 37 | False # whether use conditional discriminator or not (for CDANN) 38 | ) 39 | print("model: {}".format(opt.model)) 40 | opt.use_visdom = True 41 | opt.visdom_port = 2000 42 | 43 | # we do not prepare the pretrain g encode for random-60 dataset 44 | opt.use_g_encode = True 45 | if opt.use_g_encode: 46 | opt.g_encode = read_pickle("derive_g_encode/g_encode_60.pkl") 47 | 48 | opt.device = "cuda" 49 | opt.seed = 233 # 1# 101 # 1 # 233 # 1 50 | 51 | opt.lambda_gan = 0.5 # 0.5 # 0.3125 # 0.5 # 0.5 52 | 53 | # for MDD use only 54 | opt.lambda_src = 0.5 55 | opt.lambda_tgt = 0.5 56 | 57 | opt.num_epoch = 500 58 | opt.batch_size = 10 # 10 59 | opt.lr_d = 1e-5 # 3e-5 # 1e-4 # 2.9 * 1e-5 #3e-5 # 1e-4 60 | opt.lr_e = 1e-5 # 3e-5 # 1e-4 # 2.9 * 1e-5 61 | opt.lr_g = 1e-3 # 3e-4 62 | opt.gamma = 100 63 | opt.beta1 = 0.9 64 | opt.weight_decay = 5e-4 65 | opt.wgan = False # do not use wgan to train 66 | opt.no_bn = True # do not use batch normalization # True 67 | 68 | # model size configs, used for D, E, F 69 | opt.nx = 2 # dimension of the input data 70 | opt.nt = 2 # dimension of the vertex embedding 71 | opt.nh = 512 # dimension of hidden # 512 72 | opt.nc = 2 # number of label class 73 | opt.nd_out = 2 # dimension of D's output 74 | 75 | # sample how many vertices for training D 76 | opt.sample_v = 30 77 | 78 | # # sample how many vertices for training G 79 | opt.sample_v_g = 60 80 | 81 | opt.test_interval = 20 # 20 82 | opt.save_interval = 100 83 | # drop out rate 84 | opt.p = 0.2 85 | opt.shuffle = True 86 | 87 | # dataset 88 | opt.dataset = "data/toy_d60_spiral.pkl" 89 | -------------------------------------------------------------------------------- /data/TPT-48.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/data/TPT-48.zip -------------------------------------------------------------------------------- /data/toy_d15_spiral_tight_boundary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/data/toy_d15_spiral_tight_boundary.pkl -------------------------------------------------------------------------------- /data/toy_d60_spiral.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/data/toy_d60_spiral.pkl -------------------------------------------------------------------------------- /dataset_generator/dataset_generator.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import collections 3 | import csv 4 | import shutil 5 | import os 6 | import numpy as np 7 | from os.path import join 8 | import matplotlib.pyplot as plt 9 | # from utils import read_pickle 10 | # from utils import write_pickle 11 | from scipy.ndimage.interpolation import zoom 12 | import re 13 | import pickle 14 | import networkx as nx 15 | from draw_graph_utils import draw 16 | 17 | 18 | def read_pickle(name): 19 | with open(name, 'rb') as f: 20 | data = pickle.load(f) 21 | return data 22 | 23 | def write_pickle(data, name): 24 | with open(name,'wb') as f: 25 | # the default protocol level is 4 26 | pickle.dump(data, f) 27 | 28 | def show_graph_with_labels(adjacency_matrix, my_angles): 29 | rows, cols = np.where(adjacency_matrix == 1) 30 | edges = zip(rows.tolist(), cols.tolist()) 31 | gr = nx.Graph() 32 | gr.add_edges_from(edges) 33 | 34 | pos = nx.kamada_kawai_layout(gr) # good, littel better than spring 35 | 36 | num_domain = adjacency_matrix.shape[0] 37 | 38 | # expand the graph in horizontal 39 | for i in range(num_domain): 40 | pos[i][0] *= 1.4 41 | pos[i][1] *= 0.8 42 | 43 | 44 | labels = dict() 45 | for i in range(num_domain): 46 | labels[i] = i 47 | 48 | # use self defined picture drawing picture. 49 | fig, ax = plt.subplots(1, 1) 50 | draw(gr, pos, node_radius=0.077, font_color='white', node_angles=my_angles, labels=labels, with_labels=True, ax=ax) 51 | ax.set_aspect("equal") 52 | plt.show() 53 | 54 | # generate data given the mean/std, radius and number 55 | def generate_data(mean, std, radius, num): 56 | dim = mean.shape[0] 57 | m_data = np.random.randn(num, dim) 58 | print('bingo', m_data.shape) 59 | m_data *= std[None, :] 60 | m_radius = m_data[:, 0] ** 2 + m_data[:, 1] ** 2 61 | m_data += mean[None, :] 62 | m_data = m_data[m_radius <= radius ** 2, :] 63 | 64 | # random choice 65 | choice = np.random.choice(m_data.shape[0], size=50, replace=False) 66 | print(choice) 67 | m_data = m_data[choice, :] 68 | 69 | print('num of data points within radius', radius, ':', m_data.shape[0]) 70 | return m_data 71 | 72 | # generate label for circle-shape data 73 | def generate_label(m_data, radius): 74 | m_radius = m_data[:, 0] ** 2 + m_data[:, 1] ** 2 75 | m_label = np.zeros((m_data.shape[0],)) 76 | m_label[m_radius > radius ** 2] = 1 77 | print("=============") 78 | print("label 0's num: {}".format(np.sum(m_label == 0))) 79 | print("label 1's num: {}".format(np.sum(m_label == 1))) 80 | return m_label 81 | 82 | # create dataset, circle-shape domain manifold 83 | def create_toy_data(): 84 | fname = 'toy_d60_spiral.pkl' 85 | num_domain = 60 86 | # fname = 'toy_d30_spiral.pkl' 87 | # l_angle = np.random.rand(15) * np.pi / 2 88 | # l_angle = np.random.rand(num_domain) * np.pi * 2 89 | l_angle = np.random.rand(num_domain) * np.pi / 30 * num_domain 90 | 91 | radius_start = 1 92 | std_small = 1 93 | radius_step = 0.1 94 | radius_small = 1 95 | 96 | lm_data = [] 97 | l_domain = [] 98 | l_label = [] 99 | for i, angle in enumerate(l_angle): 100 | # radius = radius_start + angle / (np.pi / 2) * radius_step 101 | # radius = radius_start + angle / (np.pi) * radius_step 102 | radius = radius_start + angle / (np.pi / 30 * num_domain) * radius_step * 60 103 | mean = np.array([np.cos(angle), np.sin(angle)]) * radius 104 | std = np.ones((2,)) * std_small 105 | m_data = generate_data(mean, std, radius_small, 300) 106 | m_data = np.append(m_data, generate_data(-mean, std, radius_small, 300), axis=0) 107 | m_label = np.ones(m_data.shape[0],) 108 | m_label[0:int(0.5 * m_data.shape[0])] = 0 109 | # m_data = generate_data(mean, std, radius_small, 300) 110 | # m_label = generate_label(m_data, ra) 111 | l_label.append(m_label) 112 | lm_data.append(m_data) 113 | l_domain.append(np.ones(m_data.shape[0],) * i) 114 | 115 | angle_all = np.array(l_angle) 116 | data_all = np.concatenate(lm_data, axis=0) 117 | domain_all = np.concatenate(l_domain, axis=0) 118 | label_all = np.concatenate(l_label, axis=0)# generate_label(data_all, radius_large) 119 | 120 | # generate A 121 | # A's generation: 122 | A = np.zeros((num_domain, num_domain)) 123 | for i in range(num_domain): 124 | for j in range(i + 1, num_domain): 125 | p = np.cos(angle_all[i]) * np.cos(angle_all[j]) + np.sin(angle_all[i]) * np.sin(angle_all[j]) 126 | 127 | if num_domain == 15: 128 | if p < 0.5: 129 | c = 0 130 | else: 131 | c = np.random.binomial(1, p) 132 | elif num_domain == 60: 133 | if p < 0.2: 134 | c = 0 135 | else: 136 | c = np.random.binomial(1, p) 137 | 138 | A[i][j] = c 139 | A[j][i] = c 140 | 141 | show_graph_with_labels(A, angle_all) 142 | print(angle_all) 143 | 144 | 145 | d_pkl = dict() 146 | d_pkl['data'] = data_all 147 | d_pkl['label'] = label_all 148 | d_pkl['domain'] = domain_all 149 | d_pkl['A'] = A 150 | d_pkl['angle'] = angle_all 151 | write_pickle(d_pkl, fname) 152 | 153 | l_style = ['k*', 'r*', 'b*', 'y*', 'k.', 'r.'] 154 | for i in range(2): 155 | data_sub = data_all[label_all == i, :] 156 | plt.plot(data_sub[:, 0], data_sub[:, 1], l_style[i]) 157 | plt.grid() 158 | plt.show() 159 | 160 | if __name__ == '__main__': 161 | create_toy_data() -------------------------------------------------------------------------------- /dataset_generator/draw_graph_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.patches as mpatches 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | # this is a modified version of networkx 6 | 7 | 8 | def arc_patch(center, radius, theta1, theta2, ax=None, resolution=50, **kwargs): 9 | # be sure that theta is in radian!! 10 | # make sure ax is not empty 11 | if ax is None: 12 | ax = plt.gca() 13 | # generate the points 14 | theta = np.linspace(theta1, theta2, resolution) 15 | points = np.vstack((radius*np.cos(theta) + center[0], 16 | radius*np.sin(theta) + center[1])) 17 | # build the polygon and add it to the axes 18 | # print(**kwargs) 19 | poly = mpatches.Polygon(points.T, closed=True, **kwargs) 20 | # poly = mpatches.Polygon(points.T, closed=True, fill=True, color='blue') 21 | ax.add_patch(poly) 22 | ax.plot() 23 | return poly 24 | 25 | 26 | """ 27 | ********** 28 | Matplotlib 29 | ********** 30 | 31 | Draw networks with matplotlib. 32 | 33 | See Also 34 | -------- 35 | 36 | matplotlib: http://matplotlib.org/ 37 | 38 | pygraphviz: http://pygraphviz.github.io/ 39 | 40 | """ 41 | from numbers import Number 42 | import networkx as nx 43 | from networkx.drawing.layout import ( 44 | shell_layout, 45 | circular_layout, 46 | kamada_kawai_layout, 47 | spectral_layout, 48 | spring_layout, 49 | random_layout, 50 | planar_layout, 51 | ) 52 | 53 | __all__ = [ 54 | "draw", 55 | "draw_networkx", 56 | "draw_networkx_nodes", 57 | "draw_networkx_edges", 58 | "draw_networkx_labels", 59 | "draw_networkx_edge_labels", 60 | "draw_circular", 61 | "draw_kamada_kawai", 62 | "draw_random", 63 | "draw_spectral", 64 | "draw_spring", 65 | "draw_planar", 66 | "draw_shell", 67 | ] 68 | 69 | 70 | def draw(G, pos=None, ax=None, **kwds): 71 | """Draw the graph G with Matplotlib. 72 | 73 | Draw the graph as a simple representation with no node 74 | labels or edge labels and using the full Matplotlib figure area 75 | and no axis labels by default. See draw_networkx() for more 76 | full-featured drawing that allows title, axis labels etc. 77 | 78 | Parameters 79 | ---------- 80 | G : graph 81 | A networkx graph 82 | 83 | pos : dictionary, optional 84 | A dictionary with nodes as keys and positions as values. 85 | If not specified a spring layout positioning will be computed. 86 | See :py:mod:`networkx.drawing.layout` for functions that 87 | compute node positions. 88 | 89 | ax : Matplotlib Axes object, optional 90 | Draw the graph in specified Matplotlib axes. 91 | 92 | kwds : optional keywords 93 | See networkx.draw_networkx() for a description of optional keywords. 94 | 95 | Examples 96 | -------- 97 | >>> G = nx.dodecahedral_graph() 98 | >>> nx.draw(G) 99 | >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout 100 | 101 | See Also 102 | -------- 103 | draw_networkx() 104 | draw_networkx_nodes() 105 | draw_networkx_edges() 106 | draw_networkx_labels() 107 | draw_networkx_edge_labels() 108 | 109 | Notes 110 | ----- 111 | This function has the same name as pylab.draw and pyplot.draw 112 | so beware when using `from networkx import *` 113 | 114 | since you might overwrite the pylab.draw function. 115 | 116 | With pyplot use 117 | 118 | >>> import matplotlib.pyplot as plt 119 | >>> G = nx.dodecahedral_graph() 120 | >>> nx.draw(G) # networkx draw() 121 | >>> plt.draw() # pyplot draw() 122 | 123 | Also see the NetworkX drawing examples at 124 | https://networkx.github.io/documentation/latest/auto_examples/index.html 125 | """ 126 | try: 127 | import matplotlib.pyplot as plt 128 | except ImportError as e: 129 | raise ImportError("Matplotlib required for draw()") from e 130 | except RuntimeError: 131 | print("Matplotlib unable to open display") 132 | raise 133 | 134 | if ax is None: 135 | cf = plt.gcf() 136 | else: 137 | cf = ax.get_figure() 138 | cf.set_facecolor("w") 139 | if ax is None: 140 | if cf._axstack() is None: 141 | ax = cf.add_axes((0, 0, 1, 1)) 142 | else: 143 | ax = cf.gca() 144 | 145 | if "with_labels" not in kwds: 146 | kwds["with_labels"] = "labels" in kwds 147 | 148 | draw_networkx(G, pos=pos, ax=ax, **kwds) 149 | ax.set_axis_off() 150 | plt.draw_if_interactive() 151 | return 152 | 153 | 154 | 155 | def draw_networkx(G, pos=None, arrows=True, with_labels=True, **kwds): 156 | """Draw the graph G using Matplotlib. 157 | 158 | Draw the graph with Matplotlib with options for node positions, 159 | labeling, titles, and many other drawing features. 160 | See draw() for simple drawing without labels or axes. 161 | 162 | Parameters 163 | ---------- 164 | G : graph 165 | A networkx graph 166 | 167 | pos : dictionary, optional 168 | A dictionary with nodes as keys and positions as values. 169 | If not specified a spring layout positioning will be computed. 170 | See :py:mod:`networkx.drawing.layout` for functions that 171 | compute node positions. 172 | 173 | arrows : bool, optional (default=True) 174 | For directed graphs, if True draw arrowheads. 175 | Note: Arrows will be the same color as edges. 176 | 177 | arrowstyle : str, optional (default='-|>') 178 | For directed graphs, choose the style of the arrowsheads. 179 | See :py:class: `matplotlib.patches.ArrowStyle` for more 180 | options. 181 | 182 | arrowsize : int, optional (default=10) 183 | For directed graphs, choose the size of the arrow head head's length and 184 | width. See :py:class: `matplotlib.patches.FancyArrowPatch` for attribute 185 | `mutation_scale` for more info. 186 | 187 | with_labels : bool, optional (default=True) 188 | Set to True to draw labels on the nodes. 189 | 190 | ax : Matplotlib Axes object, optional 191 | Draw the graph in the specified Matplotlib axes. 192 | 193 | nodelist : list, optional (default G.nodes()) 194 | Draw only specified nodes 195 | 196 | edgelist : list, optional (default=G.edges()) 197 | Draw only specified edges 198 | 199 | node_size : scalar or array, optional (default=300) 200 | Size of nodes. If an array is specified it must be the 201 | same length as nodelist. 202 | 203 | node_color : color or array of colors (default='#1f78b4') 204 | Node color. Can be a single color or a sequence of colors with the same 205 | length as nodelist. Color can be string, or rgb (or rgba) tuple of 206 | floats from 0-1. If numeric values are specified they will be 207 | mapped to colors using the cmap and vmin,vmax parameters. See 208 | matplotlib.scatter for more details. 209 | 210 | node_shape : string, optional (default='o') 211 | The shape of the node. Specification is as matplotlib.scatter 212 | marker, one of 'so^>v>> G = nx.dodecahedral_graph() 274 | >>> nx.draw(G) 275 | >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout 276 | 277 | >>> import matplotlib.pyplot as plt 278 | >>> limits = plt.axis("off") # turn of axis 279 | 280 | Also see the NetworkX drawing examples at 281 | https://networkx.github.io/documentation/latest/auto_examples/index.html 282 | 283 | See Also 284 | -------- 285 | draw() 286 | draw_networkx_nodes() 287 | draw_networkx_edges() 288 | draw_networkx_labels() 289 | draw_networkx_edge_labels() 290 | """ 291 | try: 292 | import matplotlib.pyplot as plt 293 | except ImportError as e: 294 | raise ImportError("Matplotlib required for draw()") from e 295 | except RuntimeError: 296 | print("Matplotlib unable to open display") 297 | raise 298 | 299 | valid_node_kwds = ( 300 | # some of my own parameters: 301 | "node_angles", 302 | "node_radius", 303 | 304 | "nodelist", 305 | "node_size", 306 | "node_color", 307 | "node_shape", 308 | "alpha", 309 | "cmap", 310 | "vmin", 311 | "vmax", 312 | "ax", 313 | "linewidths", 314 | "edgecolors", 315 | "label", 316 | ) 317 | 318 | valid_edge_kwds = ( 319 | "edgelist", 320 | "width", 321 | "edge_color", 322 | "style", 323 | "alpha", 324 | "arrowstyle", 325 | "arrowsize", 326 | "edge_cmap", 327 | "edge_vmin", 328 | "edge_vmax", 329 | "ax", 330 | "label", 331 | "node_size", 332 | "nodelist", 333 | "node_shape", 334 | "connectionstyle", 335 | "min_source_margin", 336 | "min_target_margin", 337 | ) 338 | 339 | valid_label_kwds = ( 340 | "labels", 341 | "font_size", 342 | "font_color", 343 | "font_family", 344 | "font_weight", 345 | "alpha", 346 | "bbox", 347 | "ax", 348 | "horizontalalignment", 349 | "verticalalignment", 350 | ) 351 | 352 | valid_kwds = valid_node_kwds + valid_edge_kwds + valid_label_kwds 353 | 354 | if any([k not in valid_kwds for k in kwds]): 355 | invalid_args = ", ".join([k for k in kwds if k not in valid_kwds]) 356 | raise ValueError(f"Received invalid argument(s): {invalid_args}") 357 | 358 | node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds} 359 | edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds} 360 | label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds} 361 | 362 | if pos is None: 363 | pos = nx.drawing.spring_layout(G) # default to spring layout 364 | 365 | # draw_networkx_nodes(G, pos, **node_kwds) 366 | 367 | # use my own nodes instead 368 | my_draw_networkx_nodes(G, pos, **node_kwds) 369 | 370 | draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds) 371 | if with_labels: 372 | draw_networkx_labels(G, pos, **label_kwds) 373 | plt.draw_if_interactive() 374 | 375 | 376 | def my_draw_networkx_nodes( 377 | G, 378 | pos, 379 | node_angles=None, 380 | node_radius=1, 381 | nodelist=None, 382 | ax=None, 383 | label=None, 384 | ): 385 | from collections.abc import Iterable 386 | 387 | if ax is None: 388 | ax = plt.gca() 389 | 390 | if nodelist is None: 391 | nodelist = list(G) 392 | 393 | # xys = np.asarray([pos[v] for v in nodelist]) 394 | # num_node = 395 | # for i in range(num_node): 396 | for v in nodelist: 397 | arc_patch(pos[v], node_radius, node_angles[v], node_angles[v] + np.pi, ax, fill=True, color='blue') 398 | arc_patch(pos[v], node_radius, node_angles[v]+np.pi, node_angles[v] + 2 * np.pi, ax, fill=True, color='red') 399 | 400 | ax.tick_params( 401 | axis="both", 402 | which="both", 403 | bottom=False, 404 | left=False, 405 | labelbottom=False, 406 | labelleft=False, 407 | ) 408 | 409 | 410 | 411 | def draw_networkx_nodes( 412 | G, 413 | pos, 414 | nodelist=None, 415 | node_size=300, 416 | node_color="#1f78b4", 417 | node_shape="o", 418 | alpha=None, 419 | cmap=None, 420 | vmin=None, 421 | vmax=None, 422 | ax=None, 423 | linewidths=None, 424 | edgecolors=None, 425 | label=None, 426 | ): 427 | """Draw the nodes of the graph G. 428 | 429 | This draws only the nodes of the graph G. 430 | 431 | Parameters 432 | ---------- 433 | G : graph 434 | A networkx graph 435 | 436 | pos : dictionary 437 | A dictionary with nodes as keys and positions as values. 438 | Positions should be sequences of length 2. 439 | 440 | ax : Matplotlib Axes object, optional 441 | Draw the graph in the specified Matplotlib axes. 442 | 443 | nodelist : list, optional 444 | Draw only specified nodes (default G.nodes()) 445 | 446 | node_size : scalar or array 447 | Size of nodes (default=300). If an array is specified it must be the 448 | same length as nodelist. 449 | 450 | node_color : color or array of colors (default='#1f78b4') 451 | Node color. Can be a single color or a sequence of colors with the same 452 | length as nodelist. Color can be string, or rgb (or rgba) tuple of 453 | floats from 0-1. If numeric values are specified they will be 454 | mapped to colors using the cmap and vmin,vmax parameters. See 455 | matplotlib.scatter for more details. 456 | 457 | node_shape : string 458 | The shape of the node. Specification is as matplotlib.scatter 459 | marker, one of 'so^>v>> G = nx.dodecahedral_graph() 490 | >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G)) 491 | 492 | Also see the NetworkX drawing examples at 493 | https://networkx.github.io/documentation/latest/auto_examples/index.html 494 | 495 | See Also 496 | -------- 497 | draw() 498 | draw_networkx() 499 | draw_networkx_edges() 500 | draw_networkx_labels() 501 | draw_networkx_edge_labels() 502 | """ 503 | from collections.abc import Iterable 504 | 505 | try: 506 | import matplotlib.pyplot as plt 507 | from matplotlib.collections import PathCollection 508 | import numpy as np 509 | except ImportError as e: 510 | raise ImportError("Matplotlib required for draw()") from e 511 | except RuntimeError: 512 | print("Matplotlib unable to open display") 513 | raise 514 | 515 | if ax is None: 516 | ax = plt.gca() 517 | 518 | if nodelist is None: 519 | nodelist = list(G) 520 | 521 | if len(nodelist) == 0: # empty nodelist, no drawing 522 | return PathCollection(None) 523 | 524 | try: 525 | xy = np.asarray([pos[v] for v in nodelist]) 526 | except KeyError as e: 527 | raise nx.NetworkXError(f"Node {e} has no position.") from e 528 | except ValueError as e: 529 | raise nx.NetworkXError("Bad value in node positions.") from e 530 | 531 | if isinstance(alpha, Iterable): 532 | node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax) 533 | alpha = None 534 | 535 | node_collection = ax.scatter( 536 | xy[:, 0], 537 | xy[:, 1], 538 | s=node_size, 539 | c=node_color, 540 | marker=node_shape, 541 | cmap=cmap, 542 | vmin=vmin, 543 | vmax=vmax, 544 | alpha=alpha, 545 | linewidths=linewidths, 546 | edgecolors=edgecolors, 547 | label=label, 548 | ) 549 | ax.tick_params( 550 | axis="both", 551 | which="both", 552 | bottom=False, 553 | left=False, 554 | labelbottom=False, 555 | labelleft=False, 556 | ) 557 | 558 | node_collection.set_zorder(2) 559 | return node_collection 560 | 561 | 562 | 563 | def draw_networkx_edges( 564 | G, 565 | pos, 566 | edgelist=None, 567 | width=1.0, 568 | edge_color="k", 569 | style="solid", 570 | alpha=None, 571 | arrowstyle="-|>", 572 | arrowsize=10, 573 | edge_cmap=None, 574 | edge_vmin=None, 575 | edge_vmax=None, 576 | ax=None, 577 | arrows=True, 578 | label=None, 579 | node_size=300, 580 | nodelist=None, 581 | node_shape="o", 582 | connectionstyle=None, 583 | min_source_margin=0, 584 | min_target_margin=0, 585 | ): 586 | """Draw the edges of the graph G. 587 | 588 | This draws only the edges of the graph G. 589 | 590 | Parameters 591 | ---------- 592 | G : graph 593 | A networkx graph 594 | 595 | pos : dictionary 596 | A dictionary with nodes as keys and positions as values. 597 | Positions should be sequences of length 2. 598 | 599 | edgelist : collection of edge tuples 600 | Draw only specified edges(default=G.edges()) 601 | 602 | width : float, or array of floats 603 | Line width of edges (default=1.0) 604 | 605 | edge_color : color or array of colors (default='k') 606 | Edge color. Can be a single color or a sequence of colors with the same 607 | length as edgelist. Color can be string, or rgb (or rgba) tuple of 608 | floats from 0-1. If numeric values are specified they will be 609 | mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. 610 | 611 | style : string 612 | Edge line style (default='solid') (solid|dashed|dotted,dashdot) 613 | 614 | alpha : float 615 | The edge transparency (default=None) 616 | 617 | edge_ cmap : Matplotlib colormap 618 | Colormap for mapping intensities of edges (default=None) 619 | 620 | edge_vmin,edge_vmax : floats 621 | Minimum and maximum for edge colormap scaling (default=None) 622 | 623 | ax : Matplotlib Axes object, optional 624 | Draw the graph in the specified Matplotlib axes. 625 | 626 | arrows : bool, optional (default=True) 627 | For directed graphs, if True draw arrowheads. 628 | Note: Arrows will be the same color as edges. 629 | 630 | arrowstyle : str, optional (default='-|>') 631 | For directed graphs, choose the style of the arrow heads. 632 | See :py:class: `matplotlib.patches.ArrowStyle` for more 633 | options. 634 | 635 | arrowsize : int, optional (default=10) 636 | For directed graphs, choose the size of the arrow head head's length and 637 | width. See :py:class: `matplotlib.patches.FancyArrowPatch` for attribute 638 | `mutation_scale` for more info. 639 | 640 | connectionstyle : str, optional (default=None) 641 | Pass the connectionstyle parameter to create curved arc of rounding 642 | radius rad. For example, connectionstyle='arc3,rad=0.2'. 643 | See :py:class: `matplotlib.patches.ConnectionStyle` and 644 | :py:class: `matplotlib.patches.FancyArrowPatch` for more info. 645 | 646 | label : [None| string] 647 | Label for legend 648 | 649 | min_source_margin : int, optional (default=0) 650 | The minimum margin (gap) at the begining of the edge at the source. 651 | 652 | min_target_margin : int, optional (default=0) 653 | The minimum margin (gap) at the end of the edge at the target. 654 | 655 | Returns 656 | ------- 657 | matplotlib.collection.LineCollection 658 | `LineCollection` of the edges 659 | 660 | list of matplotlib.patches.FancyArrowPatch 661 | `FancyArrowPatch` instances of the directed edges 662 | 663 | Depending whether the drawing includes arrows or not. 664 | 665 | Notes 666 | ----- 667 | For directed graphs, arrows are drawn at the head end. Arrows can be 668 | turned off with keyword arrows=False. Be sure to include `node_size` as a 669 | keyword argument; arrows are drawn considering the size of nodes. 670 | 671 | Examples 672 | -------- 673 | >>> G = nx.dodecahedral_graph() 674 | >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) 675 | 676 | >>> G = nx.DiGraph() 677 | >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)]) 678 | >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) 679 | >>> alphas = [0.3, 0.4, 0.5] 680 | >>> for i, arc in enumerate(arcs): # change alpha values of arcs 681 | ... arc.set_alpha(alphas[i]) 682 | 683 | Also see the NetworkX drawing examples at 684 | https://networkx.github.io/documentation/latest/auto_examples/index.html 685 | 686 | See Also 687 | -------- 688 | draw() 689 | draw_networkx() 690 | draw_networkx_nodes() 691 | draw_networkx_labels() 692 | draw_networkx_edge_labels() 693 | """ 694 | try: 695 | import matplotlib.pyplot as plt 696 | from matplotlib.colors import colorConverter, Colormap, Normalize 697 | from matplotlib.collections import LineCollection 698 | from matplotlib.patches import FancyArrowPatch 699 | import numpy as np 700 | except ImportError as e: 701 | raise ImportError("Matplotlib required for draw()") from e 702 | except RuntimeError: 703 | print("Matplotlib unable to open display") 704 | raise 705 | 706 | if ax is None: 707 | ax = plt.gca() 708 | 709 | if edgelist is None: 710 | edgelist = list(G.edges()) 711 | 712 | if len(edgelist) == 0: # no edges! 713 | if not G.is_directed() or not arrows: 714 | return LineCollection(None) 715 | else: 716 | return [] 717 | 718 | if nodelist is None: 719 | nodelist = list(G.nodes()) 720 | 721 | # FancyArrowPatch handles color=None different from LineCollection 722 | if edge_color is None: 723 | edge_color = "k" 724 | 725 | # set edge positions 726 | edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) 727 | 728 | # Check if edge_color is an array of floats and map to edge_cmap. 729 | # This is the only case handled differently from matplotlib 730 | if ( 731 | np.iterable(edge_color) 732 | and (len(edge_color) == len(edge_pos)) 733 | and np.alltrue([isinstance(c, Number) for c in edge_color]) 734 | ): 735 | if edge_cmap is not None: 736 | assert isinstance(edge_cmap, Colormap) 737 | else: 738 | edge_cmap = plt.get_cmap() 739 | if edge_vmin is None: 740 | edge_vmin = min(edge_color) 741 | if edge_vmax is None: 742 | edge_vmax = max(edge_color) 743 | color_normal = Normalize(vmin=edge_vmin, vmax=edge_vmax) 744 | edge_color = [edge_cmap(color_normal(e)) for e in edge_color] 745 | 746 | if not G.is_directed() or not arrows: 747 | edge_collection = LineCollection( 748 | edge_pos, 749 | colors=edge_color, 750 | linewidths=width, 751 | antialiaseds=(1,), 752 | linestyle=style, 753 | transOffset=ax.transData, 754 | alpha=alpha, 755 | ) 756 | 757 | edge_collection.set_cmap(edge_cmap) 758 | edge_collection.set_clim(edge_vmin, edge_vmax) 759 | 760 | edge_collection.set_zorder(1) # edges go behind nodes 761 | edge_collection.set_label(label) 762 | ax.add_collection(edge_collection) 763 | 764 | return edge_collection 765 | 766 | arrow_collection = None 767 | 768 | if G.is_directed() and arrows: 769 | # Note: Waiting for someone to implement arrow to intersection with 770 | # marker. Meanwhile, this works well for polygons with more than 4 771 | # sides and circle. 772 | 773 | def to_marker_edge(marker_size, marker): 774 | if marker in "s^>v>> G = nx.dodecahedral_graph() 929 | >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G)) 930 | 931 | Also see the NetworkX drawing examples at 932 | https://networkx.github.io/documentation/latest/auto_examples/index.html 933 | 934 | See Also 935 | -------- 936 | draw() 937 | draw_networkx() 938 | draw_networkx_nodes() 939 | draw_networkx_edges() 940 | draw_networkx_edge_labels() 941 | """ 942 | try: 943 | import matplotlib.pyplot as plt 944 | except ImportError as e: 945 | raise ImportError("Matplotlib required for draw()") from e 946 | except RuntimeError: 947 | print("Matplotlib unable to open display") 948 | raise 949 | 950 | if ax is None: 951 | ax = plt.gca() 952 | 953 | if labels is None: 954 | labels = {n: n for n in G.nodes()} 955 | 956 | text_items = {} # there is no text collection so we'll fake one 957 | for n, label in labels.items(): 958 | (x, y) = pos[n] 959 | if not isinstance(label, str): 960 | label = str(label) # this makes "1" and 1 labeled the same 961 | t = ax.text( 962 | x, 963 | y, 964 | label, 965 | size=font_size, 966 | color=font_color, 967 | family=font_family, 968 | weight=font_weight, 969 | alpha=alpha, 970 | horizontalalignment=horizontalalignment, 971 | verticalalignment=verticalalignment, 972 | transform=ax.transData, 973 | bbox=bbox, 974 | clip_on=True, 975 | ) 976 | text_items[n] = t 977 | 978 | ax.tick_params( 979 | axis="both", 980 | which="both", 981 | bottom=False, 982 | left=False, 983 | labelbottom=False, 984 | labelleft=False, 985 | ) 986 | 987 | return text_items 988 | 989 | 990 | 991 | def draw_networkx_edge_labels( 992 | G, 993 | pos, 994 | edge_labels=None, 995 | label_pos=0.5, 996 | font_size=10, 997 | font_color="k", 998 | font_family="sans-serif", 999 | font_weight="normal", 1000 | alpha=None, 1001 | bbox=None, 1002 | horizontalalignment="center", 1003 | verticalalignment="center", 1004 | ax=None, 1005 | rotate=True, 1006 | ): 1007 | """Draw edge labels. 1008 | 1009 | Parameters 1010 | ---------- 1011 | G : graph 1012 | A networkx graph 1013 | 1014 | pos : dictionary 1015 | A dictionary with nodes as keys and positions as values. 1016 | Positions should be sequences of length 2. 1017 | 1018 | ax : Matplotlib Axes object, optional 1019 | Draw the graph in the specified Matplotlib axes. 1020 | 1021 | alpha : float or None 1022 | The text transparency (default=None) 1023 | 1024 | edge_labels : dictionary 1025 | Edge labels in a dictionary keyed by edge two-tuple of text 1026 | labels (default=None). Only labels for the keys in the dictionary 1027 | are drawn. 1028 | 1029 | label_pos : float 1030 | Position of edge label along edge (0=head, 0.5=center, 1=tail) 1031 | 1032 | font_size : int 1033 | Font size for text labels (default=12) 1034 | 1035 | font_color : string 1036 | Font color string (default='k' black) 1037 | 1038 | font_weight : string 1039 | Font weight (default='normal') 1040 | 1041 | font_family : string 1042 | Font family (default='sans-serif') 1043 | 1044 | bbox : Matplotlib bbox 1045 | Specify text box shape and colors. 1046 | 1047 | clip_on : bool 1048 | Turn on clipping at axis boundaries (default=True) 1049 | 1050 | horizontalalignment : {'center', 'right', 'left'} 1051 | Horizontal alignment (default='center') 1052 | 1053 | verticalalignment : {'center', 'top', 'bottom', 'baseline', 'center_baseline'} 1054 | Vertical alignment (default='center') 1055 | 1056 | ax : Matplotlib Axes object, optional 1057 | Draw the graph in the specified Matplotlib axes. 1058 | 1059 | Returns 1060 | ------- 1061 | dict 1062 | `dict` of labels keyed on the edges 1063 | 1064 | Examples 1065 | -------- 1066 | >>> G = nx.dodecahedral_graph() 1067 | >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G)) 1068 | 1069 | Also see the NetworkX drawing examples at 1070 | https://networkx.github.io/documentation/latest/auto_examples/index.html 1071 | 1072 | See Also 1073 | -------- 1074 | draw() 1075 | draw_networkx() 1076 | draw_networkx_nodes() 1077 | draw_networkx_edges() 1078 | draw_networkx_labels() 1079 | """ 1080 | try: 1081 | import matplotlib.pyplot as plt 1082 | import numpy as np 1083 | except ImportError as e: 1084 | raise ImportError("Matplotlib required for draw()") from e 1085 | except RuntimeError: 1086 | print("Matplotlib unable to open display") 1087 | raise 1088 | 1089 | if ax is None: 1090 | ax = plt.gca() 1091 | if edge_labels is None: 1092 | labels = {(u, v): d for u, v, d in G.edges(data=True)} 1093 | else: 1094 | labels = edge_labels 1095 | text_items = {} 1096 | for (n1, n2), label in labels.items(): 1097 | (x1, y1) = pos[n1] 1098 | (x2, y2) = pos[n2] 1099 | (x, y) = ( 1100 | x1 * label_pos + x2 * (1.0 - label_pos), 1101 | y1 * label_pos + y2 * (1.0 - label_pos), 1102 | ) 1103 | 1104 | if rotate: 1105 | # in degrees 1106 | angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360 1107 | # make label orientation "right-side-up" 1108 | if angle > 90: 1109 | angle -= 180 1110 | if angle < -90: 1111 | angle += 180 1112 | # transform data coordinate angle to screen coordinate angle 1113 | xy = np.array((x, y)) 1114 | trans_angle = ax.transData.transform_angles( 1115 | np.array((angle,)), xy.reshape((1, 2)) 1116 | )[0] 1117 | else: 1118 | trans_angle = 0.0 1119 | # use default box of white with white border 1120 | if bbox is None: 1121 | bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)) 1122 | if not isinstance(label, str): 1123 | label = str(label) # this makes "1" and 1 labeled the same 1124 | 1125 | t = ax.text( 1126 | x, 1127 | y, 1128 | label, 1129 | size=font_size, 1130 | color=font_color, 1131 | family=font_family, 1132 | weight=font_weight, 1133 | alpha=alpha, 1134 | horizontalalignment=horizontalalignment, 1135 | verticalalignment=verticalalignment, 1136 | rotation=trans_angle, 1137 | transform=ax.transData, 1138 | bbox=bbox, 1139 | zorder=1, 1140 | clip_on=True, 1141 | ) 1142 | text_items[(n1, n2)] = t 1143 | 1144 | ax.tick_params( 1145 | axis="both", 1146 | which="both", 1147 | bottom=False, 1148 | left=False, 1149 | labelbottom=False, 1150 | labelleft=False, 1151 | ) 1152 | 1153 | return text_items 1154 | 1155 | 1156 | 1157 | def draw_circular(G, **kwargs): 1158 | """Draw the graph G with a circular layout. 1159 | 1160 | Parameters 1161 | ---------- 1162 | G : graph 1163 | A networkx graph 1164 | 1165 | kwargs : optional keywords 1166 | See networkx.draw_networkx() for a description of optional keywords, 1167 | with the exception of the pos parameter which is not used by this 1168 | function. 1169 | """ 1170 | draw(G, circular_layout(G), **kwargs) 1171 | 1172 | 1173 | 1174 | def draw_kamada_kawai(G, **kwargs): 1175 | """Draw the graph G with a Kamada-Kawai force-directed layout. 1176 | 1177 | Parameters 1178 | ---------- 1179 | G : graph 1180 | A networkx graph 1181 | 1182 | kwargs : optional keywords 1183 | See networkx.draw_networkx() for a description of optional keywords, 1184 | with the exception of the pos parameter which is not used by this 1185 | function. 1186 | """ 1187 | draw(G, kamada_kawai_layout(G), **kwargs) 1188 | 1189 | 1190 | 1191 | def draw_random(G, **kwargs): 1192 | """Draw the graph G with a random layout. 1193 | 1194 | Parameters 1195 | ---------- 1196 | G : graph 1197 | A networkx graph 1198 | 1199 | kwargs : optional keywords 1200 | See networkx.draw_networkx() for a description of optional keywords, 1201 | with the exception of the pos parameter which is not used by this 1202 | function. 1203 | """ 1204 | draw(G, random_layout(G), **kwargs) 1205 | 1206 | 1207 | 1208 | def draw_spectral(G, **kwargs): 1209 | """Draw the graph G with a spectral 2D layout. 1210 | 1211 | Using the unnormalized Laplacian, the layout shows possible clusters of 1212 | nodes which are an approximation of the ratio cut. The positions are the 1213 | entries of the second and third eigenvectors corresponding to the 1214 | ascending eigenvalues starting from the second one. 1215 | 1216 | Parameters 1217 | ---------- 1218 | G : graph 1219 | A networkx graph 1220 | 1221 | kwargs : optional keywords 1222 | See networkx.draw_networkx() for a description of optional keywords, 1223 | with the exception of the pos parameter which is not used by this 1224 | function. 1225 | """ 1226 | draw(G, spectral_layout(G), **kwargs) 1227 | 1228 | 1229 | 1230 | def draw_spring(G, **kwargs): 1231 | """Draw the graph G with a spring layout. 1232 | 1233 | Parameters 1234 | ---------- 1235 | G : graph 1236 | A networkx graph 1237 | 1238 | kwargs : optional keywords 1239 | See networkx.draw_networkx() for a description of optional keywords, 1240 | with the exception of the pos parameter which is not used by this 1241 | function. 1242 | """ 1243 | draw(G, spring_layout(G), **kwargs) 1244 | 1245 | 1246 | 1247 | def draw_shell(G, **kwargs): 1248 | """Draw networkx graph with shell layout. 1249 | 1250 | Parameters 1251 | ---------- 1252 | G : graph 1253 | A networkx graph 1254 | 1255 | kwargs : optional keywords 1256 | See networkx.draw_networkx() for a description of optional keywords, 1257 | with the exception of the pos parameter which is not used by this 1258 | function. 1259 | """ 1260 | nlist = kwargs.get("nlist", None) 1261 | if nlist is not None: 1262 | del kwargs["nlist"] 1263 | draw(G, shell_layout(G, nlist=nlist), **kwargs) 1264 | 1265 | 1266 | 1267 | def draw_planar(G, **kwargs): 1268 | """Draw a planar networkx graph with planar layout. 1269 | 1270 | Parameters 1271 | ---------- 1272 | G : graph 1273 | A planar networkx graph 1274 | 1275 | kwargs : optional keywords 1276 | See networkx.draw_networkx() for a description of optional keywords, 1277 | with the exception of the pos parameter which is not used by this 1278 | function. 1279 | """ 1280 | draw(G, planar_layout(G), **kwargs) 1281 | 1282 | 1283 | 1284 | def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None): 1285 | """Apply an alpha (or list of alphas) to the colors provided. 1286 | 1287 | Parameters 1288 | ---------- 1289 | 1290 | colors : color string, or array of floats 1291 | Color of element. Can be a single color format string (default='r'), 1292 | or a sequence of colors with the same length as nodelist. 1293 | If numeric values are specified they will be mapped to 1294 | colors using the cmap and vmin,vmax parameters. See 1295 | matplotlib.scatter for more details. 1296 | 1297 | alpha : float or array of floats 1298 | Alpha values for elements. This can be a single alpha value, in 1299 | which case it will be applied to all the elements of color. Otherwise, 1300 | if it is an array, the elements of alpha will be applied to the colors 1301 | in order (cycling through alpha multiple times if necessary). 1302 | 1303 | elem_list : array of networkx objects 1304 | The list of elements which are being colored. These could be nodes, 1305 | edges or labels. 1306 | 1307 | cmap : matplotlib colormap 1308 | Color map for use if colors is a list of floats corresponding to points 1309 | on a color mapping. 1310 | 1311 | vmin, vmax : float 1312 | Minimum and maximum values for normalizing colors if a color mapping is 1313 | used. 1314 | 1315 | Returns 1316 | ------- 1317 | 1318 | rgba_colors : numpy ndarray 1319 | Array containing RGBA format values for each of the node colours. 1320 | 1321 | """ 1322 | from itertools import islice, cycle 1323 | 1324 | try: 1325 | import numpy as np 1326 | from matplotlib.colors import colorConverter 1327 | import matplotlib.cm as cm 1328 | except ImportError as e: 1329 | raise ImportError("Matplotlib required for draw()") from e 1330 | 1331 | # If we have been provided with a list of numbers as long as elem_list, 1332 | # apply the color mapping. 1333 | if len(colors) == len(elem_list) and isinstance(colors[0], Number): 1334 | mapper = cm.ScalarMappable(cmap=cmap) 1335 | mapper.set_clim(vmin, vmax) 1336 | rgba_colors = mapper.to_rgba(colors) 1337 | # Otherwise, convert colors to matplotlib's RGB using the colorConverter 1338 | # object. These are converted to numpy ndarrays to be consistent with the 1339 | # to_rgba method of ScalarMappable. 1340 | else: 1341 | try: 1342 | rgba_colors = np.array([colorConverter.to_rgba(colors)]) 1343 | except ValueError: 1344 | rgba_colors = np.array([colorConverter.to_rgba(color) for color in colors]) 1345 | # Set the final column of the rgba_colors to have the relevant alpha values 1346 | try: 1347 | # If alpha is longer than the number of colors, resize to the number of 1348 | # elements. Also, if rgba_colors.size (the number of elements of 1349 | # rgba_colors) is the same as the number of elements, resize the array, 1350 | # to avoid it being interpreted as a colormap by scatter() 1351 | if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list): 1352 | rgba_colors = np.resize(rgba_colors, (len(elem_list), 4)) 1353 | rgba_colors[1:, 0] = rgba_colors[0, 0] 1354 | rgba_colors[1:, 1] = rgba_colors[0, 1] 1355 | rgba_colors[1:, 2] = rgba_colors[0, 2] 1356 | rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors))) 1357 | except TypeError: 1358 | rgba_colors[:, -1] = alpha 1359 | return rgba_colors -------------------------------------------------------------------------------- /dataset_utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | import pickle 4 | 5 | 6 | def read_pickle(name): 7 | with open(name, "rb") as f: 8 | data = pickle.load(f) 9 | return data 10 | 11 | 12 | def write_pickle(data, name): 13 | with open(name, "wb") as f: 14 | pickle.dump(data, f) 15 | 16 | 17 | class ToyDataset(Dataset): 18 | def __init__(self, pkl, domain_id, opt=None): 19 | idx = pkl["domain"] == domain_id 20 | self.data = pkl["data"][idx].astype(np.float32) 21 | self.label = pkl["label"][idx].astype(np.int64) 22 | self.domain = domain_id 23 | 24 | # if opt.normalize_domain: 25 | # print('===> Normalize in every domain') 26 | # self.data_m, self.data_s = self.data.mean(0, keepdims=True), self.data.std(0, keepdims=True) 27 | # self.data = (self.data - self.data_m) / self.data_s 28 | 29 | def __getitem__(self, idx): 30 | return self.data[idx], self.label[idx], self.domain 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | 36 | class SeqToyDataset(Dataset): 37 | def __init__(self, datasets, size=3 * 200): 38 | self.datasets = datasets 39 | self.size = size 40 | print( 41 | "SeqDataset Size {} Sub Size {}".format( 42 | size, [len(ds) for ds in datasets] 43 | ) 44 | ) 45 | 46 | def __len__(self): 47 | return self.size 48 | 49 | def __getitem__(self, i): 50 | return [ds[i] for ds in self.datasets] 51 | -------------------------------------------------------------------------------- /derive_g_encode/499_pred_GDA_new.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/derive_g_encode/499_pred_GDA_new.pkl -------------------------------------------------------------------------------- /derive_g_encode/derive.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def read_pickle(name): 5 | with open(name, "rb") as f: 6 | data = pickle.load(f) 7 | return data 8 | 9 | 10 | def write_pickle(data, name): 11 | with open(name, "wb") as f: 12 | pickle.dump(data, f) 13 | 14 | 15 | # read_file = "499_pred_GDA_new.pkl" 16 | read_file = "499_pred.pkl" 17 | num_domain = 60 18 | 19 | info = read_pickle(read_file) 20 | z = info["z"] 21 | 22 | # print(z) 23 | 24 | g_encode = dict() 25 | for i in range(num_domain): 26 | g_encode[str(i)] = z[i] 27 | 28 | write_pickle(g_encode, "g_encode_60.pkl") 29 | print("success!") 30 | -------------------------------------------------------------------------------- /derive_g_encode/g_encode.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/derive_g_encode/g_encode.pkl -------------------------------------------------------------------------------- /derive_g_encode/g_encode_15.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/derive_g_encode/g_encode_15.pkl -------------------------------------------------------------------------------- /derive_g_encode/g_encode_60.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/derive_g_encode/g_encode_60.pkl -------------------------------------------------------------------------------- /fig/GRDA-DG-15-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/GRDA-DG-15-results.png -------------------------------------------------------------------------------- /fig/GRDA-domain-graph-US.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/GRDA-domain-graph-US.png -------------------------------------------------------------------------------- /fig/blog-method-DA-vs-GRDA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/blog-method-DA-vs-GRDA.png -------------------------------------------------------------------------------- /fig/compcar_quantitive_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/compcar_quantitive_result.jpg -------------------------------------------------------------------------------- /fig/dg_15_60_quantitive_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/dg_15_60_quantitive_result.jpg -------------------------------------------------------------------------------- /fig/tpt_48_quantitive_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/tpt_48_quantitive_result.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import pickle 5 | 6 | # import the experiment setting 7 | 8 | from configs.config_random_15_new import opt 9 | 10 | # load the data 11 | from torch.utils.data import DataLoader 12 | from dataset_utils.dataset import ToyDataset, SeqToyDataset 13 | 14 | # actually the config doesn't change much 15 | # from configs.config_random_60_new import opt 16 | 17 | np.random.seed(opt.seed) 18 | random.seed(opt.seed) 19 | torch.manual_seed(opt.seed) 20 | 21 | if opt.model == "DANN": 22 | from model.model import DANN as Model 23 | elif opt.model == "GDA": 24 | from model.model import GDA as Model 25 | elif opt.model == "CDANN": 26 | from model.model import CDANN as Model 27 | 28 | opt.cond_disc = True 29 | elif opt.model == "ADDA": 30 | from model.model import ADDA as Model 31 | elif opt.model == "MDD": 32 | from model.model import MDD as Model 33 | model = Model(opt).to(opt.device) # .double() 34 | 35 | data_source = opt.dataset 36 | 37 | with open(data_source, "rb") as data_file: 38 | data_pkl = pickle.load(data_file) 39 | print(f"Data: {data_pkl['data'].shape}\nLabel: {data_pkl['label'].shape}") 40 | 41 | # build dataset 42 | opt.A = data_pkl["A"] 43 | 44 | data = data_pkl["data"] 45 | data_mean = data.mean(0, keepdims=True) 46 | data_std = data.std(0, keepdims=True) 47 | data_pkl["data"] = (data - data_mean) / data_std # normalize the raw data 48 | datasets = [ 49 | ToyDataset(data_pkl, i, opt) for i in range(opt.num_domain) 50 | ] # sub dataset for each domain 51 | 52 | 53 | dataset = SeqToyDataset( 54 | datasets, size=len(datasets[0]) 55 | ) # mix sub dataset to a large one 56 | dataloader = DataLoader( 57 | dataset=dataset, shuffle=True, batch_size=opt.batch_size 58 | ) 59 | 60 | # train 61 | for epoch in range(opt.num_epoch): 62 | model.learn(epoch, dataloader) 63 | if (epoch + 1) % opt.save_interval == 0 or (epoch + 1) == opt.num_epoch: 64 | model.save() 65 | if (epoch + 1) % opt.test_interval == 0 or (epoch + 1) == opt.num_epoch: 66 | model.test(epoch, dataloader) 67 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.optim.lr_scheduler as lr_scheduler 7 | import numpy as np 8 | from model.modules import ( 9 | FeatureNet, 10 | PredNet, 11 | GNet, 12 | ClassDiscNet, 13 | CondClassDiscNet, 14 | DiscNet, 15 | GraphDNet, 16 | ) 17 | import pickle 18 | from visdom import Visdom 19 | 20 | # =========================================================================================================== 21 | 22 | 23 | def to_np(x): 24 | return x.detach().cpu().numpy() 25 | 26 | 27 | def to_tensor(x, device="cuda"): 28 | if isinstance(x, np.ndarray): 29 | x = torch.from_numpy(x).to(device) 30 | else: 31 | x = x.to(device) 32 | return x 33 | 34 | 35 | def flat(x): 36 | n, m = x.shape[:2] 37 | return x.reshape(n * m, *x.shape[2:]) 38 | 39 | 40 | def write_pickle(data, name): 41 | with open(name, "wb") as f: 42 | pickle.dump(data, f) 43 | 44 | 45 | # ====================================================================================================================== 46 | 47 | 48 | # the base model 49 | class BaseModel(nn.Module): 50 | def __init__(self, opt): 51 | super(BaseModel, self).__init__() 52 | # set output format 53 | np.set_printoptions(suppress=True, precision=6) 54 | 55 | self.opt = opt 56 | self.device = opt.device 57 | self.batch_size = opt.batch_size 58 | # visualizaiton 59 | self.use_visdom = opt.use_visdom 60 | self.use_g_encode = opt.use_g_encode 61 | if opt.use_visdom: 62 | self.env = Visdom(port=opt.visdom_port) 63 | self.test_pane = dict() 64 | 65 | self.num_domain = opt.num_domain 66 | if self.opt.test_on_all_dmn: 67 | self.test_dmn_num = self.num_domain 68 | else: 69 | self.test_dmn_num = self.opt.tgt_dmn_num 70 | 71 | self.train_log = self.opt.outf + "/loss.log" 72 | self.model_path = opt.outf + "/model.pth" 73 | self.out_pic_f = opt.outf + "/plt_pic" 74 | if not os.path.exists(self.opt.outf): 75 | os.mkdir(self.opt.outf) 76 | if not os.path.exists(self.out_pic_f): 77 | os.mkdir(self.out_pic_f) 78 | with open(self.train_log, "w") as f: 79 | f.write("log start!\n") 80 | 81 | mask_list = np.zeros(opt.num_domain) 82 | mask_list[opt.src_domain] = 1 83 | self.domain_mask = torch.IntTensor(mask_list).to( 84 | opt.device 85 | ) 86 | 87 | def learn(self, epoch, dataloader): 88 | self.train() 89 | if self.use_g_encode: 90 | self.netG.eval() 91 | self.epoch = epoch 92 | loss_values = {loss: 0 for loss in self.loss_names} 93 | 94 | count = 0 95 | for data in dataloader: 96 | count += 1 97 | self.__set_input__(data) 98 | self.__train_forward__() 99 | new_loss_values = self.__optimize__() 100 | 101 | # for the loss visualization 102 | for key, loss in new_loss_values.items(): 103 | loss_values[key] += loss 104 | 105 | for key, _ in new_loss_values.items(): 106 | loss_values[key] /= count 107 | 108 | if self.use_visdom: 109 | self.__vis_loss__(loss_values) 110 | 111 | if (self.epoch + 1) % 10 == 0: 112 | print("epoch {}: {}".format(self.epoch, loss_values)) 113 | 114 | # learning rate decay 115 | for lr_scheduler in self.lr_schedulers: 116 | lr_scheduler.step() 117 | 118 | def test(self, epoch, dataloader): 119 | self.eval() 120 | 121 | acc_curve = [] 122 | l_x = [] 123 | l_domain = [] 124 | l_label = [] 125 | l_encode = [] 126 | z_seq = 0 127 | 128 | for data in dataloader: 129 | self.__set_input__(data) 130 | 131 | # forward 132 | with torch.no_grad(): 133 | self.__test_forward__() 134 | 135 | acc_curve.append( 136 | self.g_seq.eq(self.y_seq) 137 | .to(torch.float) 138 | .mean(-1, keepdim=True) 139 | ) 140 | l_x.append(to_np(self.x_seq)) 141 | l_domain.append(to_np(self.domain_seq)) 142 | l_encode.append(to_np(self.e_seq)) 143 | l_label.append(to_np(self.g_seq)) 144 | 145 | x_all = np.concatenate(l_x, axis=1) 146 | e_all = np.concatenate(l_encode, axis=1) 147 | domain_all = np.concatenate(l_domain, axis=1) 148 | label_all = np.concatenate(l_label, axis=1) 149 | 150 | z_seq = to_np(self.z_seq) 151 | z_seq_all = z_seq[ 152 | 0 : self.tmp_batch_size * self.test_dmn_num : self.tmp_batch_size, : 153 | ] 154 | 155 | d_all = dict() 156 | 157 | d_all["data"] = flat(x_all) 158 | d_all["domain"] = flat(domain_all) 159 | d_all["label"] = flat(label_all) 160 | d_all["encodeing"] = flat(e_all) 161 | d_all["z"] = z_seq_all 162 | 163 | acc = to_np(torch.cat(acc_curve, 1).mean(-1)) 164 | test_acc = ( 165 | (acc.sum() - acc[self.opt.src_domain].sum()) 166 | / (self.opt.num_target) 167 | * 100 168 | ) 169 | acc_msg = "[Test][{}] Accuracy: total average {:.1f}, test average {:.1f}, in each domain {}".format( 170 | epoch, acc.mean() * 100, test_acc, np.around(acc * 100, decimals=1) 171 | ) 172 | self.__log_write__(acc_msg) 173 | if self.use_visdom: 174 | self.__vis_test_error__(test_acc, "test acc") 175 | 176 | d_all["acc_msg"] = acc_msg 177 | 178 | write_pickle(d_all, self.opt.outf + "/" + str(epoch) + "_pred.pkl") 179 | 180 | def __vis_test_error__(self, loss, title): 181 | if self.epoch == self.opt.test_interval - 1: 182 | # initialize 183 | self.test_pane[title] = self.env.line( 184 | X=np.array([self.epoch]), 185 | Y=np.array([loss]), 186 | opts=dict(title=title), 187 | ) 188 | else: 189 | self.env.line( 190 | X=np.array([self.epoch]), 191 | Y=np.array([loss]), 192 | win=self.test_pane[title], 193 | update="append", 194 | ) 195 | 196 | def save(self): 197 | torch.save(self.state_dict(), self.model_path) 198 | 199 | def __set_input__(self, data, train=True): 200 | """ 201 | :param 202 | x_seq: Number of domain x Batch size x Data dim 203 | y_seq: Number of domain x Batch size x Predict Data dim 204 | one_hot_seq: Number of domain x Batch size x Number of vertices (domains) 205 | domain_seq: Number of domain x Batch size x domain dim (1) 206 | """ 207 | # the domain seq is in d3!! 208 | x_seq, y_seq, domain_seq = ( 209 | [d[0][None, :, :] for d in data], 210 | [d[1][None, :] for d in data], 211 | [d[2][None, :] for d in data], 212 | ) 213 | self.x_seq = torch.cat(x_seq, 0).to(self.device) 214 | self.y_seq = torch.cat(y_seq, 0).to(self.device) 215 | self.domain_seq = torch.cat(domain_seq, 0).to(self.device) 216 | self.tmp_batch_size = self.x_seq.shape[1] 217 | one_hot_seq = [ 218 | torch.nn.functional.one_hot(d[2], self.num_domain) 219 | for d in data 220 | ] 221 | 222 | if train: 223 | self.one_hot_seq = ( 224 | torch.cat(one_hot_seq, 0) 225 | .reshape(self.num_domain, self.tmp_batch_size, -1) 226 | .to(self.device) 227 | ) 228 | else: 229 | self.one_hot_seq = ( 230 | torch.cat(one_hot_seq, 0) 231 | .reshape(self.test_dmn_num, self.tmp_batch_size, -1) 232 | .to(self.device) 233 | ) 234 | 235 | def __train_forward__(self): 236 | pass 237 | 238 | def __test_forward__(self): 239 | self.z_seq = self.netG(self.one_hot_seq) 240 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data 241 | self.f_seq = self.netF(self.e_seq) 242 | self.g_seq = torch.argmax( 243 | self.f_seq.detach(), dim=2 244 | ) # class of the prediction 245 | 246 | def __optimize__(self): 247 | loss_value = dict() 248 | if not self.use_g_encode: 249 | self.loss_G = self.__loss_G__() 250 | 251 | if self.opt.lambda_gan != 0: 252 | self.loss_D = self.__loss_D__() 253 | 254 | self.loss_E, self.loss_E_pred, self.loss_E_gan = self.__loss_EF__() 255 | 256 | if not self.use_g_encode: 257 | self.optimizer_G.zero_grad() 258 | self.loss_G.backward(retain_graph=True) 259 | self.optimizer_D.zero_grad() 260 | self.loss_D.backward(retain_graph=True) 261 | self.optimizer_EF.zero_grad() 262 | self.loss_E.backward() 263 | 264 | if not self.use_g_encode: 265 | self.optimizer_G.step() 266 | self.optimizer_D.step() 267 | self.optimizer_EF.step() 268 | 269 | if not self.use_g_encode: 270 | loss_value["G"] = self.loss_G.item() 271 | loss_value["D"], loss_value["E_pred"], loss_value["E_gan"] = \ 272 | self.loss_D.item(), self.loss_E_pred.item(), self.loss_E_gan.item() 273 | return loss_value 274 | 275 | def __loss_G__(self): 276 | criterion = nn.BCEWithLogitsLoss() 277 | 278 | sub_graph = self.__sub_graph__(my_sample_v=self.opt.sample_v_g) 279 | errorG = torch.zeros((1,)).to(self.device) 280 | sample_v = self.opt.sample_v_g 281 | 282 | for i in range(sample_v): 283 | v_i = sub_graph[i] 284 | for j in range(i + 1, sample_v): 285 | v_j = sub_graph[j] 286 | label = torch.tensor(self.opt.A[v_i][v_j]).to(self.device) 287 | output = ( 288 | self.z_seq[v_i * self.tmp_batch_size] 289 | * self.z_seq[v_j * self.tmp_batch_size] 290 | ).sum() 291 | errorG += criterion(output, label) 292 | 293 | errorG /= sample_v * (sample_v - 1) / 2 294 | return errorG 295 | 296 | def __loss_D__(self): 297 | pass 298 | 299 | def __loss_EF__(self): 300 | pass 301 | 302 | def __log_write__(self, loss_msg): 303 | print(loss_msg) 304 | with open(self.train_log, "a") as f: 305 | f.write(loss_msg + "\n") 306 | 307 | def __vis_loss__(self, loss_values): 308 | if self.epoch == 0: 309 | self.panes = { 310 | loss_name: self.env.line( 311 | X=np.array([self.epoch]), 312 | Y=np.array([loss_values[loss_name]]), 313 | opts=dict(title="loss for {} on epochs".format(loss_name)), 314 | ) 315 | for loss_name in self.loss_names 316 | } 317 | else: 318 | for loss_name in self.loss_names: 319 | self.env.line( 320 | X=np.array([self.epoch]), 321 | Y=np.array([loss_values[loss_name]]), 322 | win=self.panes[loss_name], 323 | update="append", 324 | ) 325 | 326 | def __init_weight__(self, net=None): 327 | if net is None: 328 | net = self 329 | for m in net.modules(): 330 | if isinstance(m, nn.Linear): 331 | nn.init.normal_(m.weight, mean=0, std=0.01) 332 | # nn.init.normal_(m.weight, mean=0, std=0.1) 333 | # nn.init.xavier_normal_(m.weight, gain=10) 334 | nn.init.constant_(m.bias, val=0) 335 | 336 | # for graph random sampling: 337 | def __rand_walk__(self, vis, left_nodes): 338 | chain_node = [] 339 | node_num = 0 340 | # choose node 341 | node_index = np.where(vis == 0)[0] 342 | st = np.random.choice(node_index) 343 | vis[st] = 1 344 | chain_node.append(st) 345 | left_nodes -= 1 346 | node_num += 1 347 | 348 | cur_node = st 349 | while left_nodes > 0: 350 | nx_node = -1 351 | 352 | node_to_choose = np.where(vis == 0)[0] 353 | num = node_to_choose.shape[0] 354 | node_to_choose = np.random.choice( 355 | node_to_choose, num, replace=False 356 | ) 357 | 358 | for i in node_to_choose: 359 | if cur_node != i: 360 | # have an edge and doesn't visit 361 | if self.opt.A[cur_node][i] and not vis[i]: 362 | nx_node = i 363 | vis[nx_node] = 1 364 | chain_node.append(nx_node) 365 | left_nodes -= 1 366 | node_num += 1 367 | break 368 | if nx_node >= 0: 369 | cur_node = nx_node 370 | else: 371 | break 372 | return chain_node, node_num 373 | 374 | def __sub_graph__(self, my_sample_v): 375 | if np.random.randint(0, 2) == 0: 376 | return np.random.choice( 377 | self.num_domain, size=my_sample_v, replace=False 378 | ) 379 | 380 | # subsample a chain (or multiple chains in graph) 381 | left_nodes = my_sample_v 382 | choosen_node = [] 383 | vis = np.zeros(self.num_domain) 384 | while left_nodes > 0: 385 | chain_node, node_num = self.__rand_walk__(vis, left_nodes) 386 | choosen_node.extend(chain_node) 387 | left_nodes -= node_num 388 | 389 | return choosen_node 390 | 391 | 392 | class DANN(BaseModel): 393 | """ 394 | DANN Model 395 | """ 396 | 397 | def __init__(self, opt): 398 | super(DANN, self).__init__(opt) 399 | self.netE = FeatureNet(opt).to(opt.device) 400 | self.netF = PredNet(opt).to(opt.device) 401 | self.netG = GNet(opt).to(opt.device) 402 | self.netD = ClassDiscNet(opt).to(opt.device) 403 | 404 | self.__init_weight__() 405 | EF_parameters = list(self.netE.parameters()) + list( 406 | self.netF.parameters() 407 | ) 408 | self.optimizer_EF = optim.Adam( 409 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999) 410 | ) 411 | self.optimizer_D = optim.Adam( 412 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999) 413 | ) 414 | if not self.use_g_encode: 415 | self.optimizer_G = optim.Adam( 416 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999) 417 | ) 418 | 419 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR( 420 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100) 421 | ) 422 | self.lr_scheduler_D = lr_scheduler.ExponentialLR( 423 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100) 424 | ) 425 | if not self.use_g_encode: 426 | self.lr_scheduler_G = lr_scheduler.ExponentialLR( 427 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100) 428 | ) 429 | self.lr_schedulers = [ 430 | self.lr_scheduler_EF, 431 | self.lr_scheduler_D, 432 | self.lr_scheduler_G, 433 | ] 434 | else: 435 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D] 436 | self.loss_names = ["E_pred", "E_gan", "D", "G"] 437 | if self.use_g_encode: 438 | self.loss_names.remove("G") 439 | 440 | self.lambda_gan = self.opt.lambda_gan 441 | 442 | def __train_forward__(self): 443 | self.z_seq = self.netG(self.one_hot_seq) 444 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data 445 | self.f_seq = self.netF(self.e_seq) 446 | self.d_seq = self.netD(self.e_seq) 447 | 448 | def __loss_D__(self): 449 | return F.nll_loss(flat(self.d_seq), flat(self.domain_seq)) 450 | 451 | def __loss_EF__(self): 452 | # we have already calculate D loss before EF loss. Directly use it. 453 | loss_E_gan = - self.loss_D 454 | 455 | y_seq_source = self.y_seq[self.domain_mask == 1] 456 | f_seq_source = self.f_seq[self.domain_mask == 1] 457 | 458 | loss_E_pred = F.nll_loss( 459 | flat(f_seq_source), flat(y_seq_source) 460 | ) 461 | 462 | loss_E = loss_E_gan * self.lambda_gan + loss_E_pred 463 | 464 | return loss_E, loss_E_pred, loss_E_gan 465 | 466 | 467 | class CDANN(BaseModel): 468 | """ 469 | CDANN Model 470 | """ 471 | 472 | def __init__(self, opt): 473 | super(CDANN, self).__init__(opt) 474 | self.netE = FeatureNet(opt).to(opt.device) 475 | self.netF = PredNet(opt).to(opt.device) 476 | self.netG = GNet(opt).to(opt.device) 477 | self.netD = CondClassDiscNet(opt).to(opt.device) 478 | 479 | self.__init_weight__() 480 | EF_parameters = list(self.netE.parameters()) + list( 481 | self.netF.parameters() 482 | ) 483 | self.optimizer_EF = optim.Adam( 484 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999) 485 | ) 486 | self.optimizer_D = optim.Adam( 487 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999) 488 | ) 489 | 490 | if not self.use_g_encode: 491 | self.optimizer_G = optim.Adam( 492 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999) 493 | ) 494 | 495 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR( 496 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100) 497 | ) 498 | self.lr_scheduler_D = lr_scheduler.ExponentialLR( 499 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100) 500 | ) 501 | if not self.use_g_encode: 502 | self.lr_scheduler_G = lr_scheduler.ExponentialLR( 503 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100) 504 | ) 505 | self.lr_schedulers = [ 506 | self.lr_scheduler_EF, 507 | self.lr_scheduler_D, 508 | self.lr_scheduler_G, 509 | ] 510 | else: 511 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D] 512 | self.loss_names = ["E_pred", "E_gan", "D", "G"] 513 | if self.use_g_encode: 514 | self.loss_names.remove("G") 515 | 516 | self.lambda_gan = self.opt.lambda_gan 517 | 518 | def __train_forward__(self): 519 | self.z_seq = self.netG(self.one_hot_seq) 520 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data 521 | self.f_seq = self.netF(self.e_seq) 522 | self.f_seq_sig = torch.sigmoid(self.f_seq.detach()) 523 | 524 | def __loss_D__(self): 525 | d_seq = self.netD(self.e_seq.detach(), self.f_seq_sig.detach()) 526 | return F.nll_loss(flat(d_seq), flat(self.domain_seq)) 527 | 528 | def __loss_EF__(self): 529 | d_seq = self.netD(self.e_seq, self.f_seq_sig.detach()) 530 | 531 | loss_E_gan = -F.nll_loss(flat(d_seq), flat(self.domain_seq)) 532 | 533 | y_seq_source = self.y_seq[self.domain_mask == 1] 534 | f_seq_source = self.f_seq[self.domain_mask == 1] 535 | 536 | loss_E_pred = F.nll_loss( 537 | flat(f_seq_source), flat(y_seq_source) 538 | ) 539 | 540 | loss_E = loss_E_gan * self.lambda_gan + loss_E_pred 541 | return loss_E, loss_E_pred, loss_E_gan 542 | 543 | 544 | class ADDA(BaseModel): 545 | def __init__(self, opt): 546 | super(ADDA, self).__init__(opt) 547 | self.netE = FeatureNet(opt).to(opt.device) 548 | self.netF = PredNet(opt).to(opt.device) 549 | self.netG = GNet(opt).to(opt.device) 550 | self.netD = DiscNet(opt).to(opt.device) 551 | self.__init_weight__() 552 | EF_parameters = list(self.netE.parameters()) + list( 553 | self.netF.parameters() 554 | ) 555 | self.optimizer_EF = optim.Adam( 556 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999) 557 | ) 558 | self.optimizer_D = optim.Adam( 559 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999) 560 | ) 561 | 562 | if not self.use_g_encode: 563 | self.optimizer_G = optim.Adam( 564 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999) 565 | ) 566 | 567 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR( 568 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100) 569 | ) 570 | self.lr_scheduler_D = lr_scheduler.ExponentialLR( 571 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100) 572 | ) 573 | if not self.use_g_encode: 574 | self.lr_scheduler_G = lr_scheduler.ExponentialLR( 575 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100) 576 | ) 577 | self.lr_schedulers = [ 578 | self.lr_scheduler_EF, 579 | self.lr_scheduler_D, 580 | self.lr_scheduler_G, 581 | ] 582 | else: 583 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D] 584 | self.loss_names = ["E_pred", "E_gan", "D", "G"] 585 | if self.use_g_encode: 586 | self.loss_names.remove("G") 587 | 588 | def __train_forward__(self): 589 | self.z_seq = self.netG(self.one_hot_seq) 590 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data 591 | self.f_seq = self.netF(self.e_seq) 592 | 593 | def __loss_D__(self): 594 | d_seq = self.netD(self.e_seq.detach()) 595 | d_seq_source = d_seq[self.domain_mask == 1] 596 | d_seq_target = d_seq[self.domain_mask == 0] 597 | # D: discriminator loss from classifying source v.s. target 598 | loss_D = ( 599 | -torch.log(d_seq_source + 1e-10).mean() 600 | - torch.log(1 - d_seq_target + 1e-10).mean() 601 | ) 602 | return loss_D 603 | 604 | def __loss_EF__(self): 605 | d_seq = self.netD(self.e_seq) 606 | d_seq_target = d_seq[self.domain_mask == 0] 607 | loss_E_gan = -torch.log(d_seq_target + 1e-10).mean() 608 | # E_pred: encoder loss from prediction the label 609 | y_seq_source = self.y_seq[self.domain_mask == 1] 610 | f_seq_source = self.f_seq[self.domain_mask == 1] 611 | loss_E_pred = F.nll_loss( 612 | flat(f_seq_source), flat(y_seq_source) 613 | ) 614 | 615 | loss_E = loss_E_gan * self.opt.lambda_gan + loss_E_pred 616 | 617 | return loss_E, loss_E_pred, loss_E_gan 618 | 619 | 620 | class MDD(BaseModel): 621 | """ 622 | Margin Disparity Discrepancy 623 | """ 624 | 625 | def __init__(self, opt): 626 | super(MDD, self).__init__(opt) 627 | self.netE = FeatureNet(opt).to(opt.device) 628 | self.netF = PredNet(opt).to(opt.device) 629 | self.netG = GNet(opt).to(opt.device) 630 | self.netD = PredNet(opt).to(opt.device) 631 | self.__init_weight__() 632 | EF_parameters = list(self.netE.parameters()) + list( 633 | self.netF.parameters() 634 | ) 635 | self.optimizer_EF = optim.Adam( 636 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999) 637 | ) 638 | self.optimizer_D = optim.Adam( 639 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999) 640 | ) 641 | if not self.use_g_encode: 642 | self.optimizer_G = optim.Adam( 643 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999) 644 | ) 645 | 646 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR( 647 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100) 648 | ) 649 | self.lr_scheduler_D = lr_scheduler.ExponentialLR( 650 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100) 651 | ) 652 | if not self.use_g_encode: 653 | self.lr_scheduler_G = lr_scheduler.ExponentialLR( 654 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100) 655 | ) 656 | self.lr_schedulers = [ 657 | self.lr_scheduler_EF, 658 | self.lr_scheduler_D, 659 | self.lr_scheduler_G, 660 | ] 661 | else: 662 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D] 663 | self.loss_names = ["E_pred", "E_adv", "ADV_src", "ADV_tgt", "G"] 664 | if self.use_g_encode: 665 | self.loss_names.remove("G") 666 | 667 | self.lambda_src = opt.lambda_src 668 | self.lambda_tgt = opt.lambda_tgt 669 | 670 | def __train_forward__(self): 671 | self.z_seq = self.netG(self.one_hot_seq) 672 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data 673 | self.f_seq = self.netF(self.e_seq) 674 | self.g_seq = torch.argmax( 675 | self.f_seq.detach(), dim=2 676 | ) # class of the prediction 677 | 678 | def __optimize__(self): 679 | loss_value = dict() 680 | if not self.use_g_encode: 681 | self.loss_G = self.__loss_G__() 682 | 683 | self.loss_D, self.loss_ADV_src, self.loss_ADV_tgt = self.__loss_D__() 684 | self.loss_E, self.loss_E_pred, self.loss_E_adv = self.__loss_EF__() 685 | 686 | if not self.use_g_encode: 687 | self.optimizer_G.zero_grad() 688 | self.loss_G.backward(retain_graph=True) 689 | self.optimizer_D.zero_grad() 690 | self.loss_D.backward(retain_graph=True) 691 | self.optimizer_EF.zero_grad() 692 | self.loss_E.backward() 693 | 694 | if not self.use_g_encode: 695 | self.optimizer_G.step() 696 | self.optimizer_D.step() 697 | self.optimizer_EF.step() 698 | 699 | if not self.use_g_encode: 700 | loss_value["G"] = self.loss_G.item() 701 | loss_value["ADV_src"], loss_value["ADV_tgt"], loss_value["E_pred"], loss_value["E_adv"] = \ 702 | self.loss_ADV_src.item(), self.loss_ADV_tgt.item(), self.loss_E_pred.item(), self.loss_E_adv.item() 703 | return loss_value 704 | 705 | 706 | # loss_value = dict() 707 | # if not self.use_g_encode: 708 | # loss_value["G"] = self.__optimize_G__() 709 | # # if self.opt.lambda_gan != 0: 710 | # # ( 711 | # # loss_value["ADV_src"], 712 | # # loss_value["ADV_tgt"], 713 | # # ) = self.__optimize_D__() 714 | # # else: 715 | # # loss_value["ADV_src"], loss_value["ADV_tgt"] = 0 716 | # # ( 717 | # loss_value["E_pred"], 718 | # loss_value["E_adv"], 719 | # ) = self.__optimize_EF__() 720 | # return loss_value 721 | 722 | def __loss_D__(self): 723 | # self.optimizer_D.zero_grad() 724 | 725 | # # backward process: 726 | # self.loss_D.backward(retain_graph=True) 727 | # return self.loss_D.item() 728 | f_adv, f_adv_softmax = self.netD( 729 | self.e_seq.detach(), return_softmax=True 730 | ) 731 | # agreement with netF on source domain 732 | loss_ADV_src = F.nll_loss( 733 | flat(f_adv[self.domain_mask == 1]), 734 | flat(self.g_seq[self.domain_mask == 1]), 735 | ) 736 | f_adv_tgt = torch.log( 737 | 1 - f_adv_softmax[self.domain_mask == 0] + 1e-10 738 | ) 739 | # disagreement with netF on target domain 740 | loss_ADV_tgt = F.nll_loss( 741 | flat(f_adv_tgt), flat(self.g_seq[self.domain_mask == 0]) 742 | ) 743 | # minimize the agreement on source domain while maximize the disagreement on target domain 744 | loss_D = ( 745 | loss_ADV_src * self.lambda_src 746 | + loss_ADV_tgt * self.lambda_tgt 747 | ) / (self.lambda_src + self.lambda_tgt) 748 | 749 | return loss_D, loss_ADV_src, loss_ADV_tgt 750 | 751 | def __loss_EF__(self): 752 | loss_E_pred = F.nll_loss( 753 | flat(self.f_seq[self.domain_mask == 1]), 754 | flat(self.y_seq[self.domain_mask == 1]), 755 | ) 756 | 757 | f_adv, f_adv_softmax = self.netD( 758 | self.e_seq, return_softmax=True 759 | ) 760 | loss_ADV_src = F.nll_loss( 761 | flat(f_adv[self.domain_mask == 1]), 762 | flat(self.g_seq[self.domain_mask == 1]), 763 | ) 764 | f_adv_tgt = torch.log( 765 | 1 - f_adv_softmax[self.domain_mask == 0] + 1e-10 766 | ) 767 | loss_ADV_tgt = F.nll_loss( 768 | flat(f_adv_tgt), flat(self.g_seq[self.domain_mask == 0]) 769 | ) 770 | loss_E_adv = -( 771 | loss_ADV_src * self.lambda_src 772 | + loss_ADV_tgt * self.lambda_tgt 773 | ) / (self.lambda_src + self.lambda_tgt) 774 | loss_E = loss_E_pred + self.opt.lambda_gan * loss_E_adv 775 | 776 | return loss_E, loss_E_pred, loss_E_adv 777 | 778 | 779 | class GDA(BaseModel): 780 | """ 781 | GDA Model 782 | """ 783 | 784 | def __init__(self, opt): 785 | super(GDA, self).__init__(opt) 786 | self.netE = FeatureNet(opt).to(opt.device) 787 | self.netF = PredNet(opt).to(opt.device) 788 | self.netG = GNet(opt).to(opt.device) 789 | self.netD = GraphDNet(opt).to(opt.device) 790 | self.__init_weight__() 791 | 792 | EF_parameters = list(self.netE.parameters()) + list( 793 | self.netF.parameters() 794 | ) 795 | self.optimizer_EF = optim.Adam( 796 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999) 797 | ) 798 | self.optimizer_D = optim.Adam( 799 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999) 800 | ) 801 | if not self.use_g_encode: 802 | self.optimizer_G = optim.Adam( 803 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999) 804 | ) 805 | 806 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR( 807 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100) 808 | ) 809 | self.lr_scheduler_D = lr_scheduler.ExponentialLR( 810 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100) 811 | ) 812 | if not self.use_g_encode: 813 | self.lr_scheduler_G = lr_scheduler.ExponentialLR( 814 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100) 815 | ) 816 | self.lr_schedulers = [ 817 | self.lr_scheduler_EF, 818 | self.lr_scheduler_D, 819 | self.lr_scheduler_G, 820 | ] 821 | else: 822 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D] 823 | self.loss_names = ["E_pred", "E_gan", "D", "G"] 824 | if self.use_g_encode: 825 | self.loss_names.remove("G") 826 | 827 | def __train_forward__(self): 828 | self.z_seq = self.netG(self.one_hot_seq) 829 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data 830 | self.f_seq = self.netF(self.e_seq) # prediction 831 | self.d_seq = self.netD(self.e_seq) 832 | 833 | def __loss_EF__(self): 834 | # we have already got loss D 835 | if self.opt.lambda_gan != 0: 836 | loss_E_gan = -self.loss_D 837 | else: 838 | loss_E_gan = torch.tensor( 839 | 0, dtype=torch.float, device=self.opt.device 840 | ) 841 | 842 | y_seq_source = self.y_seq[self.domain_mask == 1] 843 | f_seq_source = self.f_seq[self.domain_mask == 1] 844 | 845 | loss_E_pred = F.nll_loss(flat(f_seq_source), flat(y_seq_source)) 846 | 847 | loss_E = loss_E_gan * self.opt.lambda_gan + loss_E_pred 848 | 849 | return loss_E, loss_E_pred, loss_E_gan 850 | 851 | def __loss_D__(self): 852 | criterion = nn.BCEWithLogitsLoss() 853 | d = self.d_seq 854 | # random pick subchain and optimize the D 855 | # balance coefficient is calculate by pos/neg ratio 856 | sub_graph = self.__sub_graph__(my_sample_v=self.opt.sample_v) 857 | 858 | errorD_connected = torch.zeros((1,)).to(self.device) # .double() 859 | errorD_disconnected = torch.zeros((1,)).to(self.device) # .double() 860 | 861 | count_connected = 0 862 | count_disconnected = 0 863 | 864 | for i in range(self.opt.sample_v): 865 | v_i = sub_graph[i] 866 | # no self loop version!! 867 | for j in range(i + 1, self.opt.sample_v): 868 | v_j = sub_graph[j] 869 | label = torch.full( 870 | (self.tmp_batch_size,), 871 | self.opt.A[v_i][v_j], 872 | device=self.device, 873 | ) 874 | # dot product 875 | if v_i == v_j: 876 | idx = torch.randperm(self.tmp_batch_size) 877 | output = (d[v_i][idx] * d[v_j]).sum(1) 878 | else: 879 | output = (d[v_i] * d[v_j]).sum(1) 880 | 881 | if self.opt.A[v_i][v_j]: # connected 882 | errorD_connected += criterion(output, label) 883 | count_connected += 1 884 | else: 885 | errorD_disconnected += criterion(output, label) 886 | count_disconnected += 1 887 | 888 | errorD = 0.5 * ( 889 | errorD_connected / count_connected 890 | + errorD_disconnected / count_disconnected 891 | ) 892 | # this is a loss balance 893 | return errorD * self.num_domain 894 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Identity(nn.Module): 8 | def __init__(self): 9 | super(Identity, self).__init__() 10 | 11 | def forward(self, x): 12 | return x 13 | 14 | 15 | class GNet(nn.Module): 16 | def __init__(self, opt): 17 | super(GNet, self).__init__() 18 | self.use_g_encode = opt.use_g_encode 19 | if self.use_g_encode: 20 | G = np.zeros((opt.num_domain, opt.nt)) 21 | for i in range(opt.num_domain): 22 | G[i] = opt.g_encode[str(i)] 23 | self.G = torch.from_numpy(G).float().to(device=opt.device) 24 | else: 25 | self.fc1 = nn.Linear(opt.num_domain, opt.nh) 26 | self.fc_final = nn.Linear(opt.nh, opt.nt) 27 | 28 | def forward(self, x): 29 | re = x.dim() == 3 30 | if re: 31 | T, B, C = x.shape 32 | x = x.reshape(T * B, -1) 33 | 34 | if self.use_g_encode: 35 | x = torch.matmul(x.float(), self.G) 36 | else: 37 | x = F.relu(self.fc1(x.float())) 38 | # x = nn.Dropout(p=p)(x) 39 | x = self.fc_final(x) 40 | return x 41 | 42 | 43 | class FeatureNet(nn.Module): 44 | def __init__(self, opt): 45 | super(FeatureNet, self).__init__() 46 | 47 | nx, nh, nt, p = opt.nx, opt.nh, opt.nt, opt.p 48 | self.p = p 49 | 50 | self.fc1 = nn.Linear(nx, nh) 51 | self.fc2 = nn.Linear(nh * 2, nh * 2) 52 | self.fc3 = nn.Linear(nh * 2, nh * 2) 53 | self.fc4 = nn.Linear(nh * 2, nh * 2) 54 | self.fc_final = nn.Linear(nh * 2, nh) 55 | 56 | # here I change the input to fit the change dimension 57 | self.fc1_var = nn.Linear(nt, nh) 58 | self.fc2_var = nn.Linear(nh, nh) 59 | 60 | def forward(self, x, t): 61 | re = x.dim() == 3 62 | if re: 63 | T, B, C = x.shape 64 | x = x.reshape(T * B, -1) 65 | t = t.reshape(T * B, -1) 66 | 67 | x = F.relu(self.fc1(x)) 68 | t = F.relu(self.fc1_var(t)) 69 | t = F.relu(self.fc2_var(t)) 70 | 71 | # combine feature in the middle 72 | x = torch.cat((x, t), dim=1) 73 | 74 | # main 75 | x = F.relu(self.fc2(x)) 76 | x = F.relu(self.fc3(x)) 77 | x = F.relu(self.fc4(x)) 78 | x = self.fc_final(x) 79 | 80 | if re: 81 | return x.reshape(T, B, -1) 82 | else: 83 | return x 84 | 85 | 86 | class GraphDNet(nn.Module): 87 | """ 88 | Generate z' for connection loss 89 | """ 90 | 91 | def __init__(self, opt): 92 | super(GraphDNet, self).__init__() 93 | nh = opt.nh 94 | nin = nh 95 | self.fc3 = nn.Linear(nin, nh) 96 | self.bn3 = nn.BatchNorm1d(nh) 97 | 98 | self.fc4 = nn.Linear(nh, nh) 99 | self.bn4 = nn.BatchNorm1d(nh) 100 | 101 | self.fc5 = nn.Linear(nh, nh) 102 | self.bn5 = nn.BatchNorm1d(nh) 103 | 104 | self.fc6 = nn.Linear(nh, nh) 105 | self.bn6 = nn.BatchNorm1d(nh) 106 | 107 | self.fc7 = nn.Linear(nh, nh) 108 | self.bn7 = nn.BatchNorm1d(nh) 109 | 110 | self.fc_final = nn.Linear(nh, opt.nd_out) 111 | 112 | if opt.no_bn: 113 | self.bn3 = Identity() 114 | self.bn4 = Identity() 115 | self.bn5 = Identity() 116 | self.bn6 = Identity() 117 | self.bn7 = Identity() 118 | 119 | def forward(self, x): 120 | re = x.dim() == 3 121 | 122 | if re: 123 | T, B, C = x.shape 124 | x = x.reshape(T * B, -1) 125 | 126 | x = F.relu(self.bn3(self.fc3(x))) 127 | x = F.relu(self.bn4(self.fc4(x))) 128 | x = F.relu(self.bn5(self.fc5(x))) 129 | x = F.relu(self.bn6(self.fc6(x))) 130 | x = F.relu(self.bn7(self.fc7(x))) 131 | 132 | x = self.fc_final(x) 133 | 134 | if re: 135 | return x.reshape(T, B, -1) 136 | else: 137 | return x 138 | 139 | 140 | class ResGraphDNet(nn.Module): 141 | """ 142 | Generate z' for connection loss 143 | """ 144 | 145 | def __init__(self, opt): 146 | super(ResGraphDNet, self).__init__() 147 | nh = opt.nh 148 | nin = nh 149 | self.fc3 = nn.Linear(nin, nh) 150 | self.bn3 = nn.BatchNorm1d(nh) 151 | 152 | self.fc4 = nn.Linear(nh, nh) 153 | self.bn4 = nn.BatchNorm1d(nh) 154 | 155 | self.fc5 = nn.Linear(nh, nh) 156 | self.bn5 = nn.BatchNorm1d(nh) 157 | 158 | self.fc6 = nn.Linear(nh, nh) 159 | self.bn6 = nn.BatchNorm1d(nh) 160 | 161 | self.fc7 = nn.Linear(nh, nh) 162 | self.bn7 = nn.BatchNorm1d(nh) 163 | 164 | self.fc8 = nn.Linear(nh, nh) 165 | self.bn8 = nn.BatchNorm1d(nh) 166 | 167 | self.fc9 = nn.Linear(nh, nh) 168 | self.bn9 = nn.BatchNorm1d(nh) 169 | 170 | self.fc10 = nn.Linear(nh, nh) 171 | self.bn10 = nn.BatchNorm1d(nh) 172 | 173 | self.fc11 = nn.Linear(nh, nh) 174 | self.bn11 = nn.BatchNorm1d(nh) 175 | 176 | self.fc_final = nn.Linear(nh, opt.nd_out) 177 | 178 | if opt.no_bn: 179 | self.bn3 = Identity() 180 | self.bn4 = Identity() 181 | self.bn5 = Identity() 182 | self.bn6 = Identity() 183 | self.bn7 = Identity() 184 | self.bn8 = Identity() 185 | self.bn9 = Identity() 186 | self.bn10 = Identity() 187 | self.bn11 = Identity() 188 | 189 | def forward(self, x): 190 | re = x.dim() == 3 191 | 192 | if re: 193 | T, B, C = x.shape 194 | x = x.reshape(T * B, -1) 195 | 196 | x = F.relu(self.bn3(self.fc3(x))) 197 | id1 = x 198 | out = F.relu(self.bn4(self.fc4(x))) 199 | out = self.bn5(self.fc5(out)) 200 | x = F.relu(out + id1) 201 | 202 | id2 = x 203 | out = F.relu(self.bn6(self.fc6(x))) 204 | out = self.bn7(self.fc7(out)) 205 | x = F.relu(out + id2) 206 | 207 | id3 = x 208 | out = F.relu(self.bn8(self.fc8(x))) 209 | out = self.bn9(self.fc9(out)) 210 | x = F.relu(out + id3) 211 | 212 | id4 = x 213 | out = F.relu(self.bn10(self.fc10(x))) 214 | out = self.bn11(self.fc11(out)) 215 | x = F.relu(out + id4) 216 | 217 | x = self.fc_final(x) 218 | 219 | if re: 220 | return x.reshape(T, B, -1) 221 | else: 222 | return x 223 | 224 | 225 | class DiscNet(nn.Module): 226 | # Discriminator doing binary classification: source v.s. target 227 | 228 | def __init__(self, opt): 229 | super(DiscNet, self).__init__() 230 | nh = opt.nh 231 | 232 | nin = nh 233 | self.fc3 = nn.Linear(nin, nh) 234 | self.bn3 = nn.BatchNorm1d(nh) 235 | 236 | self.fc4 = nn.Linear(nh, nh) 237 | self.bn4 = nn.BatchNorm1d(nh) 238 | 239 | self.fc5 = nn.Linear(nh, nh) 240 | self.bn5 = nn.BatchNorm1d(nh) 241 | 242 | self.fc6 = nn.Linear(nh, nh) 243 | self.bn6 = nn.BatchNorm1d(nh) 244 | 245 | self.fc7 = nn.Linear(nh, nh) 246 | self.bn7 = nn.BatchNorm1d(nh) 247 | 248 | if opt.no_bn: 249 | self.bn3 = Identity() 250 | self.bn4 = Identity() 251 | self.bn5 = Identity() 252 | self.bn6 = Identity() 253 | self.bn7 = Identity() 254 | 255 | self.fc_final = nn.Linear(nh, 1) 256 | if opt.model in ["ADDA", "CUA"]: 257 | print("===> Discrinimator Output Activation: sigmoid") 258 | self.output = lambda x: torch.sigmoid(x) 259 | else: 260 | print("===> Discrinimator Output Activation: identity") 261 | self.output = lambda x: x 262 | 263 | def forward(self, x): 264 | re = x.dim() == 3 265 | 266 | if re: 267 | T, B, C = x.shape 268 | x = x.reshape(T * B, -1) 269 | 270 | x = F.relu(self.bn3(self.fc3(x))) 271 | x = F.relu(self.bn4(self.fc4(x))) 272 | x = F.relu(self.bn5(self.fc5(x))) 273 | x = F.relu(self.bn6(self.fc6(x))) 274 | x = F.relu(self.bn7(self.fc7(x))) 275 | x = self.output(self.fc_final(x)) 276 | 277 | if re: 278 | return x.reshape(T, B, -1) 279 | else: 280 | return x 281 | 282 | 283 | class ClassDiscNet(nn.Module): 284 | """ 285 | Discriminator doing multi-class classification on the domain 286 | """ 287 | 288 | def __init__(self, opt): 289 | super(ClassDiscNet, self).__init__() 290 | nh = opt.nh 291 | nc = opt.nc 292 | nin = nh 293 | nout = opt.num_domain 294 | 295 | if opt.cond_disc: 296 | print("===> Conditioned Discriminator") 297 | nmid = nh * 2 298 | self.cond = nn.Sequential( 299 | nn.Linear(nc, nh), 300 | nn.ReLU(True), 301 | nn.Linear(nh, nh), 302 | nn.ReLU(True), 303 | ) 304 | else: 305 | print("===> Unconditioned Discriminator") 306 | nmid = nh 307 | self.cond = None 308 | 309 | print(f"===> Discriminator will distinguish {nout} domains") 310 | 311 | self.fc3 = nn.Linear(nin, nh) 312 | self.bn3 = nn.BatchNorm1d(nh) 313 | 314 | self.fc4 = nn.Linear(nmid, nh) 315 | self.bn4 = nn.BatchNorm1d(nh) 316 | 317 | self.fc5 = nn.Linear(nh, nh) 318 | self.bn5 = nn.BatchNorm1d(nh) 319 | 320 | self.fc6 = nn.Linear(nh, nh) 321 | self.bn6 = nn.BatchNorm1d(nh) 322 | 323 | self.fc7 = nn.Linear(nh, nh) 324 | self.bn7 = nn.BatchNorm1d(nh) 325 | 326 | if opt.no_bn: 327 | self.bn3 = Identity() 328 | self.bn4 = Identity() 329 | self.bn5 = Identity() 330 | self.bn6 = Identity() 331 | self.bn7 = Identity() 332 | 333 | self.fc_final = nn.Linear(nh, nout) 334 | 335 | def forward(self, x): 336 | re = x.dim() == 3 337 | 338 | if re: 339 | T, B, C = x.shape 340 | x = x.reshape(T * B, -1) 341 | # f_exp = f_exp.reshape(T * B, -1) 342 | 343 | x = F.relu(self.bn3(self.fc3(x))) 344 | x = F.relu(self.bn4(self.fc4(x))) 345 | x = F.relu(self.bn5(self.fc5(x))) 346 | x = F.relu(self.bn6(self.fc6(x))) 347 | x = F.relu(self.bn7(self.fc7(x))) 348 | # x = self.fc_final(x) 349 | x = F.relu(self.fc_final(x)) 350 | x = torch.log_softmax(x, dim=1) 351 | if re: 352 | return x.reshape(T, B, -1) 353 | else: 354 | return x 355 | 356 | 357 | class CondClassDiscNet(nn.Module): 358 | """ 359 | Discriminator doing multi-class classification on the domain 360 | """ 361 | 362 | def __init__(self, opt): 363 | super(CondClassDiscNet, self).__init__() 364 | nh = opt.nh 365 | nc = opt.nc 366 | nin = nh 367 | nout = opt.num_domain 368 | 369 | if opt.cond_disc: 370 | print("===> Conditioned Discriminator") 371 | nmid = nh * 2 372 | self.cond = nn.Sequential( 373 | nn.Linear(nc, nh), 374 | nn.ReLU(True), 375 | nn.Linear(nh, nh), 376 | nn.ReLU(True), 377 | ) 378 | else: 379 | print("===> Unconditioned Discriminator") 380 | nmid = nh 381 | self.cond = None 382 | 383 | print(f"===> Discriminator will distinguish {nout} domains") 384 | 385 | self.fc3 = nn.Linear(nin, nh) 386 | self.bn3 = nn.BatchNorm1d(nh) 387 | 388 | self.fc4 = nn.Linear(nmid, nh) 389 | self.bn4 = nn.BatchNorm1d(nh) 390 | 391 | self.fc5 = nn.Linear(nh, nh) 392 | self.bn5 = nn.BatchNorm1d(nh) 393 | 394 | self.fc6 = nn.Linear(nh, nh) 395 | self.bn6 = nn.BatchNorm1d(nh) 396 | 397 | self.fc7 = nn.Linear(nh, nh) 398 | self.bn7 = nn.BatchNorm1d(nh) 399 | 400 | if opt.no_bn: 401 | self.bn3 = Identity() 402 | self.bn4 = Identity() 403 | self.bn5 = Identity() 404 | self.bn6 = Identity() 405 | self.bn7 = Identity() 406 | 407 | self.fc_final = nn.Linear(nh, nout) 408 | 409 | def forward(self, x, f_exp): 410 | re = x.dim() == 3 411 | 412 | if re: 413 | T, B, C = x.shape 414 | x = x.reshape(T * B, -1) 415 | f_exp = f_exp.reshape(T * B, -1) 416 | 417 | x = F.relu(self.bn3(self.fc3(x))) 418 | if self.cond is not None: 419 | f = self.cond(f_exp) 420 | x = torch.cat([x, f], dim=1) 421 | x = F.relu(self.bn4(self.fc4(x))) 422 | x = F.relu(self.bn5(self.fc5(x))) 423 | x = F.relu(self.bn6(self.fc6(x))) 424 | x = F.relu(self.bn7(self.fc7(x))) 425 | x = self.fc_final(x) 426 | x = torch.log_softmax(x, dim=1) 427 | if re: 428 | return x.reshape(T, B, -1) 429 | else: 430 | return x 431 | 432 | 433 | class ProbDiscNet(nn.Module): 434 | def __init__(self, opt): 435 | super(ProbDiscNet, self).__init__() 436 | 437 | nmix = opt.nmix 438 | 439 | nh = opt.nh 440 | 441 | nin = nh 442 | nout = opt.dim_domain * nmix * 3 443 | 444 | self.fc3 = nn.Linear(nin, nh) 445 | self.bn3 = nn.BatchNorm1d(nh) 446 | 447 | self.fc4 = nn.Linear(nh, nh) 448 | self.bn4 = nn.BatchNorm1d(nh) 449 | 450 | self.fc5 = nn.Linear(nh, nh) 451 | self.bn5 = nn.BatchNorm1d(nh) 452 | 453 | self.fc6 = nn.Linear(nh, nh) 454 | self.bn6 = nn.BatchNorm1d(nh) 455 | 456 | if opt.no_bn: 457 | self.bn3 = Identity() 458 | self.bn4 = Identity() 459 | self.bn5 = Identity() 460 | self.bn6 = Identity() 461 | 462 | self.fc_final = nn.Linear(nh, nout) 463 | 464 | self.ndomain = opt.dim_domain 465 | self.nmix = nmix 466 | 467 | def forward(self, x): 468 | re = x.dim() == 3 469 | if re: 470 | T, B, C = x.shape 471 | x = x.reshape(T * B, -1) 472 | 473 | x = F.relu(self.bn3(self.fc3(x))) 474 | x = F.relu(self.bn4(self.fc4(x))) 475 | x = F.relu(self.bn5(self.fc5(x))) 476 | x = F.relu(self.bn6(self.fc6(x))) 477 | 478 | x = self.fc_final(x).reshape(-1, 3, self.ndomain, self.nmix) 479 | x_mean, x_std, x_weight = x[:, 0], x[:, 1], x[:, 2] 480 | x_std = torch.sigmoid(x_std) * 2 + 0.1 481 | x_weight = torch.softmax(x_weight, dim=1) 482 | 483 | if re: 484 | return ( 485 | x_mean.reshape(T, B, -1), 486 | x_std.reshape(T, B, -1), 487 | x_weight.reshape(T, B, -1), 488 | ) 489 | else: 490 | return x_mean, x_std, x_weight 491 | 492 | 493 | class PredNet(nn.Module): 494 | def __init__(self, opt): 495 | super(PredNet, self).__init__() 496 | 497 | nh, nc = opt.nh, opt.nc 498 | nin = nh 499 | self.fc3 = nn.Linear(nin, nh) 500 | self.bn3 = nn.BatchNorm1d(nh) 501 | self.fc4 = nn.Linear(nh, nh) 502 | self.bn4 = nn.BatchNorm1d(nh) 503 | self.fc_final = nn.Linear(nh, nc) 504 | if opt.no_bn: 505 | self.bn3 = Identity() 506 | self.bn4 = Identity() 507 | 508 | def forward(self, x, return_softmax=False): 509 | re = x.dim() == 3 510 | if re: 511 | T, B, C = x.shape 512 | x = x.reshape(T * B, -1) 513 | 514 | x = F.relu(self.bn3(self.fc3(x))) 515 | x = F.relu(self.bn4(self.fc4(x))) 516 | x = self.fc_final(x) 517 | x_softmax = F.softmax(x, dim=1) 518 | 519 | # x = F.log_softmax(x, dim=1) 520 | # x = torch.clamp_max(x_softmax + 1e-4, 1) 521 | # x = torch.log(x) 522 | x = torch.log(x_softmax + 1e-4) 523 | 524 | if re: 525 | x = x.reshape(T, B, -1) 526 | x_softmax = x_softmax.reshape(T, B, -1) 527 | 528 | if return_softmax: 529 | return x, x_softmax 530 | else: 531 | return x 532 | 533 | 534 | # ====================================================================================================================== 535 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | networkx 3 | matplotlib 4 | visdom 5 | easydict 6 | torch==1.9.0 7 | torchvision==0.10.0 8 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def read_pickle(name): 5 | with open(name, "rb") as f: 6 | data = pickle.load(f) 7 | return data 8 | 9 | 10 | def write_pickle(data, name): 11 | with open(name, "wb") as f: 12 | pickle.dump(data, f) 13 | --------------------------------------------------------------------------------