├── .gitignore
├── README.md
├── configs
├── config_random_15_new.py
└── config_random_60_new.py
├── data
├── TPT-48.zip
├── toy_d15_spiral_tight_boundary.pkl
└── toy_d60_spiral.pkl
├── dataset_generator
├── dataset_generator.py
└── draw_graph_utils.py
├── dataset_utils
└── dataset.py
├── derive_g_encode
├── 499_pred_GDA_new.pkl
├── derive.py
├── g_encode.pkl
├── g_encode_15.pkl
└── g_encode_60.pkl
├── fig
├── GRDA-DG-15-results.png
├── GRDA-domain-graph-US.png
├── blog-method-DA-vs-GRDA.png
├── compcar_quantitive_result.jpg
├── dg_15_60_quantitive_result.jpg
└── tpt_48_quantitive_result.jpg
├── main.py
├── model
├── model.py
└── modules.py
├── requirements.txt
└── utils
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | __pycache__/
3 | dataset_utils/__pycache__/
4 | dataset_utils/__init__.py
5 | dump/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph-Relational Domain Adaptation (GRDA)
2 | This repo contains the code for our ICLR 2022 paper:
3 | **Graph-Relational Domain Adaptation**
4 | Zihao Xu, Hao He, Guang-He Lee, Yuyang Wang, Hao Wang
5 | *Tenth International Conference on Learning Representations (ICLR), 2022*
6 | [[Paper](http://wanghao.in/paper/ICLR22_GRDA.pdf)] [[Talk](https://www.youtube.com/watch?v=oNM5hZGVv34)] [[Slides](http://wanghao.in/slides/GRDA_slides.pptx)][[TPT-48 Dataset](data/TPT-48.zip)]
7 |
8 |
9 | ## Beyond Domain Adaptation: Brief Introduction for GRDA
10 | Essentially GRDA goes beyond current (categorical) domain adaptation regime and proposes the first approach to **adapt across graph-relational domains**. We introduce a new notion, dubbed "**domain graph**", that to encode domain adjacency, e.g., a graph of states in the US with each state as a domain and each edge indicating adjacency. Theoretical analysis shows that *at equilibrium, GRDA recovers classic domain adaptation when the graph is a clique, and achieves non-trivial alignment for other types of graphs*. See the following example (black nodes as source domains and white nodes as target domains).
11 |
12 |
13 |
14 |
15 |
16 | ## Sample Results
17 | In a DA problem with 15 domains connected by a domain graph (see the figure below), if we use domains 0, 3, 4, 8, 12, 14 as source domains (left of the following figure) and the rest as target domains, below are some sample results from previous domain adaptation methods and GRDA (right of the figure), where GRDA successfully generalizes across different domains in the graph.
18 |
19 |
20 |
21 |
22 |
23 | ## Method Overview
24 | We provide a simple yet effective learning framework with **theoretical guarantees** (see the [Theory section](https://github.com/Wang-ML-Lab/GRDA/edit/main/README.md#theory-informal) at the end of this README). Below is a quick comparison between previous domain adaptation methods and GRDA (differences marked in red).
25 | * Previous domain adaptation methods use a discriminator is classifify different domains (as categorical values), while GRDA's discriminator directly reconstructs the domain graph (as a adjacency matrix).
26 | * Previous domain adaptation methods' encoders ignore domain IDs, while GRDA takes the domain IDs with the domain graph as input.
27 |
28 |
29 |
30 |
31 | ## Quantitative Result
32 | #### Toy Dataset: DG-15 and DG-60
33 |
34 |
35 |
36 |
37 | #### TPT-48
38 |
39 |
40 |
41 |
42 | #### CompCars
43 |
44 |
45 |
46 |
47 | ## Theory (Informal)
48 | - Traditional DA is equivalent to using our GRDA with a fully-connected graph (i.e., a clique).
49 | - D and E converge if and only if 𝔼i\~p(u|e),j\~p(u|e')[Ai,j|e, e'] = 𝔼i,j[Ai,j].
50 | - The global optimum of the two-player game between E and D matches the three-player game between E, D, and F.
51 |
52 |
53 |
54 | ## Installation
55 | pip install -r requirements.txt
56 |
57 | ## How to Train GRDA
58 | python main.py
59 |
60 | ## Visualization
61 | We use visdom to visualize. We assume the code is run on a remote gpu machine.
62 |
63 | ### Change Configurations
64 | Find the config in "config" folder. Choose the config you need and Set "opt.use_visdom" to "True".
65 |
66 | ### Start a Visdom Server on Your Machine
67 | python -m visdom.server -p 2000
68 | Now connect your computer with the gpu server and forward the port 2000 to your local computer. You can now go to:
69 | http://localhost:2000 (Your local address)
70 | to see the visualization during training.
71 |
72 | ## Also Check Out Relevant Work
73 | **Continuously Indexed Domain Adaptation**
74 | Hao Wang*, Hao He*, Dina Katabi
75 | *Thirty-Seventh International Conference on Machine Learning (ICML), 2020*
76 | [[Paper](http://wanghao.in/paper/ICML20_CIDA.pdf)] [[Code](https://github.com/hehaodele/CIDA)] [[Talk](https://www.youtube.com/watch?v=KtZPSCD-WhQ)] [[Blog](http://wanghao.in/CIDA-Blog/CIDA.html)] [[Slides](http://wanghao.in/slides/CIDA_slides.pptx)]
77 |
78 | ## Reference
79 | [Graph-Relational Domain Adaptation](http://wanghao.in/paper/ICLR22_GRDA.pdf)
80 | ```bib
81 | @inproceedings{GRDA,
82 | title={Graph-Relational Domain Adaptation},
83 | author={Xu, Zihao and He, Hao and Lee, Guang-He and Wang, Yuyang and Wang, Hao},
84 | booktitle={International Conference on Learning Representations},
85 | year={2022}
86 | }
87 | ```
88 |
--------------------------------------------------------------------------------
/configs/config_random_15_new.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict
2 | import numpy as np
3 | import pickle
4 |
5 |
6 | def read_pickle(name):
7 | with open(name, "rb") as f:
8 | data = pickle.load(f)
9 | return data
10 |
11 |
12 | # load/output dir
13 | opt = EasyDict()
14 | opt.loadf = "./dump"
15 | opt.outf = "./dump"
16 |
17 | # normalize each data domain
18 | # opt.normalize_domain = False
19 |
20 | # now it is half circle
21 | opt.num_domain = 15
22 | # the specific source and target domain:
23 | opt.src_domain = np.array([0, 12, 3, 4, 14, 8]) # tight_boundary
24 | opt.num_source = opt.src_domain.shape[0]
25 | opt.num_target = opt.num_domain - opt.num_source
26 | opt.test_on_all_dmn = True
27 |
28 |
29 | print("src domain: {}".format(opt.src_domain))
30 |
31 | # opt.model = "DANN"
32 | # opt.model = "CDANN"
33 | # opt.model = "ADDA"
34 | # opt.model = 'MDD'
35 | opt.model = "GDA"
36 | opt.cond_disc = (
37 | False # whether use conditional discriminator or not (for CDANN)
38 | )
39 | print("model: {}".format(opt.model))
40 | opt.use_visdom = False
41 | opt.visdom_port = 2000
42 |
43 | opt.use_g_encode = True # False # True
44 | if opt.use_g_encode:
45 | opt.g_encode = read_pickle("derive_g_encode/g_encode_15.pkl")
46 |
47 | opt.device = "cuda"
48 | opt.seed = 233 # 1# 101 # 1 # 233 # 1
49 |
50 | opt.lambda_gan = 1 # 0.5
51 | # for MDD use only
52 | opt.lambda_src = 0.5
53 | opt.lambda_tgt = 0.5
54 |
55 |
56 | opt.num_epoch = 500
57 | opt.batch_size = 10
58 | opt.lr_d = 1e-4 # 3e-5 # 1e-4
59 | opt.lr_e = 1e-4 # 3e-5 # 1e-4
60 | opt.lr_g = 1e-4
61 | opt.gamma = 100
62 | opt.beta1 = 0.9
63 | opt.weight_decay = 5e-4
64 | opt.wgan = False # do not use wgan to train
65 | opt.no_bn = True # do not use batch normalization # True
66 |
67 | # model size configs, used for D, E, F
68 | opt.nx = 2 # dimension of the input data
69 | opt.nt = 2 # dimension of the vertex embedding
70 | opt.nh = 512 # dimension of hidden # 512
71 | opt.nc = 2 # number of label class
72 | opt.nd_out = 2 # dimension of D's output
73 |
74 | # sample how many vertices for training D
75 | opt.sample_v = 10
76 |
77 | # # sample how many vertices for training G
78 | opt.sample_v_g = 15
79 |
80 | opt.test_interval = 20
81 | opt.save_interval = 100
82 | # drop out rate
83 | opt.p = 0.2
84 | opt.shuffle = True
85 |
86 | # dataset
87 | opt.dataset = "data/toy_d15_spiral_tight_boundary.pkl"
88 |
--------------------------------------------------------------------------------
/configs/config_random_60_new.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict
2 | import numpy as np
3 | import pickle
4 |
5 |
6 | def read_pickle(name):
7 | with open(name, "rb") as f:
8 | data = pickle.load(f)
9 | return data
10 |
11 |
12 | # load/output dir
13 | opt = EasyDict() # set experiment configs
14 | opt.loadf = "./dump"
15 | opt.outf = "./dump"
16 |
17 | # normalize each data domain
18 | # opt.normalize_domain = False
19 |
20 | # now it is half circle
21 | opt.num_domain = 60
22 | # the specific source and target domain:
23 | opt.src_domain = np.array([2, 14, 41, 23, 59, 33]) # 60-spiral
24 | opt.num_source = opt.src_domain.shape[0]
25 | opt.num_target = opt.num_domain - opt.num_source
26 | opt.test_on_all_dmn = True
27 |
28 |
29 | print("src domain: {}".format(opt.src_domain))
30 |
31 | # opt.model = "DANN"
32 | # opt.model = "CDANN"
33 | # opt.model = "ADDA"
34 | # opt.model = 'MDD'
35 | opt.model = "GDA"
36 | opt.cond_disc = (
37 | False # whether use conditional discriminator or not (for CDANN)
38 | )
39 | print("model: {}".format(opt.model))
40 | opt.use_visdom = True
41 | opt.visdom_port = 2000
42 |
43 | # we do not prepare the pretrain g encode for random-60 dataset
44 | opt.use_g_encode = True
45 | if opt.use_g_encode:
46 | opt.g_encode = read_pickle("derive_g_encode/g_encode_60.pkl")
47 |
48 | opt.device = "cuda"
49 | opt.seed = 233 # 1# 101 # 1 # 233 # 1
50 |
51 | opt.lambda_gan = 0.5 # 0.5 # 0.3125 # 0.5 # 0.5
52 |
53 | # for MDD use only
54 | opt.lambda_src = 0.5
55 | opt.lambda_tgt = 0.5
56 |
57 | opt.num_epoch = 500
58 | opt.batch_size = 10 # 10
59 | opt.lr_d = 1e-5 # 3e-5 # 1e-4 # 2.9 * 1e-5 #3e-5 # 1e-4
60 | opt.lr_e = 1e-5 # 3e-5 # 1e-4 # 2.9 * 1e-5
61 | opt.lr_g = 1e-3 # 3e-4
62 | opt.gamma = 100
63 | opt.beta1 = 0.9
64 | opt.weight_decay = 5e-4
65 | opt.wgan = False # do not use wgan to train
66 | opt.no_bn = True # do not use batch normalization # True
67 |
68 | # model size configs, used for D, E, F
69 | opt.nx = 2 # dimension of the input data
70 | opt.nt = 2 # dimension of the vertex embedding
71 | opt.nh = 512 # dimension of hidden # 512
72 | opt.nc = 2 # number of label class
73 | opt.nd_out = 2 # dimension of D's output
74 |
75 | # sample how many vertices for training D
76 | opt.sample_v = 30
77 |
78 | # # sample how many vertices for training G
79 | opt.sample_v_g = 60
80 |
81 | opt.test_interval = 20 # 20
82 | opt.save_interval = 100
83 | # drop out rate
84 | opt.p = 0.2
85 | opt.shuffle = True
86 |
87 | # dataset
88 | opt.dataset = "data/toy_d60_spiral.pkl"
89 |
--------------------------------------------------------------------------------
/data/TPT-48.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/data/TPT-48.zip
--------------------------------------------------------------------------------
/data/toy_d15_spiral_tight_boundary.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/data/toy_d15_spiral_tight_boundary.pkl
--------------------------------------------------------------------------------
/data/toy_d60_spiral.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/data/toy_d60_spiral.pkl
--------------------------------------------------------------------------------
/dataset_generator/dataset_generator.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import collections
3 | import csv
4 | import shutil
5 | import os
6 | import numpy as np
7 | from os.path import join
8 | import matplotlib.pyplot as plt
9 | # from utils import read_pickle
10 | # from utils import write_pickle
11 | from scipy.ndimage.interpolation import zoom
12 | import re
13 | import pickle
14 | import networkx as nx
15 | from draw_graph_utils import draw
16 |
17 |
18 | def read_pickle(name):
19 | with open(name, 'rb') as f:
20 | data = pickle.load(f)
21 | return data
22 |
23 | def write_pickle(data, name):
24 | with open(name,'wb') as f:
25 | # the default protocol level is 4
26 | pickle.dump(data, f)
27 |
28 | def show_graph_with_labels(adjacency_matrix, my_angles):
29 | rows, cols = np.where(adjacency_matrix == 1)
30 | edges = zip(rows.tolist(), cols.tolist())
31 | gr = nx.Graph()
32 | gr.add_edges_from(edges)
33 |
34 | pos = nx.kamada_kawai_layout(gr) # good, littel better than spring
35 |
36 | num_domain = adjacency_matrix.shape[0]
37 |
38 | # expand the graph in horizontal
39 | for i in range(num_domain):
40 | pos[i][0] *= 1.4
41 | pos[i][1] *= 0.8
42 |
43 |
44 | labels = dict()
45 | for i in range(num_domain):
46 | labels[i] = i
47 |
48 | # use self defined picture drawing picture.
49 | fig, ax = plt.subplots(1, 1)
50 | draw(gr, pos, node_radius=0.077, font_color='white', node_angles=my_angles, labels=labels, with_labels=True, ax=ax)
51 | ax.set_aspect("equal")
52 | plt.show()
53 |
54 | # generate data given the mean/std, radius and number
55 | def generate_data(mean, std, radius, num):
56 | dim = mean.shape[0]
57 | m_data = np.random.randn(num, dim)
58 | print('bingo', m_data.shape)
59 | m_data *= std[None, :]
60 | m_radius = m_data[:, 0] ** 2 + m_data[:, 1] ** 2
61 | m_data += mean[None, :]
62 | m_data = m_data[m_radius <= radius ** 2, :]
63 |
64 | # random choice
65 | choice = np.random.choice(m_data.shape[0], size=50, replace=False)
66 | print(choice)
67 | m_data = m_data[choice, :]
68 |
69 | print('num of data points within radius', radius, ':', m_data.shape[0])
70 | return m_data
71 |
72 | # generate label for circle-shape data
73 | def generate_label(m_data, radius):
74 | m_radius = m_data[:, 0] ** 2 + m_data[:, 1] ** 2
75 | m_label = np.zeros((m_data.shape[0],))
76 | m_label[m_radius > radius ** 2] = 1
77 | print("=============")
78 | print("label 0's num: {}".format(np.sum(m_label == 0)))
79 | print("label 1's num: {}".format(np.sum(m_label == 1)))
80 | return m_label
81 |
82 | # create dataset, circle-shape domain manifold
83 | def create_toy_data():
84 | fname = 'toy_d60_spiral.pkl'
85 | num_domain = 60
86 | # fname = 'toy_d30_spiral.pkl'
87 | # l_angle = np.random.rand(15) * np.pi / 2
88 | # l_angle = np.random.rand(num_domain) * np.pi * 2
89 | l_angle = np.random.rand(num_domain) * np.pi / 30 * num_domain
90 |
91 | radius_start = 1
92 | std_small = 1
93 | radius_step = 0.1
94 | radius_small = 1
95 |
96 | lm_data = []
97 | l_domain = []
98 | l_label = []
99 | for i, angle in enumerate(l_angle):
100 | # radius = radius_start + angle / (np.pi / 2) * radius_step
101 | # radius = radius_start + angle / (np.pi) * radius_step
102 | radius = radius_start + angle / (np.pi / 30 * num_domain) * radius_step * 60
103 | mean = np.array([np.cos(angle), np.sin(angle)]) * radius
104 | std = np.ones((2,)) * std_small
105 | m_data = generate_data(mean, std, radius_small, 300)
106 | m_data = np.append(m_data, generate_data(-mean, std, radius_small, 300), axis=0)
107 | m_label = np.ones(m_data.shape[0],)
108 | m_label[0:int(0.5 * m_data.shape[0])] = 0
109 | # m_data = generate_data(mean, std, radius_small, 300)
110 | # m_label = generate_label(m_data, ra)
111 | l_label.append(m_label)
112 | lm_data.append(m_data)
113 | l_domain.append(np.ones(m_data.shape[0],) * i)
114 |
115 | angle_all = np.array(l_angle)
116 | data_all = np.concatenate(lm_data, axis=0)
117 | domain_all = np.concatenate(l_domain, axis=0)
118 | label_all = np.concatenate(l_label, axis=0)# generate_label(data_all, radius_large)
119 |
120 | # generate A
121 | # A's generation:
122 | A = np.zeros((num_domain, num_domain))
123 | for i in range(num_domain):
124 | for j in range(i + 1, num_domain):
125 | p = np.cos(angle_all[i]) * np.cos(angle_all[j]) + np.sin(angle_all[i]) * np.sin(angle_all[j])
126 |
127 | if num_domain == 15:
128 | if p < 0.5:
129 | c = 0
130 | else:
131 | c = np.random.binomial(1, p)
132 | elif num_domain == 60:
133 | if p < 0.2:
134 | c = 0
135 | else:
136 | c = np.random.binomial(1, p)
137 |
138 | A[i][j] = c
139 | A[j][i] = c
140 |
141 | show_graph_with_labels(A, angle_all)
142 | print(angle_all)
143 |
144 |
145 | d_pkl = dict()
146 | d_pkl['data'] = data_all
147 | d_pkl['label'] = label_all
148 | d_pkl['domain'] = domain_all
149 | d_pkl['A'] = A
150 | d_pkl['angle'] = angle_all
151 | write_pickle(d_pkl, fname)
152 |
153 | l_style = ['k*', 'r*', 'b*', 'y*', 'k.', 'r.']
154 | for i in range(2):
155 | data_sub = data_all[label_all == i, :]
156 | plt.plot(data_sub[:, 0], data_sub[:, 1], l_style[i])
157 | plt.grid()
158 | plt.show()
159 |
160 | if __name__ == '__main__':
161 | create_toy_data()
--------------------------------------------------------------------------------
/dataset_generator/draw_graph_utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.patches as mpatches
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 |
5 | # this is a modified version of networkx
6 |
7 |
8 | def arc_patch(center, radius, theta1, theta2, ax=None, resolution=50, **kwargs):
9 | # be sure that theta is in radian!!
10 | # make sure ax is not empty
11 | if ax is None:
12 | ax = plt.gca()
13 | # generate the points
14 | theta = np.linspace(theta1, theta2, resolution)
15 | points = np.vstack((radius*np.cos(theta) + center[0],
16 | radius*np.sin(theta) + center[1]))
17 | # build the polygon and add it to the axes
18 | # print(**kwargs)
19 | poly = mpatches.Polygon(points.T, closed=True, **kwargs)
20 | # poly = mpatches.Polygon(points.T, closed=True, fill=True, color='blue')
21 | ax.add_patch(poly)
22 | ax.plot()
23 | return poly
24 |
25 |
26 | """
27 | **********
28 | Matplotlib
29 | **********
30 |
31 | Draw networks with matplotlib.
32 |
33 | See Also
34 | --------
35 |
36 | matplotlib: http://matplotlib.org/
37 |
38 | pygraphviz: http://pygraphviz.github.io/
39 |
40 | """
41 | from numbers import Number
42 | import networkx as nx
43 | from networkx.drawing.layout import (
44 | shell_layout,
45 | circular_layout,
46 | kamada_kawai_layout,
47 | spectral_layout,
48 | spring_layout,
49 | random_layout,
50 | planar_layout,
51 | )
52 |
53 | __all__ = [
54 | "draw",
55 | "draw_networkx",
56 | "draw_networkx_nodes",
57 | "draw_networkx_edges",
58 | "draw_networkx_labels",
59 | "draw_networkx_edge_labels",
60 | "draw_circular",
61 | "draw_kamada_kawai",
62 | "draw_random",
63 | "draw_spectral",
64 | "draw_spring",
65 | "draw_planar",
66 | "draw_shell",
67 | ]
68 |
69 |
70 | def draw(G, pos=None, ax=None, **kwds):
71 | """Draw the graph G with Matplotlib.
72 |
73 | Draw the graph as a simple representation with no node
74 | labels or edge labels and using the full Matplotlib figure area
75 | and no axis labels by default. See draw_networkx() for more
76 | full-featured drawing that allows title, axis labels etc.
77 |
78 | Parameters
79 | ----------
80 | G : graph
81 | A networkx graph
82 |
83 | pos : dictionary, optional
84 | A dictionary with nodes as keys and positions as values.
85 | If not specified a spring layout positioning will be computed.
86 | See :py:mod:`networkx.drawing.layout` for functions that
87 | compute node positions.
88 |
89 | ax : Matplotlib Axes object, optional
90 | Draw the graph in specified Matplotlib axes.
91 |
92 | kwds : optional keywords
93 | See networkx.draw_networkx() for a description of optional keywords.
94 |
95 | Examples
96 | --------
97 | >>> G = nx.dodecahedral_graph()
98 | >>> nx.draw(G)
99 | >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
100 |
101 | See Also
102 | --------
103 | draw_networkx()
104 | draw_networkx_nodes()
105 | draw_networkx_edges()
106 | draw_networkx_labels()
107 | draw_networkx_edge_labels()
108 |
109 | Notes
110 | -----
111 | This function has the same name as pylab.draw and pyplot.draw
112 | so beware when using `from networkx import *`
113 |
114 | since you might overwrite the pylab.draw function.
115 |
116 | With pyplot use
117 |
118 | >>> import matplotlib.pyplot as plt
119 | >>> G = nx.dodecahedral_graph()
120 | >>> nx.draw(G) # networkx draw()
121 | >>> plt.draw() # pyplot draw()
122 |
123 | Also see the NetworkX drawing examples at
124 | https://networkx.github.io/documentation/latest/auto_examples/index.html
125 | """
126 | try:
127 | import matplotlib.pyplot as plt
128 | except ImportError as e:
129 | raise ImportError("Matplotlib required for draw()") from e
130 | except RuntimeError:
131 | print("Matplotlib unable to open display")
132 | raise
133 |
134 | if ax is None:
135 | cf = plt.gcf()
136 | else:
137 | cf = ax.get_figure()
138 | cf.set_facecolor("w")
139 | if ax is None:
140 | if cf._axstack() is None:
141 | ax = cf.add_axes((0, 0, 1, 1))
142 | else:
143 | ax = cf.gca()
144 |
145 | if "with_labels" not in kwds:
146 | kwds["with_labels"] = "labels" in kwds
147 |
148 | draw_networkx(G, pos=pos, ax=ax, **kwds)
149 | ax.set_axis_off()
150 | plt.draw_if_interactive()
151 | return
152 |
153 |
154 |
155 | def draw_networkx(G, pos=None, arrows=True, with_labels=True, **kwds):
156 | """Draw the graph G using Matplotlib.
157 |
158 | Draw the graph with Matplotlib with options for node positions,
159 | labeling, titles, and many other drawing features.
160 | See draw() for simple drawing without labels or axes.
161 |
162 | Parameters
163 | ----------
164 | G : graph
165 | A networkx graph
166 |
167 | pos : dictionary, optional
168 | A dictionary with nodes as keys and positions as values.
169 | If not specified a spring layout positioning will be computed.
170 | See :py:mod:`networkx.drawing.layout` for functions that
171 | compute node positions.
172 |
173 | arrows : bool, optional (default=True)
174 | For directed graphs, if True draw arrowheads.
175 | Note: Arrows will be the same color as edges.
176 |
177 | arrowstyle : str, optional (default='-|>')
178 | For directed graphs, choose the style of the arrowsheads.
179 | See :py:class: `matplotlib.patches.ArrowStyle` for more
180 | options.
181 |
182 | arrowsize : int, optional (default=10)
183 | For directed graphs, choose the size of the arrow head head's length and
184 | width. See :py:class: `matplotlib.patches.FancyArrowPatch` for attribute
185 | `mutation_scale` for more info.
186 |
187 | with_labels : bool, optional (default=True)
188 | Set to True to draw labels on the nodes.
189 |
190 | ax : Matplotlib Axes object, optional
191 | Draw the graph in the specified Matplotlib axes.
192 |
193 | nodelist : list, optional (default G.nodes())
194 | Draw only specified nodes
195 |
196 | edgelist : list, optional (default=G.edges())
197 | Draw only specified edges
198 |
199 | node_size : scalar or array, optional (default=300)
200 | Size of nodes. If an array is specified it must be the
201 | same length as nodelist.
202 |
203 | node_color : color or array of colors (default='#1f78b4')
204 | Node color. Can be a single color or a sequence of colors with the same
205 | length as nodelist. Color can be string, or rgb (or rgba) tuple of
206 | floats from 0-1. If numeric values are specified they will be
207 | mapped to colors using the cmap and vmin,vmax parameters. See
208 | matplotlib.scatter for more details.
209 |
210 | node_shape : string, optional (default='o')
211 | The shape of the node. Specification is as matplotlib.scatter
212 | marker, one of 'so^>v>> G = nx.dodecahedral_graph()
274 | >>> nx.draw(G)
275 | >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
276 |
277 | >>> import matplotlib.pyplot as plt
278 | >>> limits = plt.axis("off") # turn of axis
279 |
280 | Also see the NetworkX drawing examples at
281 | https://networkx.github.io/documentation/latest/auto_examples/index.html
282 |
283 | See Also
284 | --------
285 | draw()
286 | draw_networkx_nodes()
287 | draw_networkx_edges()
288 | draw_networkx_labels()
289 | draw_networkx_edge_labels()
290 | """
291 | try:
292 | import matplotlib.pyplot as plt
293 | except ImportError as e:
294 | raise ImportError("Matplotlib required for draw()") from e
295 | except RuntimeError:
296 | print("Matplotlib unable to open display")
297 | raise
298 |
299 | valid_node_kwds = (
300 | # some of my own parameters:
301 | "node_angles",
302 | "node_radius",
303 |
304 | "nodelist",
305 | "node_size",
306 | "node_color",
307 | "node_shape",
308 | "alpha",
309 | "cmap",
310 | "vmin",
311 | "vmax",
312 | "ax",
313 | "linewidths",
314 | "edgecolors",
315 | "label",
316 | )
317 |
318 | valid_edge_kwds = (
319 | "edgelist",
320 | "width",
321 | "edge_color",
322 | "style",
323 | "alpha",
324 | "arrowstyle",
325 | "arrowsize",
326 | "edge_cmap",
327 | "edge_vmin",
328 | "edge_vmax",
329 | "ax",
330 | "label",
331 | "node_size",
332 | "nodelist",
333 | "node_shape",
334 | "connectionstyle",
335 | "min_source_margin",
336 | "min_target_margin",
337 | )
338 |
339 | valid_label_kwds = (
340 | "labels",
341 | "font_size",
342 | "font_color",
343 | "font_family",
344 | "font_weight",
345 | "alpha",
346 | "bbox",
347 | "ax",
348 | "horizontalalignment",
349 | "verticalalignment",
350 | )
351 |
352 | valid_kwds = valid_node_kwds + valid_edge_kwds + valid_label_kwds
353 |
354 | if any([k not in valid_kwds for k in kwds]):
355 | invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
356 | raise ValueError(f"Received invalid argument(s): {invalid_args}")
357 |
358 | node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
359 | edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
360 | label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
361 |
362 | if pos is None:
363 | pos = nx.drawing.spring_layout(G) # default to spring layout
364 |
365 | # draw_networkx_nodes(G, pos, **node_kwds)
366 |
367 | # use my own nodes instead
368 | my_draw_networkx_nodes(G, pos, **node_kwds)
369 |
370 | draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds)
371 | if with_labels:
372 | draw_networkx_labels(G, pos, **label_kwds)
373 | plt.draw_if_interactive()
374 |
375 |
376 | def my_draw_networkx_nodes(
377 | G,
378 | pos,
379 | node_angles=None,
380 | node_radius=1,
381 | nodelist=None,
382 | ax=None,
383 | label=None,
384 | ):
385 | from collections.abc import Iterable
386 |
387 | if ax is None:
388 | ax = plt.gca()
389 |
390 | if nodelist is None:
391 | nodelist = list(G)
392 |
393 | # xys = np.asarray([pos[v] for v in nodelist])
394 | # num_node =
395 | # for i in range(num_node):
396 | for v in nodelist:
397 | arc_patch(pos[v], node_radius, node_angles[v], node_angles[v] + np.pi, ax, fill=True, color='blue')
398 | arc_patch(pos[v], node_radius, node_angles[v]+np.pi, node_angles[v] + 2 * np.pi, ax, fill=True, color='red')
399 |
400 | ax.tick_params(
401 | axis="both",
402 | which="both",
403 | bottom=False,
404 | left=False,
405 | labelbottom=False,
406 | labelleft=False,
407 | )
408 |
409 |
410 |
411 | def draw_networkx_nodes(
412 | G,
413 | pos,
414 | nodelist=None,
415 | node_size=300,
416 | node_color="#1f78b4",
417 | node_shape="o",
418 | alpha=None,
419 | cmap=None,
420 | vmin=None,
421 | vmax=None,
422 | ax=None,
423 | linewidths=None,
424 | edgecolors=None,
425 | label=None,
426 | ):
427 | """Draw the nodes of the graph G.
428 |
429 | This draws only the nodes of the graph G.
430 |
431 | Parameters
432 | ----------
433 | G : graph
434 | A networkx graph
435 |
436 | pos : dictionary
437 | A dictionary with nodes as keys and positions as values.
438 | Positions should be sequences of length 2.
439 |
440 | ax : Matplotlib Axes object, optional
441 | Draw the graph in the specified Matplotlib axes.
442 |
443 | nodelist : list, optional
444 | Draw only specified nodes (default G.nodes())
445 |
446 | node_size : scalar or array
447 | Size of nodes (default=300). If an array is specified it must be the
448 | same length as nodelist.
449 |
450 | node_color : color or array of colors (default='#1f78b4')
451 | Node color. Can be a single color or a sequence of colors with the same
452 | length as nodelist. Color can be string, or rgb (or rgba) tuple of
453 | floats from 0-1. If numeric values are specified they will be
454 | mapped to colors using the cmap and vmin,vmax parameters. See
455 | matplotlib.scatter for more details.
456 |
457 | node_shape : string
458 | The shape of the node. Specification is as matplotlib.scatter
459 | marker, one of 'so^>v>> G = nx.dodecahedral_graph()
490 | >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
491 |
492 | Also see the NetworkX drawing examples at
493 | https://networkx.github.io/documentation/latest/auto_examples/index.html
494 |
495 | See Also
496 | --------
497 | draw()
498 | draw_networkx()
499 | draw_networkx_edges()
500 | draw_networkx_labels()
501 | draw_networkx_edge_labels()
502 | """
503 | from collections.abc import Iterable
504 |
505 | try:
506 | import matplotlib.pyplot as plt
507 | from matplotlib.collections import PathCollection
508 | import numpy as np
509 | except ImportError as e:
510 | raise ImportError("Matplotlib required for draw()") from e
511 | except RuntimeError:
512 | print("Matplotlib unable to open display")
513 | raise
514 |
515 | if ax is None:
516 | ax = plt.gca()
517 |
518 | if nodelist is None:
519 | nodelist = list(G)
520 |
521 | if len(nodelist) == 0: # empty nodelist, no drawing
522 | return PathCollection(None)
523 |
524 | try:
525 | xy = np.asarray([pos[v] for v in nodelist])
526 | except KeyError as e:
527 | raise nx.NetworkXError(f"Node {e} has no position.") from e
528 | except ValueError as e:
529 | raise nx.NetworkXError("Bad value in node positions.") from e
530 |
531 | if isinstance(alpha, Iterable):
532 | node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
533 | alpha = None
534 |
535 | node_collection = ax.scatter(
536 | xy[:, 0],
537 | xy[:, 1],
538 | s=node_size,
539 | c=node_color,
540 | marker=node_shape,
541 | cmap=cmap,
542 | vmin=vmin,
543 | vmax=vmax,
544 | alpha=alpha,
545 | linewidths=linewidths,
546 | edgecolors=edgecolors,
547 | label=label,
548 | )
549 | ax.tick_params(
550 | axis="both",
551 | which="both",
552 | bottom=False,
553 | left=False,
554 | labelbottom=False,
555 | labelleft=False,
556 | )
557 |
558 | node_collection.set_zorder(2)
559 | return node_collection
560 |
561 |
562 |
563 | def draw_networkx_edges(
564 | G,
565 | pos,
566 | edgelist=None,
567 | width=1.0,
568 | edge_color="k",
569 | style="solid",
570 | alpha=None,
571 | arrowstyle="-|>",
572 | arrowsize=10,
573 | edge_cmap=None,
574 | edge_vmin=None,
575 | edge_vmax=None,
576 | ax=None,
577 | arrows=True,
578 | label=None,
579 | node_size=300,
580 | nodelist=None,
581 | node_shape="o",
582 | connectionstyle=None,
583 | min_source_margin=0,
584 | min_target_margin=0,
585 | ):
586 | """Draw the edges of the graph G.
587 |
588 | This draws only the edges of the graph G.
589 |
590 | Parameters
591 | ----------
592 | G : graph
593 | A networkx graph
594 |
595 | pos : dictionary
596 | A dictionary with nodes as keys and positions as values.
597 | Positions should be sequences of length 2.
598 |
599 | edgelist : collection of edge tuples
600 | Draw only specified edges(default=G.edges())
601 |
602 | width : float, or array of floats
603 | Line width of edges (default=1.0)
604 |
605 | edge_color : color or array of colors (default='k')
606 | Edge color. Can be a single color or a sequence of colors with the same
607 | length as edgelist. Color can be string, or rgb (or rgba) tuple of
608 | floats from 0-1. If numeric values are specified they will be
609 | mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
610 |
611 | style : string
612 | Edge line style (default='solid') (solid|dashed|dotted,dashdot)
613 |
614 | alpha : float
615 | The edge transparency (default=None)
616 |
617 | edge_ cmap : Matplotlib colormap
618 | Colormap for mapping intensities of edges (default=None)
619 |
620 | edge_vmin,edge_vmax : floats
621 | Minimum and maximum for edge colormap scaling (default=None)
622 |
623 | ax : Matplotlib Axes object, optional
624 | Draw the graph in the specified Matplotlib axes.
625 |
626 | arrows : bool, optional (default=True)
627 | For directed graphs, if True draw arrowheads.
628 | Note: Arrows will be the same color as edges.
629 |
630 | arrowstyle : str, optional (default='-|>')
631 | For directed graphs, choose the style of the arrow heads.
632 | See :py:class: `matplotlib.patches.ArrowStyle` for more
633 | options.
634 |
635 | arrowsize : int, optional (default=10)
636 | For directed graphs, choose the size of the arrow head head's length and
637 | width. See :py:class: `matplotlib.patches.FancyArrowPatch` for attribute
638 | `mutation_scale` for more info.
639 |
640 | connectionstyle : str, optional (default=None)
641 | Pass the connectionstyle parameter to create curved arc of rounding
642 | radius rad. For example, connectionstyle='arc3,rad=0.2'.
643 | See :py:class: `matplotlib.patches.ConnectionStyle` and
644 | :py:class: `matplotlib.patches.FancyArrowPatch` for more info.
645 |
646 | label : [None| string]
647 | Label for legend
648 |
649 | min_source_margin : int, optional (default=0)
650 | The minimum margin (gap) at the begining of the edge at the source.
651 |
652 | min_target_margin : int, optional (default=0)
653 | The minimum margin (gap) at the end of the edge at the target.
654 |
655 | Returns
656 | -------
657 | matplotlib.collection.LineCollection
658 | `LineCollection` of the edges
659 |
660 | list of matplotlib.patches.FancyArrowPatch
661 | `FancyArrowPatch` instances of the directed edges
662 |
663 | Depending whether the drawing includes arrows or not.
664 |
665 | Notes
666 | -----
667 | For directed graphs, arrows are drawn at the head end. Arrows can be
668 | turned off with keyword arrows=False. Be sure to include `node_size` as a
669 | keyword argument; arrows are drawn considering the size of nodes.
670 |
671 | Examples
672 | --------
673 | >>> G = nx.dodecahedral_graph()
674 | >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
675 |
676 | >>> G = nx.DiGraph()
677 | >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
678 | >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
679 | >>> alphas = [0.3, 0.4, 0.5]
680 | >>> for i, arc in enumerate(arcs): # change alpha values of arcs
681 | ... arc.set_alpha(alphas[i])
682 |
683 | Also see the NetworkX drawing examples at
684 | https://networkx.github.io/documentation/latest/auto_examples/index.html
685 |
686 | See Also
687 | --------
688 | draw()
689 | draw_networkx()
690 | draw_networkx_nodes()
691 | draw_networkx_labels()
692 | draw_networkx_edge_labels()
693 | """
694 | try:
695 | import matplotlib.pyplot as plt
696 | from matplotlib.colors import colorConverter, Colormap, Normalize
697 | from matplotlib.collections import LineCollection
698 | from matplotlib.patches import FancyArrowPatch
699 | import numpy as np
700 | except ImportError as e:
701 | raise ImportError("Matplotlib required for draw()") from e
702 | except RuntimeError:
703 | print("Matplotlib unable to open display")
704 | raise
705 |
706 | if ax is None:
707 | ax = plt.gca()
708 |
709 | if edgelist is None:
710 | edgelist = list(G.edges())
711 |
712 | if len(edgelist) == 0: # no edges!
713 | if not G.is_directed() or not arrows:
714 | return LineCollection(None)
715 | else:
716 | return []
717 |
718 | if nodelist is None:
719 | nodelist = list(G.nodes())
720 |
721 | # FancyArrowPatch handles color=None different from LineCollection
722 | if edge_color is None:
723 | edge_color = "k"
724 |
725 | # set edge positions
726 | edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
727 |
728 | # Check if edge_color is an array of floats and map to edge_cmap.
729 | # This is the only case handled differently from matplotlib
730 | if (
731 | np.iterable(edge_color)
732 | and (len(edge_color) == len(edge_pos))
733 | and np.alltrue([isinstance(c, Number) for c in edge_color])
734 | ):
735 | if edge_cmap is not None:
736 | assert isinstance(edge_cmap, Colormap)
737 | else:
738 | edge_cmap = plt.get_cmap()
739 | if edge_vmin is None:
740 | edge_vmin = min(edge_color)
741 | if edge_vmax is None:
742 | edge_vmax = max(edge_color)
743 | color_normal = Normalize(vmin=edge_vmin, vmax=edge_vmax)
744 | edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
745 |
746 | if not G.is_directed() or not arrows:
747 | edge_collection = LineCollection(
748 | edge_pos,
749 | colors=edge_color,
750 | linewidths=width,
751 | antialiaseds=(1,),
752 | linestyle=style,
753 | transOffset=ax.transData,
754 | alpha=alpha,
755 | )
756 |
757 | edge_collection.set_cmap(edge_cmap)
758 | edge_collection.set_clim(edge_vmin, edge_vmax)
759 |
760 | edge_collection.set_zorder(1) # edges go behind nodes
761 | edge_collection.set_label(label)
762 | ax.add_collection(edge_collection)
763 |
764 | return edge_collection
765 |
766 | arrow_collection = None
767 |
768 | if G.is_directed() and arrows:
769 | # Note: Waiting for someone to implement arrow to intersection with
770 | # marker. Meanwhile, this works well for polygons with more than 4
771 | # sides and circle.
772 |
773 | def to_marker_edge(marker_size, marker):
774 | if marker in "s^>v>> G = nx.dodecahedral_graph()
929 | >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G))
930 |
931 | Also see the NetworkX drawing examples at
932 | https://networkx.github.io/documentation/latest/auto_examples/index.html
933 |
934 | See Also
935 | --------
936 | draw()
937 | draw_networkx()
938 | draw_networkx_nodes()
939 | draw_networkx_edges()
940 | draw_networkx_edge_labels()
941 | """
942 | try:
943 | import matplotlib.pyplot as plt
944 | except ImportError as e:
945 | raise ImportError("Matplotlib required for draw()") from e
946 | except RuntimeError:
947 | print("Matplotlib unable to open display")
948 | raise
949 |
950 | if ax is None:
951 | ax = plt.gca()
952 |
953 | if labels is None:
954 | labels = {n: n for n in G.nodes()}
955 |
956 | text_items = {} # there is no text collection so we'll fake one
957 | for n, label in labels.items():
958 | (x, y) = pos[n]
959 | if not isinstance(label, str):
960 | label = str(label) # this makes "1" and 1 labeled the same
961 | t = ax.text(
962 | x,
963 | y,
964 | label,
965 | size=font_size,
966 | color=font_color,
967 | family=font_family,
968 | weight=font_weight,
969 | alpha=alpha,
970 | horizontalalignment=horizontalalignment,
971 | verticalalignment=verticalalignment,
972 | transform=ax.transData,
973 | bbox=bbox,
974 | clip_on=True,
975 | )
976 | text_items[n] = t
977 |
978 | ax.tick_params(
979 | axis="both",
980 | which="both",
981 | bottom=False,
982 | left=False,
983 | labelbottom=False,
984 | labelleft=False,
985 | )
986 |
987 | return text_items
988 |
989 |
990 |
991 | def draw_networkx_edge_labels(
992 | G,
993 | pos,
994 | edge_labels=None,
995 | label_pos=0.5,
996 | font_size=10,
997 | font_color="k",
998 | font_family="sans-serif",
999 | font_weight="normal",
1000 | alpha=None,
1001 | bbox=None,
1002 | horizontalalignment="center",
1003 | verticalalignment="center",
1004 | ax=None,
1005 | rotate=True,
1006 | ):
1007 | """Draw edge labels.
1008 |
1009 | Parameters
1010 | ----------
1011 | G : graph
1012 | A networkx graph
1013 |
1014 | pos : dictionary
1015 | A dictionary with nodes as keys and positions as values.
1016 | Positions should be sequences of length 2.
1017 |
1018 | ax : Matplotlib Axes object, optional
1019 | Draw the graph in the specified Matplotlib axes.
1020 |
1021 | alpha : float or None
1022 | The text transparency (default=None)
1023 |
1024 | edge_labels : dictionary
1025 | Edge labels in a dictionary keyed by edge two-tuple of text
1026 | labels (default=None). Only labels for the keys in the dictionary
1027 | are drawn.
1028 |
1029 | label_pos : float
1030 | Position of edge label along edge (0=head, 0.5=center, 1=tail)
1031 |
1032 | font_size : int
1033 | Font size for text labels (default=12)
1034 |
1035 | font_color : string
1036 | Font color string (default='k' black)
1037 |
1038 | font_weight : string
1039 | Font weight (default='normal')
1040 |
1041 | font_family : string
1042 | Font family (default='sans-serif')
1043 |
1044 | bbox : Matplotlib bbox
1045 | Specify text box shape and colors.
1046 |
1047 | clip_on : bool
1048 | Turn on clipping at axis boundaries (default=True)
1049 |
1050 | horizontalalignment : {'center', 'right', 'left'}
1051 | Horizontal alignment (default='center')
1052 |
1053 | verticalalignment : {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
1054 | Vertical alignment (default='center')
1055 |
1056 | ax : Matplotlib Axes object, optional
1057 | Draw the graph in the specified Matplotlib axes.
1058 |
1059 | Returns
1060 | -------
1061 | dict
1062 | `dict` of labels keyed on the edges
1063 |
1064 | Examples
1065 | --------
1066 | >>> G = nx.dodecahedral_graph()
1067 | >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
1068 |
1069 | Also see the NetworkX drawing examples at
1070 | https://networkx.github.io/documentation/latest/auto_examples/index.html
1071 |
1072 | See Also
1073 | --------
1074 | draw()
1075 | draw_networkx()
1076 | draw_networkx_nodes()
1077 | draw_networkx_edges()
1078 | draw_networkx_labels()
1079 | """
1080 | try:
1081 | import matplotlib.pyplot as plt
1082 | import numpy as np
1083 | except ImportError as e:
1084 | raise ImportError("Matplotlib required for draw()") from e
1085 | except RuntimeError:
1086 | print("Matplotlib unable to open display")
1087 | raise
1088 |
1089 | if ax is None:
1090 | ax = plt.gca()
1091 | if edge_labels is None:
1092 | labels = {(u, v): d for u, v, d in G.edges(data=True)}
1093 | else:
1094 | labels = edge_labels
1095 | text_items = {}
1096 | for (n1, n2), label in labels.items():
1097 | (x1, y1) = pos[n1]
1098 | (x2, y2) = pos[n2]
1099 | (x, y) = (
1100 | x1 * label_pos + x2 * (1.0 - label_pos),
1101 | y1 * label_pos + y2 * (1.0 - label_pos),
1102 | )
1103 |
1104 | if rotate:
1105 | # in degrees
1106 | angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
1107 | # make label orientation "right-side-up"
1108 | if angle > 90:
1109 | angle -= 180
1110 | if angle < -90:
1111 | angle += 180
1112 | # transform data coordinate angle to screen coordinate angle
1113 | xy = np.array((x, y))
1114 | trans_angle = ax.transData.transform_angles(
1115 | np.array((angle,)), xy.reshape((1, 2))
1116 | )[0]
1117 | else:
1118 | trans_angle = 0.0
1119 | # use default box of white with white border
1120 | if bbox is None:
1121 | bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0))
1122 | if not isinstance(label, str):
1123 | label = str(label) # this makes "1" and 1 labeled the same
1124 |
1125 | t = ax.text(
1126 | x,
1127 | y,
1128 | label,
1129 | size=font_size,
1130 | color=font_color,
1131 | family=font_family,
1132 | weight=font_weight,
1133 | alpha=alpha,
1134 | horizontalalignment=horizontalalignment,
1135 | verticalalignment=verticalalignment,
1136 | rotation=trans_angle,
1137 | transform=ax.transData,
1138 | bbox=bbox,
1139 | zorder=1,
1140 | clip_on=True,
1141 | )
1142 | text_items[(n1, n2)] = t
1143 |
1144 | ax.tick_params(
1145 | axis="both",
1146 | which="both",
1147 | bottom=False,
1148 | left=False,
1149 | labelbottom=False,
1150 | labelleft=False,
1151 | )
1152 |
1153 | return text_items
1154 |
1155 |
1156 |
1157 | def draw_circular(G, **kwargs):
1158 | """Draw the graph G with a circular layout.
1159 |
1160 | Parameters
1161 | ----------
1162 | G : graph
1163 | A networkx graph
1164 |
1165 | kwargs : optional keywords
1166 | See networkx.draw_networkx() for a description of optional keywords,
1167 | with the exception of the pos parameter which is not used by this
1168 | function.
1169 | """
1170 | draw(G, circular_layout(G), **kwargs)
1171 |
1172 |
1173 |
1174 | def draw_kamada_kawai(G, **kwargs):
1175 | """Draw the graph G with a Kamada-Kawai force-directed layout.
1176 |
1177 | Parameters
1178 | ----------
1179 | G : graph
1180 | A networkx graph
1181 |
1182 | kwargs : optional keywords
1183 | See networkx.draw_networkx() for a description of optional keywords,
1184 | with the exception of the pos parameter which is not used by this
1185 | function.
1186 | """
1187 | draw(G, kamada_kawai_layout(G), **kwargs)
1188 |
1189 |
1190 |
1191 | def draw_random(G, **kwargs):
1192 | """Draw the graph G with a random layout.
1193 |
1194 | Parameters
1195 | ----------
1196 | G : graph
1197 | A networkx graph
1198 |
1199 | kwargs : optional keywords
1200 | See networkx.draw_networkx() for a description of optional keywords,
1201 | with the exception of the pos parameter which is not used by this
1202 | function.
1203 | """
1204 | draw(G, random_layout(G), **kwargs)
1205 |
1206 |
1207 |
1208 | def draw_spectral(G, **kwargs):
1209 | """Draw the graph G with a spectral 2D layout.
1210 |
1211 | Using the unnormalized Laplacian, the layout shows possible clusters of
1212 | nodes which are an approximation of the ratio cut. The positions are the
1213 | entries of the second and third eigenvectors corresponding to the
1214 | ascending eigenvalues starting from the second one.
1215 |
1216 | Parameters
1217 | ----------
1218 | G : graph
1219 | A networkx graph
1220 |
1221 | kwargs : optional keywords
1222 | See networkx.draw_networkx() for a description of optional keywords,
1223 | with the exception of the pos parameter which is not used by this
1224 | function.
1225 | """
1226 | draw(G, spectral_layout(G), **kwargs)
1227 |
1228 |
1229 |
1230 | def draw_spring(G, **kwargs):
1231 | """Draw the graph G with a spring layout.
1232 |
1233 | Parameters
1234 | ----------
1235 | G : graph
1236 | A networkx graph
1237 |
1238 | kwargs : optional keywords
1239 | See networkx.draw_networkx() for a description of optional keywords,
1240 | with the exception of the pos parameter which is not used by this
1241 | function.
1242 | """
1243 | draw(G, spring_layout(G), **kwargs)
1244 |
1245 |
1246 |
1247 | def draw_shell(G, **kwargs):
1248 | """Draw networkx graph with shell layout.
1249 |
1250 | Parameters
1251 | ----------
1252 | G : graph
1253 | A networkx graph
1254 |
1255 | kwargs : optional keywords
1256 | See networkx.draw_networkx() for a description of optional keywords,
1257 | with the exception of the pos parameter which is not used by this
1258 | function.
1259 | """
1260 | nlist = kwargs.get("nlist", None)
1261 | if nlist is not None:
1262 | del kwargs["nlist"]
1263 | draw(G, shell_layout(G, nlist=nlist), **kwargs)
1264 |
1265 |
1266 |
1267 | def draw_planar(G, **kwargs):
1268 | """Draw a planar networkx graph with planar layout.
1269 |
1270 | Parameters
1271 | ----------
1272 | G : graph
1273 | A planar networkx graph
1274 |
1275 | kwargs : optional keywords
1276 | See networkx.draw_networkx() for a description of optional keywords,
1277 | with the exception of the pos parameter which is not used by this
1278 | function.
1279 | """
1280 | draw(G, planar_layout(G), **kwargs)
1281 |
1282 |
1283 |
1284 | def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
1285 | """Apply an alpha (or list of alphas) to the colors provided.
1286 |
1287 | Parameters
1288 | ----------
1289 |
1290 | colors : color string, or array of floats
1291 | Color of element. Can be a single color format string (default='r'),
1292 | or a sequence of colors with the same length as nodelist.
1293 | If numeric values are specified they will be mapped to
1294 | colors using the cmap and vmin,vmax parameters. See
1295 | matplotlib.scatter for more details.
1296 |
1297 | alpha : float or array of floats
1298 | Alpha values for elements. This can be a single alpha value, in
1299 | which case it will be applied to all the elements of color. Otherwise,
1300 | if it is an array, the elements of alpha will be applied to the colors
1301 | in order (cycling through alpha multiple times if necessary).
1302 |
1303 | elem_list : array of networkx objects
1304 | The list of elements which are being colored. These could be nodes,
1305 | edges or labels.
1306 |
1307 | cmap : matplotlib colormap
1308 | Color map for use if colors is a list of floats corresponding to points
1309 | on a color mapping.
1310 |
1311 | vmin, vmax : float
1312 | Minimum and maximum values for normalizing colors if a color mapping is
1313 | used.
1314 |
1315 | Returns
1316 | -------
1317 |
1318 | rgba_colors : numpy ndarray
1319 | Array containing RGBA format values for each of the node colours.
1320 |
1321 | """
1322 | from itertools import islice, cycle
1323 |
1324 | try:
1325 | import numpy as np
1326 | from matplotlib.colors import colorConverter
1327 | import matplotlib.cm as cm
1328 | except ImportError as e:
1329 | raise ImportError("Matplotlib required for draw()") from e
1330 |
1331 | # If we have been provided with a list of numbers as long as elem_list,
1332 | # apply the color mapping.
1333 | if len(colors) == len(elem_list) and isinstance(colors[0], Number):
1334 | mapper = cm.ScalarMappable(cmap=cmap)
1335 | mapper.set_clim(vmin, vmax)
1336 | rgba_colors = mapper.to_rgba(colors)
1337 | # Otherwise, convert colors to matplotlib's RGB using the colorConverter
1338 | # object. These are converted to numpy ndarrays to be consistent with the
1339 | # to_rgba method of ScalarMappable.
1340 | else:
1341 | try:
1342 | rgba_colors = np.array([colorConverter.to_rgba(colors)])
1343 | except ValueError:
1344 | rgba_colors = np.array([colorConverter.to_rgba(color) for color in colors])
1345 | # Set the final column of the rgba_colors to have the relevant alpha values
1346 | try:
1347 | # If alpha is longer than the number of colors, resize to the number of
1348 | # elements. Also, if rgba_colors.size (the number of elements of
1349 | # rgba_colors) is the same as the number of elements, resize the array,
1350 | # to avoid it being interpreted as a colormap by scatter()
1351 | if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
1352 | rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
1353 | rgba_colors[1:, 0] = rgba_colors[0, 0]
1354 | rgba_colors[1:, 1] = rgba_colors[0, 1]
1355 | rgba_colors[1:, 2] = rgba_colors[0, 2]
1356 | rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
1357 | except TypeError:
1358 | rgba_colors[:, -1] = alpha
1359 | return rgba_colors
--------------------------------------------------------------------------------
/dataset_utils/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset
3 | import pickle
4 |
5 |
6 | def read_pickle(name):
7 | with open(name, "rb") as f:
8 | data = pickle.load(f)
9 | return data
10 |
11 |
12 | def write_pickle(data, name):
13 | with open(name, "wb") as f:
14 | pickle.dump(data, f)
15 |
16 |
17 | class ToyDataset(Dataset):
18 | def __init__(self, pkl, domain_id, opt=None):
19 | idx = pkl["domain"] == domain_id
20 | self.data = pkl["data"][idx].astype(np.float32)
21 | self.label = pkl["label"][idx].astype(np.int64)
22 | self.domain = domain_id
23 |
24 | # if opt.normalize_domain:
25 | # print('===> Normalize in every domain')
26 | # self.data_m, self.data_s = self.data.mean(0, keepdims=True), self.data.std(0, keepdims=True)
27 | # self.data = (self.data - self.data_m) / self.data_s
28 |
29 | def __getitem__(self, idx):
30 | return self.data[idx], self.label[idx], self.domain
31 |
32 | def __len__(self):
33 | return len(self.data)
34 |
35 |
36 | class SeqToyDataset(Dataset):
37 | def __init__(self, datasets, size=3 * 200):
38 | self.datasets = datasets
39 | self.size = size
40 | print(
41 | "SeqDataset Size {} Sub Size {}".format(
42 | size, [len(ds) for ds in datasets]
43 | )
44 | )
45 |
46 | def __len__(self):
47 | return self.size
48 |
49 | def __getitem__(self, i):
50 | return [ds[i] for ds in self.datasets]
51 |
--------------------------------------------------------------------------------
/derive_g_encode/499_pred_GDA_new.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/derive_g_encode/499_pred_GDA_new.pkl
--------------------------------------------------------------------------------
/derive_g_encode/derive.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 |
4 | def read_pickle(name):
5 | with open(name, "rb") as f:
6 | data = pickle.load(f)
7 | return data
8 |
9 |
10 | def write_pickle(data, name):
11 | with open(name, "wb") as f:
12 | pickle.dump(data, f)
13 |
14 |
15 | # read_file = "499_pred_GDA_new.pkl"
16 | read_file = "499_pred.pkl"
17 | num_domain = 60
18 |
19 | info = read_pickle(read_file)
20 | z = info["z"]
21 |
22 | # print(z)
23 |
24 | g_encode = dict()
25 | for i in range(num_domain):
26 | g_encode[str(i)] = z[i]
27 |
28 | write_pickle(g_encode, "g_encode_60.pkl")
29 | print("success!")
30 |
--------------------------------------------------------------------------------
/derive_g_encode/g_encode.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/derive_g_encode/g_encode.pkl
--------------------------------------------------------------------------------
/derive_g_encode/g_encode_15.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/derive_g_encode/g_encode_15.pkl
--------------------------------------------------------------------------------
/derive_g_encode/g_encode_60.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/derive_g_encode/g_encode_60.pkl
--------------------------------------------------------------------------------
/fig/GRDA-DG-15-results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/GRDA-DG-15-results.png
--------------------------------------------------------------------------------
/fig/GRDA-domain-graph-US.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/GRDA-domain-graph-US.png
--------------------------------------------------------------------------------
/fig/blog-method-DA-vs-GRDA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/blog-method-DA-vs-GRDA.png
--------------------------------------------------------------------------------
/fig/compcar_quantitive_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/compcar_quantitive_result.jpg
--------------------------------------------------------------------------------
/fig/dg_15_60_quantitive_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/dg_15_60_quantitive_result.jpg
--------------------------------------------------------------------------------
/fig/tpt_48_quantitive_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wang-ML-Lab/GRDA/96800d62577b8046af3fbc9ade0a51c2b71129c1/fig/tpt_48_quantitive_result.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | import pickle
5 |
6 | # import the experiment setting
7 |
8 | from configs.config_random_15_new import opt
9 |
10 | # load the data
11 | from torch.utils.data import DataLoader
12 | from dataset_utils.dataset import ToyDataset, SeqToyDataset
13 |
14 | # actually the config doesn't change much
15 | # from configs.config_random_60_new import opt
16 |
17 | np.random.seed(opt.seed)
18 | random.seed(opt.seed)
19 | torch.manual_seed(opt.seed)
20 |
21 | if opt.model == "DANN":
22 | from model.model import DANN as Model
23 | elif opt.model == "GDA":
24 | from model.model import GDA as Model
25 | elif opt.model == "CDANN":
26 | from model.model import CDANN as Model
27 |
28 | opt.cond_disc = True
29 | elif opt.model == "ADDA":
30 | from model.model import ADDA as Model
31 | elif opt.model == "MDD":
32 | from model.model import MDD as Model
33 | model = Model(opt).to(opt.device) # .double()
34 |
35 | data_source = opt.dataset
36 |
37 | with open(data_source, "rb") as data_file:
38 | data_pkl = pickle.load(data_file)
39 | print(f"Data: {data_pkl['data'].shape}\nLabel: {data_pkl['label'].shape}")
40 |
41 | # build dataset
42 | opt.A = data_pkl["A"]
43 |
44 | data = data_pkl["data"]
45 | data_mean = data.mean(0, keepdims=True)
46 | data_std = data.std(0, keepdims=True)
47 | data_pkl["data"] = (data - data_mean) / data_std # normalize the raw data
48 | datasets = [
49 | ToyDataset(data_pkl, i, opt) for i in range(opt.num_domain)
50 | ] # sub dataset for each domain
51 |
52 |
53 | dataset = SeqToyDataset(
54 | datasets, size=len(datasets[0])
55 | ) # mix sub dataset to a large one
56 | dataloader = DataLoader(
57 | dataset=dataset, shuffle=True, batch_size=opt.batch_size
58 | )
59 |
60 | # train
61 | for epoch in range(opt.num_epoch):
62 | model.learn(epoch, dataloader)
63 | if (epoch + 1) % opt.save_interval == 0 or (epoch + 1) == opt.num_epoch:
64 | model.save()
65 | if (epoch + 1) % opt.test_interval == 0 or (epoch + 1) == opt.num_epoch:
66 | model.test(epoch, dataloader)
67 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import torch.optim.lr_scheduler as lr_scheduler
7 | import numpy as np
8 | from model.modules import (
9 | FeatureNet,
10 | PredNet,
11 | GNet,
12 | ClassDiscNet,
13 | CondClassDiscNet,
14 | DiscNet,
15 | GraphDNet,
16 | )
17 | import pickle
18 | from visdom import Visdom
19 |
20 | # ===========================================================================================================
21 |
22 |
23 | def to_np(x):
24 | return x.detach().cpu().numpy()
25 |
26 |
27 | def to_tensor(x, device="cuda"):
28 | if isinstance(x, np.ndarray):
29 | x = torch.from_numpy(x).to(device)
30 | else:
31 | x = x.to(device)
32 | return x
33 |
34 |
35 | def flat(x):
36 | n, m = x.shape[:2]
37 | return x.reshape(n * m, *x.shape[2:])
38 |
39 |
40 | def write_pickle(data, name):
41 | with open(name, "wb") as f:
42 | pickle.dump(data, f)
43 |
44 |
45 | # ======================================================================================================================
46 |
47 |
48 | # the base model
49 | class BaseModel(nn.Module):
50 | def __init__(self, opt):
51 | super(BaseModel, self).__init__()
52 | # set output format
53 | np.set_printoptions(suppress=True, precision=6)
54 |
55 | self.opt = opt
56 | self.device = opt.device
57 | self.batch_size = opt.batch_size
58 | # visualizaiton
59 | self.use_visdom = opt.use_visdom
60 | self.use_g_encode = opt.use_g_encode
61 | if opt.use_visdom:
62 | self.env = Visdom(port=opt.visdom_port)
63 | self.test_pane = dict()
64 |
65 | self.num_domain = opt.num_domain
66 | if self.opt.test_on_all_dmn:
67 | self.test_dmn_num = self.num_domain
68 | else:
69 | self.test_dmn_num = self.opt.tgt_dmn_num
70 |
71 | self.train_log = self.opt.outf + "/loss.log"
72 | self.model_path = opt.outf + "/model.pth"
73 | self.out_pic_f = opt.outf + "/plt_pic"
74 | if not os.path.exists(self.opt.outf):
75 | os.mkdir(self.opt.outf)
76 | if not os.path.exists(self.out_pic_f):
77 | os.mkdir(self.out_pic_f)
78 | with open(self.train_log, "w") as f:
79 | f.write("log start!\n")
80 |
81 | mask_list = np.zeros(opt.num_domain)
82 | mask_list[opt.src_domain] = 1
83 | self.domain_mask = torch.IntTensor(mask_list).to(
84 | opt.device
85 | )
86 |
87 | def learn(self, epoch, dataloader):
88 | self.train()
89 | if self.use_g_encode:
90 | self.netG.eval()
91 | self.epoch = epoch
92 | loss_values = {loss: 0 for loss in self.loss_names}
93 |
94 | count = 0
95 | for data in dataloader:
96 | count += 1
97 | self.__set_input__(data)
98 | self.__train_forward__()
99 | new_loss_values = self.__optimize__()
100 |
101 | # for the loss visualization
102 | for key, loss in new_loss_values.items():
103 | loss_values[key] += loss
104 |
105 | for key, _ in new_loss_values.items():
106 | loss_values[key] /= count
107 |
108 | if self.use_visdom:
109 | self.__vis_loss__(loss_values)
110 |
111 | if (self.epoch + 1) % 10 == 0:
112 | print("epoch {}: {}".format(self.epoch, loss_values))
113 |
114 | # learning rate decay
115 | for lr_scheduler in self.lr_schedulers:
116 | lr_scheduler.step()
117 |
118 | def test(self, epoch, dataloader):
119 | self.eval()
120 |
121 | acc_curve = []
122 | l_x = []
123 | l_domain = []
124 | l_label = []
125 | l_encode = []
126 | z_seq = 0
127 |
128 | for data in dataloader:
129 | self.__set_input__(data)
130 |
131 | # forward
132 | with torch.no_grad():
133 | self.__test_forward__()
134 |
135 | acc_curve.append(
136 | self.g_seq.eq(self.y_seq)
137 | .to(torch.float)
138 | .mean(-1, keepdim=True)
139 | )
140 | l_x.append(to_np(self.x_seq))
141 | l_domain.append(to_np(self.domain_seq))
142 | l_encode.append(to_np(self.e_seq))
143 | l_label.append(to_np(self.g_seq))
144 |
145 | x_all = np.concatenate(l_x, axis=1)
146 | e_all = np.concatenate(l_encode, axis=1)
147 | domain_all = np.concatenate(l_domain, axis=1)
148 | label_all = np.concatenate(l_label, axis=1)
149 |
150 | z_seq = to_np(self.z_seq)
151 | z_seq_all = z_seq[
152 | 0 : self.tmp_batch_size * self.test_dmn_num : self.tmp_batch_size, :
153 | ]
154 |
155 | d_all = dict()
156 |
157 | d_all["data"] = flat(x_all)
158 | d_all["domain"] = flat(domain_all)
159 | d_all["label"] = flat(label_all)
160 | d_all["encodeing"] = flat(e_all)
161 | d_all["z"] = z_seq_all
162 |
163 | acc = to_np(torch.cat(acc_curve, 1).mean(-1))
164 | test_acc = (
165 | (acc.sum() - acc[self.opt.src_domain].sum())
166 | / (self.opt.num_target)
167 | * 100
168 | )
169 | acc_msg = "[Test][{}] Accuracy: total average {:.1f}, test average {:.1f}, in each domain {}".format(
170 | epoch, acc.mean() * 100, test_acc, np.around(acc * 100, decimals=1)
171 | )
172 | self.__log_write__(acc_msg)
173 | if self.use_visdom:
174 | self.__vis_test_error__(test_acc, "test acc")
175 |
176 | d_all["acc_msg"] = acc_msg
177 |
178 | write_pickle(d_all, self.opt.outf + "/" + str(epoch) + "_pred.pkl")
179 |
180 | def __vis_test_error__(self, loss, title):
181 | if self.epoch == self.opt.test_interval - 1:
182 | # initialize
183 | self.test_pane[title] = self.env.line(
184 | X=np.array([self.epoch]),
185 | Y=np.array([loss]),
186 | opts=dict(title=title),
187 | )
188 | else:
189 | self.env.line(
190 | X=np.array([self.epoch]),
191 | Y=np.array([loss]),
192 | win=self.test_pane[title],
193 | update="append",
194 | )
195 |
196 | def save(self):
197 | torch.save(self.state_dict(), self.model_path)
198 |
199 | def __set_input__(self, data, train=True):
200 | """
201 | :param
202 | x_seq: Number of domain x Batch size x Data dim
203 | y_seq: Number of domain x Batch size x Predict Data dim
204 | one_hot_seq: Number of domain x Batch size x Number of vertices (domains)
205 | domain_seq: Number of domain x Batch size x domain dim (1)
206 | """
207 | # the domain seq is in d3!!
208 | x_seq, y_seq, domain_seq = (
209 | [d[0][None, :, :] for d in data],
210 | [d[1][None, :] for d in data],
211 | [d[2][None, :] for d in data],
212 | )
213 | self.x_seq = torch.cat(x_seq, 0).to(self.device)
214 | self.y_seq = torch.cat(y_seq, 0).to(self.device)
215 | self.domain_seq = torch.cat(domain_seq, 0).to(self.device)
216 | self.tmp_batch_size = self.x_seq.shape[1]
217 | one_hot_seq = [
218 | torch.nn.functional.one_hot(d[2], self.num_domain)
219 | for d in data
220 | ]
221 |
222 | if train:
223 | self.one_hot_seq = (
224 | torch.cat(one_hot_seq, 0)
225 | .reshape(self.num_domain, self.tmp_batch_size, -1)
226 | .to(self.device)
227 | )
228 | else:
229 | self.one_hot_seq = (
230 | torch.cat(one_hot_seq, 0)
231 | .reshape(self.test_dmn_num, self.tmp_batch_size, -1)
232 | .to(self.device)
233 | )
234 |
235 | def __train_forward__(self):
236 | pass
237 |
238 | def __test_forward__(self):
239 | self.z_seq = self.netG(self.one_hot_seq)
240 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data
241 | self.f_seq = self.netF(self.e_seq)
242 | self.g_seq = torch.argmax(
243 | self.f_seq.detach(), dim=2
244 | ) # class of the prediction
245 |
246 | def __optimize__(self):
247 | loss_value = dict()
248 | if not self.use_g_encode:
249 | self.loss_G = self.__loss_G__()
250 |
251 | if self.opt.lambda_gan != 0:
252 | self.loss_D = self.__loss_D__()
253 |
254 | self.loss_E, self.loss_E_pred, self.loss_E_gan = self.__loss_EF__()
255 |
256 | if not self.use_g_encode:
257 | self.optimizer_G.zero_grad()
258 | self.loss_G.backward(retain_graph=True)
259 | self.optimizer_D.zero_grad()
260 | self.loss_D.backward(retain_graph=True)
261 | self.optimizer_EF.zero_grad()
262 | self.loss_E.backward()
263 |
264 | if not self.use_g_encode:
265 | self.optimizer_G.step()
266 | self.optimizer_D.step()
267 | self.optimizer_EF.step()
268 |
269 | if not self.use_g_encode:
270 | loss_value["G"] = self.loss_G.item()
271 | loss_value["D"], loss_value["E_pred"], loss_value["E_gan"] = \
272 | self.loss_D.item(), self.loss_E_pred.item(), self.loss_E_gan.item()
273 | return loss_value
274 |
275 | def __loss_G__(self):
276 | criterion = nn.BCEWithLogitsLoss()
277 |
278 | sub_graph = self.__sub_graph__(my_sample_v=self.opt.sample_v_g)
279 | errorG = torch.zeros((1,)).to(self.device)
280 | sample_v = self.opt.sample_v_g
281 |
282 | for i in range(sample_v):
283 | v_i = sub_graph[i]
284 | for j in range(i + 1, sample_v):
285 | v_j = sub_graph[j]
286 | label = torch.tensor(self.opt.A[v_i][v_j]).to(self.device)
287 | output = (
288 | self.z_seq[v_i * self.tmp_batch_size]
289 | * self.z_seq[v_j * self.tmp_batch_size]
290 | ).sum()
291 | errorG += criterion(output, label)
292 |
293 | errorG /= sample_v * (sample_v - 1) / 2
294 | return errorG
295 |
296 | def __loss_D__(self):
297 | pass
298 |
299 | def __loss_EF__(self):
300 | pass
301 |
302 | def __log_write__(self, loss_msg):
303 | print(loss_msg)
304 | with open(self.train_log, "a") as f:
305 | f.write(loss_msg + "\n")
306 |
307 | def __vis_loss__(self, loss_values):
308 | if self.epoch == 0:
309 | self.panes = {
310 | loss_name: self.env.line(
311 | X=np.array([self.epoch]),
312 | Y=np.array([loss_values[loss_name]]),
313 | opts=dict(title="loss for {} on epochs".format(loss_name)),
314 | )
315 | for loss_name in self.loss_names
316 | }
317 | else:
318 | for loss_name in self.loss_names:
319 | self.env.line(
320 | X=np.array([self.epoch]),
321 | Y=np.array([loss_values[loss_name]]),
322 | win=self.panes[loss_name],
323 | update="append",
324 | )
325 |
326 | def __init_weight__(self, net=None):
327 | if net is None:
328 | net = self
329 | for m in net.modules():
330 | if isinstance(m, nn.Linear):
331 | nn.init.normal_(m.weight, mean=0, std=0.01)
332 | # nn.init.normal_(m.weight, mean=0, std=0.1)
333 | # nn.init.xavier_normal_(m.weight, gain=10)
334 | nn.init.constant_(m.bias, val=0)
335 |
336 | # for graph random sampling:
337 | def __rand_walk__(self, vis, left_nodes):
338 | chain_node = []
339 | node_num = 0
340 | # choose node
341 | node_index = np.where(vis == 0)[0]
342 | st = np.random.choice(node_index)
343 | vis[st] = 1
344 | chain_node.append(st)
345 | left_nodes -= 1
346 | node_num += 1
347 |
348 | cur_node = st
349 | while left_nodes > 0:
350 | nx_node = -1
351 |
352 | node_to_choose = np.where(vis == 0)[0]
353 | num = node_to_choose.shape[0]
354 | node_to_choose = np.random.choice(
355 | node_to_choose, num, replace=False
356 | )
357 |
358 | for i in node_to_choose:
359 | if cur_node != i:
360 | # have an edge and doesn't visit
361 | if self.opt.A[cur_node][i] and not vis[i]:
362 | nx_node = i
363 | vis[nx_node] = 1
364 | chain_node.append(nx_node)
365 | left_nodes -= 1
366 | node_num += 1
367 | break
368 | if nx_node >= 0:
369 | cur_node = nx_node
370 | else:
371 | break
372 | return chain_node, node_num
373 |
374 | def __sub_graph__(self, my_sample_v):
375 | if np.random.randint(0, 2) == 0:
376 | return np.random.choice(
377 | self.num_domain, size=my_sample_v, replace=False
378 | )
379 |
380 | # subsample a chain (or multiple chains in graph)
381 | left_nodes = my_sample_v
382 | choosen_node = []
383 | vis = np.zeros(self.num_domain)
384 | while left_nodes > 0:
385 | chain_node, node_num = self.__rand_walk__(vis, left_nodes)
386 | choosen_node.extend(chain_node)
387 | left_nodes -= node_num
388 |
389 | return choosen_node
390 |
391 |
392 | class DANN(BaseModel):
393 | """
394 | DANN Model
395 | """
396 |
397 | def __init__(self, opt):
398 | super(DANN, self).__init__(opt)
399 | self.netE = FeatureNet(opt).to(opt.device)
400 | self.netF = PredNet(opt).to(opt.device)
401 | self.netG = GNet(opt).to(opt.device)
402 | self.netD = ClassDiscNet(opt).to(opt.device)
403 |
404 | self.__init_weight__()
405 | EF_parameters = list(self.netE.parameters()) + list(
406 | self.netF.parameters()
407 | )
408 | self.optimizer_EF = optim.Adam(
409 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999)
410 | )
411 | self.optimizer_D = optim.Adam(
412 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)
413 | )
414 | if not self.use_g_encode:
415 | self.optimizer_G = optim.Adam(
416 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)
417 | )
418 |
419 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR(
420 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100)
421 | )
422 | self.lr_scheduler_D = lr_scheduler.ExponentialLR(
423 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100)
424 | )
425 | if not self.use_g_encode:
426 | self.lr_scheduler_G = lr_scheduler.ExponentialLR(
427 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100)
428 | )
429 | self.lr_schedulers = [
430 | self.lr_scheduler_EF,
431 | self.lr_scheduler_D,
432 | self.lr_scheduler_G,
433 | ]
434 | else:
435 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D]
436 | self.loss_names = ["E_pred", "E_gan", "D", "G"]
437 | if self.use_g_encode:
438 | self.loss_names.remove("G")
439 |
440 | self.lambda_gan = self.opt.lambda_gan
441 |
442 | def __train_forward__(self):
443 | self.z_seq = self.netG(self.one_hot_seq)
444 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data
445 | self.f_seq = self.netF(self.e_seq)
446 | self.d_seq = self.netD(self.e_seq)
447 |
448 | def __loss_D__(self):
449 | return F.nll_loss(flat(self.d_seq), flat(self.domain_seq))
450 |
451 | def __loss_EF__(self):
452 | # we have already calculate D loss before EF loss. Directly use it.
453 | loss_E_gan = - self.loss_D
454 |
455 | y_seq_source = self.y_seq[self.domain_mask == 1]
456 | f_seq_source = self.f_seq[self.domain_mask == 1]
457 |
458 | loss_E_pred = F.nll_loss(
459 | flat(f_seq_source), flat(y_seq_source)
460 | )
461 |
462 | loss_E = loss_E_gan * self.lambda_gan + loss_E_pred
463 |
464 | return loss_E, loss_E_pred, loss_E_gan
465 |
466 |
467 | class CDANN(BaseModel):
468 | """
469 | CDANN Model
470 | """
471 |
472 | def __init__(self, opt):
473 | super(CDANN, self).__init__(opt)
474 | self.netE = FeatureNet(opt).to(opt.device)
475 | self.netF = PredNet(opt).to(opt.device)
476 | self.netG = GNet(opt).to(opt.device)
477 | self.netD = CondClassDiscNet(opt).to(opt.device)
478 |
479 | self.__init_weight__()
480 | EF_parameters = list(self.netE.parameters()) + list(
481 | self.netF.parameters()
482 | )
483 | self.optimizer_EF = optim.Adam(
484 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999)
485 | )
486 | self.optimizer_D = optim.Adam(
487 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)
488 | )
489 |
490 | if not self.use_g_encode:
491 | self.optimizer_G = optim.Adam(
492 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)
493 | )
494 |
495 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR(
496 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100)
497 | )
498 | self.lr_scheduler_D = lr_scheduler.ExponentialLR(
499 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100)
500 | )
501 | if not self.use_g_encode:
502 | self.lr_scheduler_G = lr_scheduler.ExponentialLR(
503 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100)
504 | )
505 | self.lr_schedulers = [
506 | self.lr_scheduler_EF,
507 | self.lr_scheduler_D,
508 | self.lr_scheduler_G,
509 | ]
510 | else:
511 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D]
512 | self.loss_names = ["E_pred", "E_gan", "D", "G"]
513 | if self.use_g_encode:
514 | self.loss_names.remove("G")
515 |
516 | self.lambda_gan = self.opt.lambda_gan
517 |
518 | def __train_forward__(self):
519 | self.z_seq = self.netG(self.one_hot_seq)
520 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data
521 | self.f_seq = self.netF(self.e_seq)
522 | self.f_seq_sig = torch.sigmoid(self.f_seq.detach())
523 |
524 | def __loss_D__(self):
525 | d_seq = self.netD(self.e_seq.detach(), self.f_seq_sig.detach())
526 | return F.nll_loss(flat(d_seq), flat(self.domain_seq))
527 |
528 | def __loss_EF__(self):
529 | d_seq = self.netD(self.e_seq, self.f_seq_sig.detach())
530 |
531 | loss_E_gan = -F.nll_loss(flat(d_seq), flat(self.domain_seq))
532 |
533 | y_seq_source = self.y_seq[self.domain_mask == 1]
534 | f_seq_source = self.f_seq[self.domain_mask == 1]
535 |
536 | loss_E_pred = F.nll_loss(
537 | flat(f_seq_source), flat(y_seq_source)
538 | )
539 |
540 | loss_E = loss_E_gan * self.lambda_gan + loss_E_pred
541 | return loss_E, loss_E_pred, loss_E_gan
542 |
543 |
544 | class ADDA(BaseModel):
545 | def __init__(self, opt):
546 | super(ADDA, self).__init__(opt)
547 | self.netE = FeatureNet(opt).to(opt.device)
548 | self.netF = PredNet(opt).to(opt.device)
549 | self.netG = GNet(opt).to(opt.device)
550 | self.netD = DiscNet(opt).to(opt.device)
551 | self.__init_weight__()
552 | EF_parameters = list(self.netE.parameters()) + list(
553 | self.netF.parameters()
554 | )
555 | self.optimizer_EF = optim.Adam(
556 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999)
557 | )
558 | self.optimizer_D = optim.Adam(
559 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)
560 | )
561 |
562 | if not self.use_g_encode:
563 | self.optimizer_G = optim.Adam(
564 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)
565 | )
566 |
567 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR(
568 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100)
569 | )
570 | self.lr_scheduler_D = lr_scheduler.ExponentialLR(
571 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100)
572 | )
573 | if not self.use_g_encode:
574 | self.lr_scheduler_G = lr_scheduler.ExponentialLR(
575 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100)
576 | )
577 | self.lr_schedulers = [
578 | self.lr_scheduler_EF,
579 | self.lr_scheduler_D,
580 | self.lr_scheduler_G,
581 | ]
582 | else:
583 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D]
584 | self.loss_names = ["E_pred", "E_gan", "D", "G"]
585 | if self.use_g_encode:
586 | self.loss_names.remove("G")
587 |
588 | def __train_forward__(self):
589 | self.z_seq = self.netG(self.one_hot_seq)
590 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data
591 | self.f_seq = self.netF(self.e_seq)
592 |
593 | def __loss_D__(self):
594 | d_seq = self.netD(self.e_seq.detach())
595 | d_seq_source = d_seq[self.domain_mask == 1]
596 | d_seq_target = d_seq[self.domain_mask == 0]
597 | # D: discriminator loss from classifying source v.s. target
598 | loss_D = (
599 | -torch.log(d_seq_source + 1e-10).mean()
600 | - torch.log(1 - d_seq_target + 1e-10).mean()
601 | )
602 | return loss_D
603 |
604 | def __loss_EF__(self):
605 | d_seq = self.netD(self.e_seq)
606 | d_seq_target = d_seq[self.domain_mask == 0]
607 | loss_E_gan = -torch.log(d_seq_target + 1e-10).mean()
608 | # E_pred: encoder loss from prediction the label
609 | y_seq_source = self.y_seq[self.domain_mask == 1]
610 | f_seq_source = self.f_seq[self.domain_mask == 1]
611 | loss_E_pred = F.nll_loss(
612 | flat(f_seq_source), flat(y_seq_source)
613 | )
614 |
615 | loss_E = loss_E_gan * self.opt.lambda_gan + loss_E_pred
616 |
617 | return loss_E, loss_E_pred, loss_E_gan
618 |
619 |
620 | class MDD(BaseModel):
621 | """
622 | Margin Disparity Discrepancy
623 | """
624 |
625 | def __init__(self, opt):
626 | super(MDD, self).__init__(opt)
627 | self.netE = FeatureNet(opt).to(opt.device)
628 | self.netF = PredNet(opt).to(opt.device)
629 | self.netG = GNet(opt).to(opt.device)
630 | self.netD = PredNet(opt).to(opt.device)
631 | self.__init_weight__()
632 | EF_parameters = list(self.netE.parameters()) + list(
633 | self.netF.parameters()
634 | )
635 | self.optimizer_EF = optim.Adam(
636 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999)
637 | )
638 | self.optimizer_D = optim.Adam(
639 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)
640 | )
641 | if not self.use_g_encode:
642 | self.optimizer_G = optim.Adam(
643 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)
644 | )
645 |
646 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR(
647 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100)
648 | )
649 | self.lr_scheduler_D = lr_scheduler.ExponentialLR(
650 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100)
651 | )
652 | if not self.use_g_encode:
653 | self.lr_scheduler_G = lr_scheduler.ExponentialLR(
654 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100)
655 | )
656 | self.lr_schedulers = [
657 | self.lr_scheduler_EF,
658 | self.lr_scheduler_D,
659 | self.lr_scheduler_G,
660 | ]
661 | else:
662 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D]
663 | self.loss_names = ["E_pred", "E_adv", "ADV_src", "ADV_tgt", "G"]
664 | if self.use_g_encode:
665 | self.loss_names.remove("G")
666 |
667 | self.lambda_src = opt.lambda_src
668 | self.lambda_tgt = opt.lambda_tgt
669 |
670 | def __train_forward__(self):
671 | self.z_seq = self.netG(self.one_hot_seq)
672 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data
673 | self.f_seq = self.netF(self.e_seq)
674 | self.g_seq = torch.argmax(
675 | self.f_seq.detach(), dim=2
676 | ) # class of the prediction
677 |
678 | def __optimize__(self):
679 | loss_value = dict()
680 | if not self.use_g_encode:
681 | self.loss_G = self.__loss_G__()
682 |
683 | self.loss_D, self.loss_ADV_src, self.loss_ADV_tgt = self.__loss_D__()
684 | self.loss_E, self.loss_E_pred, self.loss_E_adv = self.__loss_EF__()
685 |
686 | if not self.use_g_encode:
687 | self.optimizer_G.zero_grad()
688 | self.loss_G.backward(retain_graph=True)
689 | self.optimizer_D.zero_grad()
690 | self.loss_D.backward(retain_graph=True)
691 | self.optimizer_EF.zero_grad()
692 | self.loss_E.backward()
693 |
694 | if not self.use_g_encode:
695 | self.optimizer_G.step()
696 | self.optimizer_D.step()
697 | self.optimizer_EF.step()
698 |
699 | if not self.use_g_encode:
700 | loss_value["G"] = self.loss_G.item()
701 | loss_value["ADV_src"], loss_value["ADV_tgt"], loss_value["E_pred"], loss_value["E_adv"] = \
702 | self.loss_ADV_src.item(), self.loss_ADV_tgt.item(), self.loss_E_pred.item(), self.loss_E_adv.item()
703 | return loss_value
704 |
705 |
706 | # loss_value = dict()
707 | # if not self.use_g_encode:
708 | # loss_value["G"] = self.__optimize_G__()
709 | # # if self.opt.lambda_gan != 0:
710 | # # (
711 | # # loss_value["ADV_src"],
712 | # # loss_value["ADV_tgt"],
713 | # # ) = self.__optimize_D__()
714 | # # else:
715 | # # loss_value["ADV_src"], loss_value["ADV_tgt"] = 0
716 | # # (
717 | # loss_value["E_pred"],
718 | # loss_value["E_adv"],
719 | # ) = self.__optimize_EF__()
720 | # return loss_value
721 |
722 | def __loss_D__(self):
723 | # self.optimizer_D.zero_grad()
724 |
725 | # # backward process:
726 | # self.loss_D.backward(retain_graph=True)
727 | # return self.loss_D.item()
728 | f_adv, f_adv_softmax = self.netD(
729 | self.e_seq.detach(), return_softmax=True
730 | )
731 | # agreement with netF on source domain
732 | loss_ADV_src = F.nll_loss(
733 | flat(f_adv[self.domain_mask == 1]),
734 | flat(self.g_seq[self.domain_mask == 1]),
735 | )
736 | f_adv_tgt = torch.log(
737 | 1 - f_adv_softmax[self.domain_mask == 0] + 1e-10
738 | )
739 | # disagreement with netF on target domain
740 | loss_ADV_tgt = F.nll_loss(
741 | flat(f_adv_tgt), flat(self.g_seq[self.domain_mask == 0])
742 | )
743 | # minimize the agreement on source domain while maximize the disagreement on target domain
744 | loss_D = (
745 | loss_ADV_src * self.lambda_src
746 | + loss_ADV_tgt * self.lambda_tgt
747 | ) / (self.lambda_src + self.lambda_tgt)
748 |
749 | return loss_D, loss_ADV_src, loss_ADV_tgt
750 |
751 | def __loss_EF__(self):
752 | loss_E_pred = F.nll_loss(
753 | flat(self.f_seq[self.domain_mask == 1]),
754 | flat(self.y_seq[self.domain_mask == 1]),
755 | )
756 |
757 | f_adv, f_adv_softmax = self.netD(
758 | self.e_seq, return_softmax=True
759 | )
760 | loss_ADV_src = F.nll_loss(
761 | flat(f_adv[self.domain_mask == 1]),
762 | flat(self.g_seq[self.domain_mask == 1]),
763 | )
764 | f_adv_tgt = torch.log(
765 | 1 - f_adv_softmax[self.domain_mask == 0] + 1e-10
766 | )
767 | loss_ADV_tgt = F.nll_loss(
768 | flat(f_adv_tgt), flat(self.g_seq[self.domain_mask == 0])
769 | )
770 | loss_E_adv = -(
771 | loss_ADV_src * self.lambda_src
772 | + loss_ADV_tgt * self.lambda_tgt
773 | ) / (self.lambda_src + self.lambda_tgt)
774 | loss_E = loss_E_pred + self.opt.lambda_gan * loss_E_adv
775 |
776 | return loss_E, loss_E_pred, loss_E_adv
777 |
778 |
779 | class GDA(BaseModel):
780 | """
781 | GDA Model
782 | """
783 |
784 | def __init__(self, opt):
785 | super(GDA, self).__init__(opt)
786 | self.netE = FeatureNet(opt).to(opt.device)
787 | self.netF = PredNet(opt).to(opt.device)
788 | self.netG = GNet(opt).to(opt.device)
789 | self.netD = GraphDNet(opt).to(opt.device)
790 | self.__init_weight__()
791 |
792 | EF_parameters = list(self.netE.parameters()) + list(
793 | self.netF.parameters()
794 | )
795 | self.optimizer_EF = optim.Adam(
796 | EF_parameters, lr=opt.lr_e, betas=(opt.beta1, 0.999)
797 | )
798 | self.optimizer_D = optim.Adam(
799 | self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)
800 | )
801 | if not self.use_g_encode:
802 | self.optimizer_G = optim.Adam(
803 | self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)
804 | )
805 |
806 | self.lr_scheduler_EF = lr_scheduler.ExponentialLR(
807 | optimizer=self.optimizer_EF, gamma=0.5 ** (1 / 100)
808 | )
809 | self.lr_scheduler_D = lr_scheduler.ExponentialLR(
810 | optimizer=self.optimizer_D, gamma=0.5 ** (1 / 100)
811 | )
812 | if not self.use_g_encode:
813 | self.lr_scheduler_G = lr_scheduler.ExponentialLR(
814 | optimizer=self.optimizer_G, gamma=0.5 ** (1 / 100)
815 | )
816 | self.lr_schedulers = [
817 | self.lr_scheduler_EF,
818 | self.lr_scheduler_D,
819 | self.lr_scheduler_G,
820 | ]
821 | else:
822 | self.lr_schedulers = [self.lr_scheduler_EF, self.lr_scheduler_D]
823 | self.loss_names = ["E_pred", "E_gan", "D", "G"]
824 | if self.use_g_encode:
825 | self.loss_names.remove("G")
826 |
827 | def __train_forward__(self):
828 | self.z_seq = self.netG(self.one_hot_seq)
829 | self.e_seq = self.netE(self.x_seq, self.z_seq) # encoder of the data
830 | self.f_seq = self.netF(self.e_seq) # prediction
831 | self.d_seq = self.netD(self.e_seq)
832 |
833 | def __loss_EF__(self):
834 | # we have already got loss D
835 | if self.opt.lambda_gan != 0:
836 | loss_E_gan = -self.loss_D
837 | else:
838 | loss_E_gan = torch.tensor(
839 | 0, dtype=torch.float, device=self.opt.device
840 | )
841 |
842 | y_seq_source = self.y_seq[self.domain_mask == 1]
843 | f_seq_source = self.f_seq[self.domain_mask == 1]
844 |
845 | loss_E_pred = F.nll_loss(flat(f_seq_source), flat(y_seq_source))
846 |
847 | loss_E = loss_E_gan * self.opt.lambda_gan + loss_E_pred
848 |
849 | return loss_E, loss_E_pred, loss_E_gan
850 |
851 | def __loss_D__(self):
852 | criterion = nn.BCEWithLogitsLoss()
853 | d = self.d_seq
854 | # random pick subchain and optimize the D
855 | # balance coefficient is calculate by pos/neg ratio
856 | sub_graph = self.__sub_graph__(my_sample_v=self.opt.sample_v)
857 |
858 | errorD_connected = torch.zeros((1,)).to(self.device) # .double()
859 | errorD_disconnected = torch.zeros((1,)).to(self.device) # .double()
860 |
861 | count_connected = 0
862 | count_disconnected = 0
863 |
864 | for i in range(self.opt.sample_v):
865 | v_i = sub_graph[i]
866 | # no self loop version!!
867 | for j in range(i + 1, self.opt.sample_v):
868 | v_j = sub_graph[j]
869 | label = torch.full(
870 | (self.tmp_batch_size,),
871 | self.opt.A[v_i][v_j],
872 | device=self.device,
873 | )
874 | # dot product
875 | if v_i == v_j:
876 | idx = torch.randperm(self.tmp_batch_size)
877 | output = (d[v_i][idx] * d[v_j]).sum(1)
878 | else:
879 | output = (d[v_i] * d[v_j]).sum(1)
880 |
881 | if self.opt.A[v_i][v_j]: # connected
882 | errorD_connected += criterion(output, label)
883 | count_connected += 1
884 | else:
885 | errorD_disconnected += criterion(output, label)
886 | count_disconnected += 1
887 |
888 | errorD = 0.5 * (
889 | errorD_connected / count_connected
890 | + errorD_disconnected / count_disconnected
891 | )
892 | # this is a loss balance
893 | return errorD * self.num_domain
894 |
--------------------------------------------------------------------------------
/model/modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Identity(nn.Module):
8 | def __init__(self):
9 | super(Identity, self).__init__()
10 |
11 | def forward(self, x):
12 | return x
13 |
14 |
15 | class GNet(nn.Module):
16 | def __init__(self, opt):
17 | super(GNet, self).__init__()
18 | self.use_g_encode = opt.use_g_encode
19 | if self.use_g_encode:
20 | G = np.zeros((opt.num_domain, opt.nt))
21 | for i in range(opt.num_domain):
22 | G[i] = opt.g_encode[str(i)]
23 | self.G = torch.from_numpy(G).float().to(device=opt.device)
24 | else:
25 | self.fc1 = nn.Linear(opt.num_domain, opt.nh)
26 | self.fc_final = nn.Linear(opt.nh, opt.nt)
27 |
28 | def forward(self, x):
29 | re = x.dim() == 3
30 | if re:
31 | T, B, C = x.shape
32 | x = x.reshape(T * B, -1)
33 |
34 | if self.use_g_encode:
35 | x = torch.matmul(x.float(), self.G)
36 | else:
37 | x = F.relu(self.fc1(x.float()))
38 | # x = nn.Dropout(p=p)(x)
39 | x = self.fc_final(x)
40 | return x
41 |
42 |
43 | class FeatureNet(nn.Module):
44 | def __init__(self, opt):
45 | super(FeatureNet, self).__init__()
46 |
47 | nx, nh, nt, p = opt.nx, opt.nh, opt.nt, opt.p
48 | self.p = p
49 |
50 | self.fc1 = nn.Linear(nx, nh)
51 | self.fc2 = nn.Linear(nh * 2, nh * 2)
52 | self.fc3 = nn.Linear(nh * 2, nh * 2)
53 | self.fc4 = nn.Linear(nh * 2, nh * 2)
54 | self.fc_final = nn.Linear(nh * 2, nh)
55 |
56 | # here I change the input to fit the change dimension
57 | self.fc1_var = nn.Linear(nt, nh)
58 | self.fc2_var = nn.Linear(nh, nh)
59 |
60 | def forward(self, x, t):
61 | re = x.dim() == 3
62 | if re:
63 | T, B, C = x.shape
64 | x = x.reshape(T * B, -1)
65 | t = t.reshape(T * B, -1)
66 |
67 | x = F.relu(self.fc1(x))
68 | t = F.relu(self.fc1_var(t))
69 | t = F.relu(self.fc2_var(t))
70 |
71 | # combine feature in the middle
72 | x = torch.cat((x, t), dim=1)
73 |
74 | # main
75 | x = F.relu(self.fc2(x))
76 | x = F.relu(self.fc3(x))
77 | x = F.relu(self.fc4(x))
78 | x = self.fc_final(x)
79 |
80 | if re:
81 | return x.reshape(T, B, -1)
82 | else:
83 | return x
84 |
85 |
86 | class GraphDNet(nn.Module):
87 | """
88 | Generate z' for connection loss
89 | """
90 |
91 | def __init__(self, opt):
92 | super(GraphDNet, self).__init__()
93 | nh = opt.nh
94 | nin = nh
95 | self.fc3 = nn.Linear(nin, nh)
96 | self.bn3 = nn.BatchNorm1d(nh)
97 |
98 | self.fc4 = nn.Linear(nh, nh)
99 | self.bn4 = nn.BatchNorm1d(nh)
100 |
101 | self.fc5 = nn.Linear(nh, nh)
102 | self.bn5 = nn.BatchNorm1d(nh)
103 |
104 | self.fc6 = nn.Linear(nh, nh)
105 | self.bn6 = nn.BatchNorm1d(nh)
106 |
107 | self.fc7 = nn.Linear(nh, nh)
108 | self.bn7 = nn.BatchNorm1d(nh)
109 |
110 | self.fc_final = nn.Linear(nh, opt.nd_out)
111 |
112 | if opt.no_bn:
113 | self.bn3 = Identity()
114 | self.bn4 = Identity()
115 | self.bn5 = Identity()
116 | self.bn6 = Identity()
117 | self.bn7 = Identity()
118 |
119 | def forward(self, x):
120 | re = x.dim() == 3
121 |
122 | if re:
123 | T, B, C = x.shape
124 | x = x.reshape(T * B, -1)
125 |
126 | x = F.relu(self.bn3(self.fc3(x)))
127 | x = F.relu(self.bn4(self.fc4(x)))
128 | x = F.relu(self.bn5(self.fc5(x)))
129 | x = F.relu(self.bn6(self.fc6(x)))
130 | x = F.relu(self.bn7(self.fc7(x)))
131 |
132 | x = self.fc_final(x)
133 |
134 | if re:
135 | return x.reshape(T, B, -1)
136 | else:
137 | return x
138 |
139 |
140 | class ResGraphDNet(nn.Module):
141 | """
142 | Generate z' for connection loss
143 | """
144 |
145 | def __init__(self, opt):
146 | super(ResGraphDNet, self).__init__()
147 | nh = opt.nh
148 | nin = nh
149 | self.fc3 = nn.Linear(nin, nh)
150 | self.bn3 = nn.BatchNorm1d(nh)
151 |
152 | self.fc4 = nn.Linear(nh, nh)
153 | self.bn4 = nn.BatchNorm1d(nh)
154 |
155 | self.fc5 = nn.Linear(nh, nh)
156 | self.bn5 = nn.BatchNorm1d(nh)
157 |
158 | self.fc6 = nn.Linear(nh, nh)
159 | self.bn6 = nn.BatchNorm1d(nh)
160 |
161 | self.fc7 = nn.Linear(nh, nh)
162 | self.bn7 = nn.BatchNorm1d(nh)
163 |
164 | self.fc8 = nn.Linear(nh, nh)
165 | self.bn8 = nn.BatchNorm1d(nh)
166 |
167 | self.fc9 = nn.Linear(nh, nh)
168 | self.bn9 = nn.BatchNorm1d(nh)
169 |
170 | self.fc10 = nn.Linear(nh, nh)
171 | self.bn10 = nn.BatchNorm1d(nh)
172 |
173 | self.fc11 = nn.Linear(nh, nh)
174 | self.bn11 = nn.BatchNorm1d(nh)
175 |
176 | self.fc_final = nn.Linear(nh, opt.nd_out)
177 |
178 | if opt.no_bn:
179 | self.bn3 = Identity()
180 | self.bn4 = Identity()
181 | self.bn5 = Identity()
182 | self.bn6 = Identity()
183 | self.bn7 = Identity()
184 | self.bn8 = Identity()
185 | self.bn9 = Identity()
186 | self.bn10 = Identity()
187 | self.bn11 = Identity()
188 |
189 | def forward(self, x):
190 | re = x.dim() == 3
191 |
192 | if re:
193 | T, B, C = x.shape
194 | x = x.reshape(T * B, -1)
195 |
196 | x = F.relu(self.bn3(self.fc3(x)))
197 | id1 = x
198 | out = F.relu(self.bn4(self.fc4(x)))
199 | out = self.bn5(self.fc5(out))
200 | x = F.relu(out + id1)
201 |
202 | id2 = x
203 | out = F.relu(self.bn6(self.fc6(x)))
204 | out = self.bn7(self.fc7(out))
205 | x = F.relu(out + id2)
206 |
207 | id3 = x
208 | out = F.relu(self.bn8(self.fc8(x)))
209 | out = self.bn9(self.fc9(out))
210 | x = F.relu(out + id3)
211 |
212 | id4 = x
213 | out = F.relu(self.bn10(self.fc10(x)))
214 | out = self.bn11(self.fc11(out))
215 | x = F.relu(out + id4)
216 |
217 | x = self.fc_final(x)
218 |
219 | if re:
220 | return x.reshape(T, B, -1)
221 | else:
222 | return x
223 |
224 |
225 | class DiscNet(nn.Module):
226 | # Discriminator doing binary classification: source v.s. target
227 |
228 | def __init__(self, opt):
229 | super(DiscNet, self).__init__()
230 | nh = opt.nh
231 |
232 | nin = nh
233 | self.fc3 = nn.Linear(nin, nh)
234 | self.bn3 = nn.BatchNorm1d(nh)
235 |
236 | self.fc4 = nn.Linear(nh, nh)
237 | self.bn4 = nn.BatchNorm1d(nh)
238 |
239 | self.fc5 = nn.Linear(nh, nh)
240 | self.bn5 = nn.BatchNorm1d(nh)
241 |
242 | self.fc6 = nn.Linear(nh, nh)
243 | self.bn6 = nn.BatchNorm1d(nh)
244 |
245 | self.fc7 = nn.Linear(nh, nh)
246 | self.bn7 = nn.BatchNorm1d(nh)
247 |
248 | if opt.no_bn:
249 | self.bn3 = Identity()
250 | self.bn4 = Identity()
251 | self.bn5 = Identity()
252 | self.bn6 = Identity()
253 | self.bn7 = Identity()
254 |
255 | self.fc_final = nn.Linear(nh, 1)
256 | if opt.model in ["ADDA", "CUA"]:
257 | print("===> Discrinimator Output Activation: sigmoid")
258 | self.output = lambda x: torch.sigmoid(x)
259 | else:
260 | print("===> Discrinimator Output Activation: identity")
261 | self.output = lambda x: x
262 |
263 | def forward(self, x):
264 | re = x.dim() == 3
265 |
266 | if re:
267 | T, B, C = x.shape
268 | x = x.reshape(T * B, -1)
269 |
270 | x = F.relu(self.bn3(self.fc3(x)))
271 | x = F.relu(self.bn4(self.fc4(x)))
272 | x = F.relu(self.bn5(self.fc5(x)))
273 | x = F.relu(self.bn6(self.fc6(x)))
274 | x = F.relu(self.bn7(self.fc7(x)))
275 | x = self.output(self.fc_final(x))
276 |
277 | if re:
278 | return x.reshape(T, B, -1)
279 | else:
280 | return x
281 |
282 |
283 | class ClassDiscNet(nn.Module):
284 | """
285 | Discriminator doing multi-class classification on the domain
286 | """
287 |
288 | def __init__(self, opt):
289 | super(ClassDiscNet, self).__init__()
290 | nh = opt.nh
291 | nc = opt.nc
292 | nin = nh
293 | nout = opt.num_domain
294 |
295 | if opt.cond_disc:
296 | print("===> Conditioned Discriminator")
297 | nmid = nh * 2
298 | self.cond = nn.Sequential(
299 | nn.Linear(nc, nh),
300 | nn.ReLU(True),
301 | nn.Linear(nh, nh),
302 | nn.ReLU(True),
303 | )
304 | else:
305 | print("===> Unconditioned Discriminator")
306 | nmid = nh
307 | self.cond = None
308 |
309 | print(f"===> Discriminator will distinguish {nout} domains")
310 |
311 | self.fc3 = nn.Linear(nin, nh)
312 | self.bn3 = nn.BatchNorm1d(nh)
313 |
314 | self.fc4 = nn.Linear(nmid, nh)
315 | self.bn4 = nn.BatchNorm1d(nh)
316 |
317 | self.fc5 = nn.Linear(nh, nh)
318 | self.bn5 = nn.BatchNorm1d(nh)
319 |
320 | self.fc6 = nn.Linear(nh, nh)
321 | self.bn6 = nn.BatchNorm1d(nh)
322 |
323 | self.fc7 = nn.Linear(nh, nh)
324 | self.bn7 = nn.BatchNorm1d(nh)
325 |
326 | if opt.no_bn:
327 | self.bn3 = Identity()
328 | self.bn4 = Identity()
329 | self.bn5 = Identity()
330 | self.bn6 = Identity()
331 | self.bn7 = Identity()
332 |
333 | self.fc_final = nn.Linear(nh, nout)
334 |
335 | def forward(self, x):
336 | re = x.dim() == 3
337 |
338 | if re:
339 | T, B, C = x.shape
340 | x = x.reshape(T * B, -1)
341 | # f_exp = f_exp.reshape(T * B, -1)
342 |
343 | x = F.relu(self.bn3(self.fc3(x)))
344 | x = F.relu(self.bn4(self.fc4(x)))
345 | x = F.relu(self.bn5(self.fc5(x)))
346 | x = F.relu(self.bn6(self.fc6(x)))
347 | x = F.relu(self.bn7(self.fc7(x)))
348 | # x = self.fc_final(x)
349 | x = F.relu(self.fc_final(x))
350 | x = torch.log_softmax(x, dim=1)
351 | if re:
352 | return x.reshape(T, B, -1)
353 | else:
354 | return x
355 |
356 |
357 | class CondClassDiscNet(nn.Module):
358 | """
359 | Discriminator doing multi-class classification on the domain
360 | """
361 |
362 | def __init__(self, opt):
363 | super(CondClassDiscNet, self).__init__()
364 | nh = opt.nh
365 | nc = opt.nc
366 | nin = nh
367 | nout = opt.num_domain
368 |
369 | if opt.cond_disc:
370 | print("===> Conditioned Discriminator")
371 | nmid = nh * 2
372 | self.cond = nn.Sequential(
373 | nn.Linear(nc, nh),
374 | nn.ReLU(True),
375 | nn.Linear(nh, nh),
376 | nn.ReLU(True),
377 | )
378 | else:
379 | print("===> Unconditioned Discriminator")
380 | nmid = nh
381 | self.cond = None
382 |
383 | print(f"===> Discriminator will distinguish {nout} domains")
384 |
385 | self.fc3 = nn.Linear(nin, nh)
386 | self.bn3 = nn.BatchNorm1d(nh)
387 |
388 | self.fc4 = nn.Linear(nmid, nh)
389 | self.bn4 = nn.BatchNorm1d(nh)
390 |
391 | self.fc5 = nn.Linear(nh, nh)
392 | self.bn5 = nn.BatchNorm1d(nh)
393 |
394 | self.fc6 = nn.Linear(nh, nh)
395 | self.bn6 = nn.BatchNorm1d(nh)
396 |
397 | self.fc7 = nn.Linear(nh, nh)
398 | self.bn7 = nn.BatchNorm1d(nh)
399 |
400 | if opt.no_bn:
401 | self.bn3 = Identity()
402 | self.bn4 = Identity()
403 | self.bn5 = Identity()
404 | self.bn6 = Identity()
405 | self.bn7 = Identity()
406 |
407 | self.fc_final = nn.Linear(nh, nout)
408 |
409 | def forward(self, x, f_exp):
410 | re = x.dim() == 3
411 |
412 | if re:
413 | T, B, C = x.shape
414 | x = x.reshape(T * B, -1)
415 | f_exp = f_exp.reshape(T * B, -1)
416 |
417 | x = F.relu(self.bn3(self.fc3(x)))
418 | if self.cond is not None:
419 | f = self.cond(f_exp)
420 | x = torch.cat([x, f], dim=1)
421 | x = F.relu(self.bn4(self.fc4(x)))
422 | x = F.relu(self.bn5(self.fc5(x)))
423 | x = F.relu(self.bn6(self.fc6(x)))
424 | x = F.relu(self.bn7(self.fc7(x)))
425 | x = self.fc_final(x)
426 | x = torch.log_softmax(x, dim=1)
427 | if re:
428 | return x.reshape(T, B, -1)
429 | else:
430 | return x
431 |
432 |
433 | class ProbDiscNet(nn.Module):
434 | def __init__(self, opt):
435 | super(ProbDiscNet, self).__init__()
436 |
437 | nmix = opt.nmix
438 |
439 | nh = opt.nh
440 |
441 | nin = nh
442 | nout = opt.dim_domain * nmix * 3
443 |
444 | self.fc3 = nn.Linear(nin, nh)
445 | self.bn3 = nn.BatchNorm1d(nh)
446 |
447 | self.fc4 = nn.Linear(nh, nh)
448 | self.bn4 = nn.BatchNorm1d(nh)
449 |
450 | self.fc5 = nn.Linear(nh, nh)
451 | self.bn5 = nn.BatchNorm1d(nh)
452 |
453 | self.fc6 = nn.Linear(nh, nh)
454 | self.bn6 = nn.BatchNorm1d(nh)
455 |
456 | if opt.no_bn:
457 | self.bn3 = Identity()
458 | self.bn4 = Identity()
459 | self.bn5 = Identity()
460 | self.bn6 = Identity()
461 |
462 | self.fc_final = nn.Linear(nh, nout)
463 |
464 | self.ndomain = opt.dim_domain
465 | self.nmix = nmix
466 |
467 | def forward(self, x):
468 | re = x.dim() == 3
469 | if re:
470 | T, B, C = x.shape
471 | x = x.reshape(T * B, -1)
472 |
473 | x = F.relu(self.bn3(self.fc3(x)))
474 | x = F.relu(self.bn4(self.fc4(x)))
475 | x = F.relu(self.bn5(self.fc5(x)))
476 | x = F.relu(self.bn6(self.fc6(x)))
477 |
478 | x = self.fc_final(x).reshape(-1, 3, self.ndomain, self.nmix)
479 | x_mean, x_std, x_weight = x[:, 0], x[:, 1], x[:, 2]
480 | x_std = torch.sigmoid(x_std) * 2 + 0.1
481 | x_weight = torch.softmax(x_weight, dim=1)
482 |
483 | if re:
484 | return (
485 | x_mean.reshape(T, B, -1),
486 | x_std.reshape(T, B, -1),
487 | x_weight.reshape(T, B, -1),
488 | )
489 | else:
490 | return x_mean, x_std, x_weight
491 |
492 |
493 | class PredNet(nn.Module):
494 | def __init__(self, opt):
495 | super(PredNet, self).__init__()
496 |
497 | nh, nc = opt.nh, opt.nc
498 | nin = nh
499 | self.fc3 = nn.Linear(nin, nh)
500 | self.bn3 = nn.BatchNorm1d(nh)
501 | self.fc4 = nn.Linear(nh, nh)
502 | self.bn4 = nn.BatchNorm1d(nh)
503 | self.fc_final = nn.Linear(nh, nc)
504 | if opt.no_bn:
505 | self.bn3 = Identity()
506 | self.bn4 = Identity()
507 |
508 | def forward(self, x, return_softmax=False):
509 | re = x.dim() == 3
510 | if re:
511 | T, B, C = x.shape
512 | x = x.reshape(T * B, -1)
513 |
514 | x = F.relu(self.bn3(self.fc3(x)))
515 | x = F.relu(self.bn4(self.fc4(x)))
516 | x = self.fc_final(x)
517 | x_softmax = F.softmax(x, dim=1)
518 |
519 | # x = F.log_softmax(x, dim=1)
520 | # x = torch.clamp_max(x_softmax + 1e-4, 1)
521 | # x = torch.log(x)
522 | x = torch.log(x_softmax + 1e-4)
523 |
524 | if re:
525 | x = x.reshape(T, B, -1)
526 | x_softmax = x_softmax.reshape(T, B, -1)
527 |
528 | if return_softmax:
529 | return x, x_softmax
530 | else:
531 | return x
532 |
533 |
534 | # ======================================================================================================================
535 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas
2 | networkx
3 | matplotlib
4 | visdom
5 | easydict
6 | torch==1.9.0
7 | torchvision==0.10.0
8 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 |
4 | def read_pickle(name):
5 | with open(name, "rb") as f:
6 | data = pickle.load(f)
7 | return data
8 |
9 |
10 | def write_pickle(data, name):
11 | with open(name, "wb") as f:
12 | pickle.dump(data, f)
13 |
--------------------------------------------------------------------------------