├── .gitignore ├── CreateDataset ├── InMemoryDataset.py ├── LargerDataset.py └── test001.py ├── Example ├── CommonBenchmarkDatasets.py ├── DataHandle.py ├── GCNModel.py ├── MiniBatches.py └── test001.py ├── GitExample └── agnn.py ├── LICENSE ├── MessagePassing ├── GCN.py └── test001.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea 132 | data/ -------------------------------------------------------------------------------- /CreateDataset/InMemoryDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset 3 | 4 | class MyDataset(InMemoryDataset): 5 | def __init__(self, root, transform=None, pre_transform=None): 6 | # 数据的下载和处理过程在父类中调用实现 7 | super(MyDataset, self).__init__(root, transform, pre_transform) 8 | # 加载数据 9 | self.data, self.slices = torch.load(self.processed_paths[0]) 10 | 11 | # 将函数修饰为类属性 12 | @property 13 | def raw_file_names(self): 14 | return ['file_1', 'file_2'] 15 | 16 | @property 17 | def processed_file_names(self): 18 | return ['data.pt'] 19 | 20 | def download(self): 21 | # download to self.raw_dir 22 | pass 23 | 24 | def process(self): 25 | data_list = [...] 26 | 27 | if self.pre_filter is not None: 28 | data_list = [data for data in data_list if self.pre_filter(data)] 29 | 30 | if self.pre_filter is not None: 31 | data_list = [self.pre_transform(data) for data in data_list] 32 | 33 | data, slices = self.collate(data_list) 34 | # 这里的save方式以及路径需要对应构造函数中的load操作 35 | torch.save((data, slices), self.processed_paths[0]) -------------------------------------------------------------------------------- /CreateDataset/LargerDataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | # 这里就不能用InMemoryDataset了 5 | from torch_geometric.data import Dataset 6 | 7 | class MyDataset(Dataset): 8 | # 默认预处理函数的参数都是None 9 | def __init__(self, root, transform=None, pre_transform=None): 10 | super(MyDataset, self).__init__(root, transform, pre_transform) 11 | 12 | @property 13 | def raw_file_names(self): 14 | return ['file_1', 'file_2'] 15 | 16 | @property 17 | def processed_file_names(self): 18 | # 一次无法加载所有数据,所以对数据进行了分解 19 | return ['data1.pt', 'data2.pt', 'data3.pt'] 20 | 21 | def download(self): 22 | # Download to raw_dir 23 | pass 24 | 25 | def process(self): 26 | i = 0 27 | # 遍历每一个文件路径 28 | for raw_path in self.raw_paths: 29 | data = Data(...) 30 | 31 | if self.pre_filter is not None and not self.pre_filter(data): 32 | continue 33 | 34 | if self.pre_transform is not None: 35 | data = self.pre_transform(data) 36 | 37 | torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i))) 38 | i += 1 39 | 40 | 41 | def len(self): 42 | return len(self.processed_file_names) 43 | 44 | def get(self, idx): 45 | data = torch.load(osp.join(self.processed_dir, 'data{}.pt',format(idx))) 46 | return data -------------------------------------------------------------------------------- /CreateDataset/test001.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | from itertools import product 4 | 5 | edge_index = torch.tensor([ 6 | [3, 1, 1, 2], 7 | [1, 3, 2, 1]], dtype=torch.long) 8 | x = torch.tensor([[-1], 9 | [0], 10 | [1]], dtype=torch.float) 11 | 12 | d = Data(x=x, edge_index=edge_index) 13 | # print(type(d)) # # 14 | 15 | data_list = [0] 16 | data_list[0] = d 17 | 18 | keys = data_list[0].keys 19 | # data->Data() 20 | data = data_list[0].__class__() 21 | # print(data_list[0].keys) # ['x', 'edge_index'] 22 | # print(type(data)) # 23 | 24 | for key in keys: 25 | data[key] = [] 26 | print(data) # Data(edge_index=[0], x=[0]) 27 | 28 | slices = {key: [0] for key in keys} 29 | # print(slices) # {'x': [0], 'edge_index': [0]} 30 | 31 | for item, key in product(data_list, keys): 32 | print(item, key) 33 | data[key].append(item[key]) 34 | print(item[key]) 35 | if torch.is_tensor(item[key]): 36 | s = slices[key][-1] + item[key].size( 37 | item.__cat_dim__(key, item[key])) 38 | else: 39 | s = slices[key][-1] + 1 40 | slices[key].append(s) 41 | 42 | print(data) 43 | 44 | if hasattr(data_list[0], '__num_nodes__'): 45 | data.__num_nodes__ = [] 46 | for item in data_list: 47 | data.__num_nodes__.append(item.num_nodes) 48 | 49 | for key in keys: 50 | item = data_list[0][key] 51 | if torch.is_tensor(item): 52 | data[key] = torch.cat(data[key], 53 | dim=data.__cat_dim__(key, item)) 54 | elif isinstance(item, int) or isinstance(item, float): 55 | data[key] = torch.tensor(data[key]) 56 | 57 | slices[key] = torch.tensor(slices[key], dtype=torch.long) 58 | 59 | -------------------------------------------------------------------------------- /Example/CommonBenchmarkDatasets.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import TUDataset 2 | import torch 3 | 4 | # 加载数据集,下载转换过程 5 | dataset = TUDataset(root='data/', name='ENZYMES') 6 | print(dataset) 7 | # ENZYMES(600) 8 | print(type(dataset)) 9 | # 10 | print(len(dataset)) 11 | # 600 12 | print(dataset.num_node_features) 13 | # 3 14 | 15 | # dataset是一个可迭代对象,并且每一个元素都是一个Data实例,但是y是一个单独的元素,所以说这个数据集是Graph-level的 16 | data = dataset[0] 17 | print(data) 18 | # Data(edge_index=[2, 168], x=[37, 3], y=[1]) 19 | 20 | # 数据集切分 21 | dataset_train = dataset[:500] 22 | dataset_test = dataset[500:] 23 | print(dataset_train, dataset_test) 24 | # ENZYMES(500) ENZYMES(100) 25 | dataset_sample1 = dataset[torch.tensor([i for i in range(500)], dtype=torch.long)] 26 | print(dataset_sample1) 27 | # ENZYMES(500) 28 | dataset_sample2 = dataset[torch.tensor([True, False])] 29 | print(dataset_sample2) 30 | # ENZYMES(1) 31 | print(dataset[0]) 32 | # Data(edge_index=[2, 168], x=[37, 3], y=[1]) 33 | print(dataset[1]) 34 | # Data(edge_index=[2, 102], x=[23, 3], y=[1]) 35 | print(dataset_sample2[0]) 36 | # Data(edge_index=[2, 168], x=[37, 3], y=[1]) 37 | 38 | dataset = dataset.shuffle() 39 | # 等价于 40 | dataset = dataset[torch.randperm(len(dataset))] 41 | -------------------------------------------------------------------------------- /Example/DataHandle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | 4 | edge_index = torch.tensor([ 5 | [3, 1, 1, 2], 6 | [1, 3, 2, 1]],dtype=torch.long) 7 | # 注意x是二维的,不是一维的,每一行代表一个节点的特征向量,此处特征维度为1 8 | x = torch.tensor([[-1], 9 | [0], 10 | [1]], dtype=torch.float) 11 | 12 | data = Data(x=x, edge_index=edge_index) 13 | print(data) 14 | ''' 15 | # 通过节点对的方式给出 16 | edge_index = torch.tensor([ 17 | [0, 1], [1, 0], [1, 2], [2, 1] 18 | ], dtype=torch.long) 19 | data = Data(x=x, edge_index=edge_index.t().contiguous()) 20 | print(data) 21 | ''' 22 | # 输出data的属性关键字,只有传递参数的才会被输出 23 | print(data.keys) 24 | # ['x', 'edge_index'] 25 | 26 | # 按照关键字进行输出,注意是字符串 27 | print(data['x']) 28 | # tensor([[-1.], 29 | # [ 0.], 30 | # [ 1.]]) 31 | print(data['edge_index']) 32 | # tensor([[0, 1, 1, 2], 33 | # [1, 0, 2, 1]]) 34 | 35 | print('edge_attr: ', data['edge_attr']) 36 | # edge_attr: None 37 | 38 | # 遍历所有关键字及其对应的数值 39 | for key, item in data: 40 | print(key, '---', item) 41 | 42 | # 可以直接检索key,也可以检索data内函数 43 | if 'edge_attr' not in data.keys: 44 | print('Not in') 45 | # Not in 46 | 47 | if 'x' in data: 48 | print('In') 49 | # In 50 | 51 | # print(type(data.keys)) 52 | # 53 | 54 | print(data.num_nodes) 55 | # 3 56 | 57 | # 这里的边数为4 58 | print(data.num_edges) 59 | # 4 60 | 61 | print(data.num_edge_features) 62 | # 0 63 | 64 | print(data.num_node_features) 65 | # 1 66 | 67 | print(data.contains_isolated_nodes()) 68 | # False 69 | 70 | print(data.contains_self_loops()) 71 | # False 72 | 73 | print(data.is_undirected()) 74 | # True 75 | -------------------------------------------------------------------------------- /Example/GCNModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.datasets import Planetoid 4 | from torch_geometric.nn import GCNConv 5 | 6 | dataset = Planetoid(root='data/', name='Cora') 7 | 8 | # 继承torch的类 9 | class Net(torch.nn.Module): 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | self.conv1 = GCNConv(dataset.num_node_features, 16) 13 | self.conv2 = GCNConv(16, dataset.num_classes) 14 | 15 | def forward(self, data): 16 | x, edge_index = data.x, data.edge_index 17 | 18 | x = self.conv1(x, edge_index) 19 | x = F.relu(x) 20 | x = F.dropout(x, training=self.training) 21 | x = self.conv2(x, edge_index) 22 | 23 | return F.log_softmax(x, dim=1) 24 | 25 | if __name__ == '__main__': 26 | # 加载数据集 27 | dataset = Planetoid(root='data/', name='Cora') 28 | # Train 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | model = Net().to(device) 31 | data = dataset[0].to(device) 32 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) 33 | 34 | model.train() 35 | for epoch in range(200): 36 | optimizer.zero_grad() 37 | out = model(data) 38 | loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) 39 | loss.backward() 40 | optimizer.step() 41 | 42 | # Test 43 | model.eval() 44 | _, pred = model(data).max(dim=1) 45 | correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()) 46 | acc = correct / data.test_mask.sum().item() 47 | print('Accuracy: {:.4f}'.format(acc)) -------------------------------------------------------------------------------- /Example/MiniBatches.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import TUDataset 2 | from torch_geometric.data import DataLoader 3 | from torch_scatter import scatter_mean 4 | 5 | dataset = TUDataset(root='data/', name='ENZYMES', use_node_attr=True) 6 | loader = DataLoader(dataset, batch_size=32, shuffle=True) 7 | 8 | for data in loader: 9 | print(data) 10 | # Batch(batch=[1005], edge_index=[2, 3948], x=[1005, 21], y=[32]) 11 | x = scatter_mean(data.x, data.batch, dim=0) 12 | print(x.size()) 13 | # torch.Size([32, 21]) -------------------------------------------------------------------------------- /Example/test001.py: -------------------------------------------------------------------------------- 1 | ### Q1: X维度和Y的维度不统一 2 | import torch 3 | from torch_geometric.data import Data 4 | 5 | # 构建边 6 | edge_index = torch.tensor([ 7 | [3, 1, 1, 2], 8 | [1, 3, 2, 1]], dtype=torch.long) 9 | # 构建X 10 | x = torch.tensor([[-1], 11 | [0], 12 | [1],[2]], dtype=torch.float) 13 | y = torch.tensor([[1], [2], [3], [4], [5]], dtype=torch.float) 14 | data = Data(x=x, y=y, edge_index=edge_index) 15 | 16 | print(data) 17 | 18 | ### Q2: 手动加载数据集 19 | from torch_geometric.datasets import Planetoid 20 | 21 | dataset = Planetoid(root='data/', name='Cora') 22 | print(dataset) -------------------------------------------------------------------------------- /GitExample/agnn.py: -------------------------------------------------------------------------------- 1 | # https://github.com/rusty1s/pytorch_geometric/blob/master/examples/agnn.py 2 | # 代码注释 3 | 4 | import os.path as osp 5 | 6 | import torch 7 | import torch_geometric 8 | import torch.nn.functional as F 9 | from torch_geometric.datasets import Planetoid 10 | import torch_geometric.transforms as T 11 | from torch_geometric.nn import AGNNConv 12 | 13 | dataset = 'Cora' 14 | 15 | # osp.realpath(__file__) 输出当前代码文件的绝对路径 16 | # .. 访问到当前代码目录的上一层目录 17 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) 18 | 19 | # path为root路径,dataset为name路径,path+dataset才是数据的路径(processed所在的路径) 20 | dataset = Planetoid(path, dataset, T.NormalizeFeatures()) 21 | 22 | # dataset为可迭代对象,每一个元素都是一个Data实例 23 | # 所有的Cora数据内容都在第一个data中存放,因为只有一个图 24 | data = dataset[0] 25 | # print(len(dataset)) # 1 26 | print(type(data)) 27 | 28 | # 定义网络结构 29 | class Net(torch.nn.Module): 30 | def __init__(self): 31 | super(Net, self).__init__() 32 | # 初始化一些函数接口 33 | self.lin1 = torch.nn.Linear(dataset.num_features, 16) 34 | self.prop1 = AGNNConv(requires_grad=False) 35 | self.prop2 = AGNNConv(requires_grad=True) 36 | # 因为使用的是节点分类数据集,所以需要映射一个全类别概率向量 37 | self.lin2 = torch.nn.Linear(16, dataset.num_classes) 38 | 39 | # 官方代码中没有传递data参数,而是直接在forward中调用data,将data作为全局变量的方式进行使用 40 | def forward(self, data:torch_geometric.data.data.Data): 41 | # dropout默认的training状态是false,而且此处仅仅是调用了一个外部函数F.dropout,即使内部training状态发生改变, 42 | # 也不会影响dropout,所以需要将模型的training状态传递给dropout 43 | x = F.dropout(data.x, training=self.training) 44 | x = F.relu(self.lin1(x)) 45 | x = self.prop1(x, data.edge_index) 46 | x = self.prop2(x, data.edge_index) 47 | x = F.dropout(x, training=self.training) 48 | x = self.lin2(x) 49 | # 50 | return F.log_softmax(x, dim=1) 51 | 52 | # 选择GPU/CPU设备 53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | # Net()相当于创建了一个Net实例化对象,然后调用父类Module的函数 55 | model, data = Net().to(device), data.to(device) 56 | # 构造优化器,指定优化目标(也就是模型的参数),学习率和衰减参数 57 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) 58 | 59 | # 定义训练函数 60 | def train(): 61 | # 1.启动训练 62 | model.train() 63 | # 2.优化器梯度初始化 64 | optimizer.zero_grad() 65 | # 3.反向传播过程 66 | # 4.调用优化器 67 | optimizer.step() 68 | 69 | 70 | # 定义测试函数 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MessagePassing/GCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | from torch_geometric.utils import add_self_loops, degree 4 | 5 | class GCNConv(MessagePassing): 6 | def __init__(self, in_channels, out_channels): 7 | super(GCNConv, self).__init__(aggr='add') 8 | self.lin = torch.nn.Linear(in_channels, out_channels) 9 | 10 | def forward(self, x, edge_index): 11 | # X: [N, in_channels] 12 | # edge_index: [2, E] 13 | 14 | # 1.在邻接矩阵中增加自环 15 | edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) 16 | 17 | # 2.对节点特征进行一个非线性转换 18 | # x的维度会由[N, in_channels]转换为[N, out_channels] 19 | x = self.lin(x) 20 | 21 | # 3.计算标准化系数 22 | # edge_index的第一个向量作为行坐标,第二个向量作为列坐标 23 | row, col = edge_index 24 | deg = degree(row, x.size(0), dtype=x.dtype) 25 | deg_inv_sqrt = deg.pow(-1/2) 26 | # norm的第一个元素就是edge_index中的第一列(第一条边)上的标准化系数 27 | # tensor的乘法为对应元素乘法,tensor1[tensor2]后的维度与tensor2一致 28 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 29 | 30 | # 4-6步的开始标志,内部实现了message-AGGREGATE-update 31 | return self.propagate(edge_index, size=(x.size(0), x.size(1)), x=x, norm=norm) 32 | 33 | def message(self, x_j, norm): 34 | # x_j的维度为[E, out_channels] 35 | print(x_j) 36 | # 4.进行传递消息的构造,将标准化系数乘以邻域节点的特征信息得到传递信息 37 | return norm.view(-1, 1) * x_j 38 | 39 | def update(self, aggr_out): 40 | # aggr_out的维度为[N, out_channels] 41 | 42 | # 6.更新新的节点嵌入,这里没有做任何多余的映射过程 43 | return aggr_out 44 | 45 | # 实例化对象 46 | conv = GCNConv(3, 3) 47 | # 构建数据 48 | edge_index = torch.tensor([ 49 | [0, 1, 1, 2], 50 | [1, 0, 2, 1] 51 | ], dtype=torch.long) 52 | x = torch.tensor([ 53 | [0, 0, 0], 54 | [1, 1, 1], 55 | [2, 2, 2] 56 | ], dtype=torch.float) 57 | 58 | # 默认为调用对象的forward函数 59 | x = conv(x, edge_index) 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /MessagePassing/test001.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import add_self_loops, degree 3 | 4 | # 节点编号是从1开始的,但是在后面增加自环的时候,是从0开始的 5 | ''' 6 | edge_index = torch.tensor([ 7 | [1, 2, 3], 8 | [2, 3, 2] 9 | ], dtype=torch.long) 10 | ''' 11 | edge_index = torch.tensor([ 12 | [0, 1, 2], 13 | [1, 2, 1] 14 | ], dtype=torch.long) 15 | 16 | # 如果节点编号从1开始,那么不设置节点个数的话,则会增加[0,1,2,3]的自环;节点个数设置为3,增加[0,1,2]的自环 17 | edge_index, _ = add_self_loops(edge_index, num_nodes=3) 18 | print(edge_index) 19 | # tensor([[0, 1, 2, 0, 1, 2], 20 | # [1, 2, 1, 0, 1, 2]]) 21 | 22 | # 分别取出所有边的第一个索引和第二个索引 23 | row, col = edge_index 24 | print(row) 25 | # tensor([0, 1, 2, 0, 1, 2]) 26 | 27 | # 节点2出现了四次,所以0和1节点的度为0,而节点2的度为4。从节点编号开始,统计每一个编号的出现次数,如果没有出现的不会被空过去,而是记录为0 28 | # row = torch.tensor([2, 2, 2, 2], dtype=torch.long) # degree: tensor([0., 0., 4.]) 29 | 30 | deg = degree(row) 31 | # print('deg: ', deg) 32 | 33 | row, col = edge_index 34 | norm = deg[row] * deg[col] 35 | print(norm) 36 | 37 | # 转换为n行1列 38 | norm_ = norm.view(-1, 1) 39 | # print(norm_.shape) 40 | 41 | # 特征矩阵,每一行代表一个元素的特征向量 42 | x_j = torch.tensor([ 43 | [1,2,3], 44 | [2,3,4], 45 | [1,2,3], 46 | [2,3,4], 47 | [1,2,3], 48 | [2,3,4] 49 | ]) 50 | # (n, 1) * (n * m)广播式乘法 51 | print(norm_ * x_j) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-Geometric-Study 2 | Code: https://github.com/FutureTwT/PyTorch-Geometric-Study 3 | --------------------------------------------------------------------------------