├── polygoncode ├── polygonembed │ ├── __init__.py │ ├── dbtopo │ │ ├── train_generative.py │ │ └── train_spa_rel.py │ ├── animal │ │ ├── train_img_cla.py │ │ └── train_pgon.py │ ├── sftp-config.json │ ├── mnist │ │ └── train_pgon.py │ ├── lenet.py │ ├── resnet2d.py │ ├── ops.py │ ├── spec_pool.py │ ├── atten.py │ ├── data_util.py │ ├── PolygonDecoder.py │ ├── trainer_img.py │ ├── dla_resnext.py │ ├── ddsl.py │ ├── module.py │ ├── ddsl_utils.py │ ├── dla.py │ └── resnet.py └── 1_pgon_dbtopo.sh ├── image └── model.png ├── requirements.txt ├── data_processing └── dbtopo │ └── data │ └── spa_rel.csv ├── .gitignore ├── README.md └── LICENSE /polygoncode/polygonembed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengchenmai/polygon_encoder/HEAD/image/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backoff==2.2.1 2 | Fiona==1.8.18 3 | geopandas==0.8.1 4 | matplotlib==3.3.3 5 | numpy==1.19.5 6 | pandas==1.2.0 7 | Requests==2.31.0 8 | scikit_learn==0.24.1 9 | Shapely==1.7.1 10 | torch==1.7.1+cu110 11 | torchvision==0.8.2+cu110 12 | tqdm==4.55.1 13 | -------------------------------------------------------------------------------- /data_processing/dbtopo/data/spa_rel.csv: -------------------------------------------------------------------------------- 1 | http://dbpedia.org/ontology/isPartOf 2 | http://dbpedia.org/ontology/location 3 | http://dbpedia.org/ontology/locatedInArea 4 | http://dbpedia.org/property/location 5 | http://dbpedia.org/property/east 6 | http://dbpedia.org/property/south 7 | http://dbpedia.org/property/west 8 | http://dbpedia.org/property/north 9 | http://dbpedia.org/property/southwest 10 | http://dbpedia.org/property/northeast 11 | http://dbpedia.org/property/southeast 12 | http://dbpedia.org/ontology/nearestCity 13 | http://dbpedia.org/property/northwest 14 | http://dbpedia.org/ontology/mouthPlace -------------------------------------------------------------------------------- /polygoncode/polygonembed/dbtopo/train_generative.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | import torch.nn.functional as F 7 | 8 | from polygonembed.module import * 9 | from polygonembed.SpatialRelationEncoder import * 10 | from polygonembed.resnet import * 11 | from polygonembed.PolygonEncoder import * 12 | from polygonembed.utils import * 13 | from polygonembed.data_util import * 14 | from polygonembed.dataset import * 15 | from polygonembed.trainer import * 16 | from polygonembed.trainer_helper import * 17 | 18 | 19 | 20 | parser = make_args_parser() 21 | args = parser.parse_args() 22 | 23 | 24 | pgon_gdf = load_dataframe(args.data_dir, args.pgon_filename) 25 | 26 | 27 | trainer = Trainer(args, pgon_gdf, console = True) 28 | 29 | 30 | trainer.run_polygon_generative_train() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.jar 19 | *.rar 20 | *.tar 21 | *.zip 22 | 23 | # Logs and databases # 24 | ###################### 25 | *.log 26 | *.sql 27 | *.sqlite 28 | 29 | # OS generated files # 30 | ###################### 31 | .DS_Store 32 | .DS_Store? 33 | ._* 34 | .Spotlight-V100 35 | .Trashes 36 | ehthumbs.db 37 | Thumbs.db 38 | 39 | # code # 40 | *.pyc 41 | 42 | # data # 43 | ######## 44 | # *.pkl 45 | *.pth 46 | *.pth.tar 47 | *.log 48 | *.crdownload 49 | 50 | # ignore directory 51 | ./polygoncode/model_dir/* 52 | ./polygoncode/polygonembed/.ipynb_checkpoints 53 | 54 | 55 | -------------------------------------------------------------------------------- /polygoncode/1_pgon_dbtopo.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | python3 -m polygonembed.dbtopo.train_spa_rel.py \ 4 | --data_dir ./data_proprocessing/dbtopo/output/ \ 5 | --model_dir ./model_dir/dbtopo/ \ 6 | --log_dir ./model_dir/dbtopo/ \ 7 | --pgon_filename pgon_300_gdf_prj.pkl \ 8 | --triple_filename pgon_triples_geom_300_norm_df.pkl \ 9 | --geom_type_list norm \ 10 | --data_split_num 0 \ 11 | --task rel \ 12 | --model_type cat \ 13 | --pgon_enc nuft_ddsl \ 14 | --pgon_embed_dim 512 \ 15 | --nuft_freqXY 32 32 \ 16 | --j 2 \ 17 | --padding_mode circular \ 18 | --do_polygon_random_start T \ 19 | --do_data_augment F \ 20 | --do_online_data_augment F \ 21 | --data_augment_type none \ 22 | --num_augment 0 \ 23 | --dropout 0.1 \ 24 | --spa_enc kdelta \ 25 | --spa_embed_dim 26 \ 26 | --freq 16 \ 27 | --max_radius 2 \ 28 | --min_radius 1e-6 \ 29 | --spa_f_act relu \ 30 | --freq_init geometric \ 31 | --spa_enc_use_postmat F \ 32 | --k_delta 12 \ 33 | --num_hidden_layer 1 \ 34 | --hidden_dim 512 \ 35 | --use_layn T \ 36 | --skip_connection T \ 37 | --pgon_dec explicit_conv \ 38 | --pgon_dec_grid_init circle \ 39 | --pgon_dec_grid_enc_type spa_enc \ 40 | --grt_loss_func LOOPL2 \ 41 | --do_weight_norm F \ 42 | --weight_decay 0.000 \ 43 | --pgon_norm_reg_weight 0.02 \ 44 | --task_loss_weight 0.95 \ 45 | --grt_epoches 0 \ 46 | --cla_epoches 5 \ 47 | --log_every 100 \ 48 | --val_every 100 \ 49 | --batch_size 128 \ 50 | --lr 0.01 \ 51 | --opt adam \ 52 | --act relu \ 53 | --balanced_train_loader F \ 54 | --device cuda:0 \ 55 | --tb F -------------------------------------------------------------------------------- /polygoncode/polygonembed/animal/train_img_cla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import BatchNorm1d 5 | 6 | 7 | import functools 8 | from functools import partial 9 | from dataclasses import dataclass 10 | from collections import OrderedDict 11 | 12 | 13 | import fiona 14 | import geopandas as gpd 15 | import pandas as pd 16 | 17 | import math 18 | import os 19 | import json 20 | import pickle 21 | import numpy as np 22 | from tqdm import tqdm 23 | import matplotlib 24 | import matplotlib.pyplot as plt 25 | import random 26 | import re 27 | import requests 28 | import copy 29 | 30 | import shapely 31 | from shapely.ops import transform 32 | 33 | from shapely.geometry.point import Point 34 | from shapely.geometry.polygon import Polygon 35 | from shapely.geometry.multipolygon import MultiPolygon 36 | from shapely.geometry.linestring import LineString 37 | from shapely.geometry.polygon import LinearRing 38 | from shapely.geometry import box 39 | 40 | from fiona.crs import from_epsg 41 | 42 | 43 | from datetime import datetime 44 | 45 | import torch 46 | import torch.nn as nn 47 | from torch.nn import init 48 | import torch.nn.functional as F 49 | 50 | 51 | from polygonembed.dataset import * 52 | from polygonembed.resnet2d import * 53 | from polygonembed.model_utils import * 54 | from polygonembed.data_util import * 55 | from polygonembed.trainer_helper import * 56 | 57 | from polygonembed.trainer_img import * 58 | 59 | 60 | parser = make_args_parser() 61 | args = parser.parse_args() 62 | 63 | img_gdf = load_dataframe(args.data_dir, args.img_filename) 64 | 65 | trainer = Trainer(args, img_gdf, console = True) 66 | 67 | trainer.run_train() -------------------------------------------------------------------------------- /polygoncode/polygonembed/sftp-config.json: -------------------------------------------------------------------------------- 1 | { 2 | // The tab key will cycle through the settings when first created 3 | // Visit http://wbond.net/sublime_packages/sftp/settings for help 4 | 5 | // sftp, ftp or ftps 6 | "type": "sftp", 7 | 8 | "save_before_upload": true, 9 | "upload_on_save": true, 10 | "sync_down_on_open": true, 11 | "sync_skip_deletes": false, 12 | "sync_same_age": true, 13 | "confirm_downloads": false, 14 | "confirm_sync": true, 15 | "confirm_overwrite_newer": false, 16 | 17 | "host": "stko-wrksrv.geog.ucsb.edu", 18 | "user": "gengchen", 19 | // "password": "", 20 | //"port": "22", 21 | 22 | "remote_path": "/home/gengchen/polygon2vec/polygonnet/polygoncode/polygonembed", 23 | "ignore_regexes": [ 24 | "\\.sublime-(project|workspace)", "sftp-config(-alt\\d?)?\\.json", 25 | "sftp-settings\\.json", "/venv/", "\\.svn/", "\\.hg/", "\\.git/", 26 | "\\.bzr", "_darcs", "CVS", "\\.DS_Store", "Thumbs\\.db", "desktop\\.ini", 27 | "\\.pdf", "\\.txt", "/node_modules(_old)?", "/server/data(_old)?" 28 | ], 29 | //"file_permissions": "664", 30 | //"dir_permissions": "775", 31 | 32 | //"extra_list_connections": 0, 33 | 34 | "connect_timeout": 30, 35 | //"keepalive": 120, 36 | //"ftp_passive_mode": true, 37 | //"ftp_obey_passive_host": false, 38 | //"ssh_key_file": "~/.ssh/id_rsa", 39 | //"sftp_flags": ["-F", "/path/to/ssh_config"], 40 | 41 | //"preserve_modification_times": false, 42 | //"remote_time_offset_in_hours": 0, 43 | //"remote_encoding": "utf-8", 44 | //"remote_locale": "C", 45 | //"allow_config_upload": false, 46 | } 47 | -------------------------------------------------------------------------------- /polygoncode/polygonembed/animal/train_pgon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import BatchNorm1d 5 | 6 | 7 | import functools 8 | from functools import partial 9 | from dataclasses import dataclass 10 | from collections import OrderedDict 11 | 12 | 13 | import fiona 14 | import geopandas as gpd 15 | import pandas as pd 16 | 17 | import math 18 | import os 19 | import json 20 | import pickle 21 | import numpy as np 22 | from tqdm import tqdm 23 | import matplotlib 24 | import matplotlib.pyplot as plt 25 | import random 26 | import re 27 | import requests 28 | import copy 29 | 30 | import shapely 31 | from shapely.ops import transform 32 | 33 | from shapely.geometry.point import Point 34 | from shapely.geometry.polygon import Polygon 35 | from shapely.geometry.multipolygon import MultiPolygon 36 | from shapely.geometry.linestring import LineString 37 | from shapely.geometry.polygon import LinearRing 38 | from shapely.geometry import box 39 | 40 | from fiona.crs import from_epsg 41 | 42 | 43 | from datetime import datetime 44 | 45 | import torch 46 | import torch.nn as nn 47 | from torch.nn import init 48 | import torch.nn.functional as F 49 | 50 | from polygonembed.module import * 51 | from polygonembed.SpatialRelationEncoder import * 52 | from polygonembed.resnet import * 53 | from polygonembed.PolygonEncoder import * 54 | from polygonembed.utils import * 55 | from polygonembed.data_util import * 56 | from polygonembed.dataset import * 57 | from polygonembed.trainer import * 58 | from polygonembed.trainer_helper import * 59 | 60 | 61 | 62 | parser = make_args_parser() 63 | args = parser.parse_args() 64 | 65 | pgon_gdf = load_dataframe(args.data_dir, args.pgon_filename) 66 | 67 | trainer = Trainer(args, pgon_gdf, console = True) 68 | 69 | trainer.run_train() -------------------------------------------------------------------------------- /polygoncode/polygonembed/dbtopo/train_spa_rel.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import fiona 5 | import geopandas as gpd 6 | import pandas as pd 7 | 8 | import math 9 | import os 10 | import json 11 | import pickle 12 | import numpy as np 13 | from tqdm import tqdm 14 | import matplotlib 15 | import matplotlib.pyplot as plt 16 | import random 17 | import re 18 | import requests 19 | import copy 20 | 21 | import shapely 22 | from shapely.ops import transform 23 | 24 | from shapely.geometry.point import Point 25 | from shapely.geometry.polygon import Polygon 26 | from shapely.geometry.multipolygon import MultiPolygon 27 | from shapely.geometry.linestring import LineString 28 | from shapely.geometry.polygon import LinearRing 29 | from shapely.geometry import box 30 | 31 | from fiona.crs import from_epsg 32 | 33 | 34 | from datetime import datetime 35 | 36 | import torch 37 | import torch.nn as nn 38 | from torch.nn import init 39 | import torch.nn.functional as F 40 | 41 | from polygonembed.module import * 42 | from polygonembed.SpatialRelationEncoder import * 43 | from polygonembed.resnet import * 44 | from polygonembed.PolygonEncoder import * 45 | from polygonembed.utils import * 46 | from polygonembed.data_util import * 47 | from polygonembed.dataset import * 48 | from polygonembed.trainer_helper import * 49 | 50 | 51 | from polygonembed.PolygonDecoder import * 52 | from polygonembed.enc_dec import * 53 | from polygonembed.model_utils import * 54 | from polygonembed.trainer_helper import * 55 | 56 | from polygonembed.trainer_dbtopo import * 57 | 58 | 59 | from polygonembed.ddsl_utils import * 60 | from polygonembed.ddsl import * 61 | 62 | 63 | parser = make_args_parser() 64 | args = parser.parse_args() 65 | 66 | triple_gdf = load_dataframe(args.data_dir, args.triple_filename) 67 | 68 | # pgon_gdf = load_dataframe(args.data_dir, args.pgon_filename) 69 | 70 | trainer = Trainer(args, triple_gdf, pgon_gdf = None, console = True) 71 | 72 | # trainer.run_train() 73 | trainer.run_eval(save_eval = True) -------------------------------------------------------------------------------- /polygoncode/polygonembed/mnist/train_pgon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import BatchNorm1d 5 | 6 | 7 | import functools 8 | from functools import partial 9 | from dataclasses import dataclass 10 | from collections import OrderedDict 11 | 12 | 13 | import fiona 14 | import geopandas as gpd 15 | import pandas as pd 16 | 17 | import math 18 | import os 19 | import json 20 | import pickle 21 | import numpy as np 22 | from tqdm import tqdm 23 | import matplotlib 24 | import matplotlib.pyplot as plt 25 | import random 26 | import re 27 | import requests 28 | import copy 29 | 30 | import shapely 31 | from shapely.ops import transform 32 | 33 | from shapely.geometry.point import Point 34 | from shapely.geometry.polygon import Polygon 35 | from shapely.geometry.multipolygon import MultiPolygon 36 | from shapely.geometry.linestring import LineString 37 | from shapely.geometry.polygon import LinearRing 38 | from shapely.geometry import box 39 | 40 | from fiona.crs import from_epsg 41 | 42 | 43 | from datetime import datetime 44 | 45 | import torch 46 | import torch.nn as nn 47 | from torch.nn import init 48 | import torch.nn.functional as F 49 | 50 | from polygonembed.module import * 51 | from polygonembed.SpatialRelationEncoder import * 52 | from polygonembed.resnet import * 53 | from polygonembed.PolygonEncoder import * 54 | from polygonembed.utils import * 55 | from polygonembed.data_util import * 56 | from polygonembed.dataset import * 57 | from polygonembed.trainer import * 58 | from polygonembed.trainer_helper import * 59 | 60 | from polygonembed.ddsl_utils import * 61 | from polygonembed.ddsl import * 62 | 63 | 64 | 65 | parser = make_args_parser() 66 | args = parser.parse_args() 67 | 68 | pgon_gdf = load_dataframe(args.data_dir, args.pgon_filename) 69 | 70 | trainer = Trainer(args, pgon_gdf, console = True) 71 | 72 | if args.load_model: 73 | print("load_model...") 74 | trainer.load_model() 75 | 76 | if not args.eval_only: 77 | trainer.run_train(save_eval = args.save_eval) 78 | else: 79 | trainer.run_eval(save_eval = args.save_eval) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards General-Purpose Representation Learning of Polygonal Geometries 2 | Code for recreting the results in [our GeoInformatica 2023 paper](https://link.springer.com/article/10.1007/s10707-022-00481-2) 3 | 4 | 5 | ## Related Link 6 | 1. [Springr Paper](https://link.springer.com/article/10.1007/s10707-022-00481-2) 7 | 2. [Arxiv Paper](https://arxiv.org/abs/2209.15458) 8 | 9 | ## Award 10 | 1. This paper won [AAG 2023 J. Warren Nystrom Award (1 award recipient every year)](https://www.aag.org/award-grant/nystrom/) 11 | 2. This paper won [AAG 2022 William L. Garrison Award for Best Dissertation in Computational Geography (1 award recipient every other year)](https://www.aag.org/aag-announces-2022-award-recipients/) 12 | 13 | ## Our Model Overview 14 |

15 | model 16 |

17 | 18 | ## Dependencies 19 | - Python 3.7+ 20 | - Torch 1.7.1+ 21 | - Other required packages are summarized in `requirements.txt`. 22 | 23 | ## Data 24 | Download the required dbtopo datasets from [here](https://www.dropbox.com/scl/fo/ubokquibjibxqb71lduto/h?rlkey=gnex7g3gx51g06gmd1v1um9u1&dl=0) and put them in `./data_proprocessing/dbtopo/output/` folder. The folder has two datasets: 25 | 1) DBSR-46K: the `pgon_triples_geom_300_norm_df.pkl`file, a GeoDataFrame contain the DBSR-46K spatial relation prediction dataset created from DBpedia and OpenStreetMap. Each row indicates a triple from DBpedia and its subject and object are presented as a simple polygon with 300 vertices. 26 | 2) DBSR-cplx46K: the `pgon_triples_geom_300_norm_df_complex.pkl` file, a GeoDataFrame contain the spatial relation prediction dataset. The only difference is each row's subject and object are presented as a complex polygon with 300 vertices. 27 | 28 | 29 | 30 | ## Train and Evaluation 31 | The main code are located in `polygoncode` folder 32 | 33 | 1) `1_pgon_dbtopo.sh` do suprevised training on both DBSR-46K and DBSR-cplx46K datasets. 34 | 35 | 36 | 37 | ### Reference 38 | If you find our work useful in your research please consider citing [our GeoInformatica 2023 paper](https://link.springer.com/article/10.1007/s10707-022-00481-2). 39 | ``` 40 | @article{mai2023towards, 41 | title={Towards general-purpose representation learning of polygonal geometries}, 42 | author={Mai, Gengchen and Jiang, Chiyu and Sun, Weiwei and Zhu, Rui and Xuan, Yao and Cai, Ling and Janowicz, Krzysztof and Ermon, Stefano and Lao, Ni}, 43 | journal={GeoInformatica}, 44 | volume={27}, 45 | number={2}, 46 | pages={289--340}, 47 | year={2023}, 48 | publisher={Springer} 49 | } 50 | ``` 51 | 52 | 53 | Please go to [Dr. Gengchen Mai's Homepage](https://gengchenmai.github.io/) for more information about Spatially Explicit Machine Learning and Artificial Intelligence. -------------------------------------------------------------------------------- /polygoncode/polygonembed/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | import numpy as np 7 | 8 | class Norm(nn.Module): 9 | def __init__(self, mean=0, std=1): 10 | super(Norm, self).__init__() 11 | self.mean = mean 12 | self.std = std 13 | 14 | def forward(self, x): 15 | return (x - self.mean) / self.std 16 | 17 | class LeNet5(nn.Module): 18 | def __init__(self, in_channels, num_classes, signal_sizes=(28,28), hidden_dim = 250, mean=0, std=1): 19 | super(LeNet5, self).__init__() 20 | self.in_channels = in_channels 21 | self.num_classes = num_classes 22 | self.hidden_dim = hidden_dim 23 | 24 | self.conv1 = nn.Conv2d(in_channels, 10, kernel_size=5) 25 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 26 | self.norm = Norm(mean, std) 27 | self.conv2_drop = nn.Dropout2d() 28 | 29 | self.conv_embed_dim = self.compute_conv_embed_dim(signal_sizes, curent_channels = 20) 30 | self.fc1 = nn.Linear(self.conv_embed_dim, self.hidden_dim) 31 | self.fc2 = nn.Linear(self.hidden_dim, self.num_classes) 32 | self.signal_sizes = signal_sizes 33 | 34 | def compute_conv_embed_dim(self, signal_sizes, curent_channels): 35 | fx, fy = signal_sizes 36 | return math.floor((fx-12)/4) * math.floor((fy-12)/4) * curent_channels 37 | 38 | def forward(self, x): 39 | ''' 40 | Args: 41 | x: shape (batch_size, in_channels, fx, fy), image tensor 42 | signal_sizes = (fx, fy) 43 | Return: 44 | x: shape (batch_size, num_classes) 45 | image class distribution 46 | 47 | ''' 48 | batch_size, n_c, fx, fy = x.shape 49 | assert n_c == self.in_channels 50 | assert fx == self.signal_sizes[0] 51 | assert fy == self.signal_sizes[1] 52 | 53 | # x = x.view(-1, 1, self.signal_sizes[0], self.signal_sizes[1]) 54 | 55 | # x: shape (batch_size, in_channels, fx, fy) 56 | x = self.norm(x) 57 | # self.conv1(x): shape [batch_size, 10, fx-4, fy-4] 58 | # F.max_pool2d(self.conv1(x), 2): shape [batch_size, 10, (fx-4)/2, (fy-4)/2 ] 59 | # x: shape [batch_size, 10, (fx-4)/2, (fy-4)/2 ] 60 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 61 | 62 | # self.conv2(x): [batch_size, 20, (fx-12)/2, (fy-12)/2 ] 63 | # self.conv2_drop(self.conv2(x)): [batch_size, 20, (fx-12)/2, (fy-12)/2 ] 64 | # F.max_pool2d(self.conv2_drop(self.conv2(x)), 2): [batch_size, 20, (fx-12)/4, (fy-12)/4 ] 65 | # x: [batch_size, 20, (fx-12)/4, (fy-12)/4 ] 66 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 67 | 68 | 69 | _, n_cc, fx_, fy_ = x.shape 70 | assert n_cc == 20 71 | assert fx_ == math.floor((fx-12)/4) 72 | assert fy_ == math.floor((fy-12)/4) 73 | 74 | # x: shape (batch_size, conv_embed_dim = 20 * floor((fx-12)/4) * floor((fy-12)/4) ) 75 | x = x.reshape(batch_size, -1) 76 | # x: shape (batch_size, 250) 77 | x = F.relu(self.fc1(x)) 78 | x = F.dropout(x, training=self.training) 79 | 80 | # x: shape (batch_size, num_classes) 81 | x = self.fc2(x) 82 | return x -------------------------------------------------------------------------------- /polygoncode/polygonembed/resnet2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision 6 | 7 | class ResNet2D(torchvision.models.resnet.ResNet): 8 | ''' 9 | See the torchvision version of ResNet 10 | https://github.com/pytorch/vision/blob/21153802a3086558e9385788956b0f2808b50e51/torchvision/models/resnet.py 11 | ''' 12 | def __init__(self, block, layers, in_channels = 1, num_classes=1000, zero_init_residual=False): 13 | super(ResNet2D, self).__init__(block, layers, num_classes, zero_init_residual) 14 | self.conv1 = torch.nn.Conv2d(in_channels, 64, 15 | kernel_size=(7, 7), 16 | stride=(2, 2), 17 | padding=(3, 3), bias=False) 18 | 19 | 20 | def get_resnet_model(resnet_type, num_classes = 20, in_channels = 1): 21 | if resnet_type == "resnet18": 22 | model = ResNet2D(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], in_channels = in_channels, num_classes = num_classes) 23 | elif resnet_type == "resnet34": 24 | model = ResNet2D(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], in_channels = in_channels, num_classes = num_classes) 25 | elif resnet_type == "resnet50": 26 | model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], in_channels = in_channels, num_classes = num_classes) 27 | elif resnet_type == "resnet101": 28 | model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], in_channels = in_channels, num_classes = num_classes) 29 | elif resnet_type == "resnet152": 30 | model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 8, 36, 3], in_channels = in_channels, num_classes = num_classes) 31 | 32 | return model 33 | 34 | # def resnet18(pretrained=False, **kwargs): 35 | # """Constructs a ResNet-18 model. 36 | # Args: 37 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 38 | # """ 39 | # model = ResNet2D(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], **kwargs) 40 | # # if pretrained: 41 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 42 | # return model 43 | 44 | 45 | # def resnet34(pretrained=False, **kwargs): 46 | # """Constructs a ResNet-34 model. 47 | # Args: 48 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 49 | # """ 50 | # model = ResNet2D(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], **kwargs) 51 | # # if pretrained: 52 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 53 | # return model 54 | 55 | 56 | # def resnet50(pretrained=False, **kwargs): 57 | # """Constructs a ResNet-50 model. 58 | # Args: 59 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 60 | # """ 61 | # model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs) 62 | # # if pretrained: 63 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 64 | # return model 65 | 66 | 67 | # def resnet101(pretrained=False, **kwargs): 68 | # """Constructs a ResNet-101 model. 69 | # Args: 70 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 71 | # """ 72 | # model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], **kwargs) 73 | # # if pretrained: 74 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 75 | # return model 76 | 77 | 78 | # def resnet152(pretrained=False, **kwargs): 79 | # """Constructs a ResNet-152 model. 80 | # Args: 81 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | # """ 83 | # model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 8, 36, 3], **kwargs) 84 | # # if pretrained: 85 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 86 | # return model -------------------------------------------------------------------------------- /polygoncode/polygonembed/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class _ModeNormalization(nn.Module): 7 | def __init__(self, dim, n_components, eps): 8 | super(_ModeNormalization, self).__init__() 9 | self.eps = eps 10 | self.dim = dim 11 | self.n_components = n_components 12 | 13 | self.alpha = nn.Parameter(torch.ones(1, dim, 1, 1)) 14 | self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1)) 15 | self.phi = lambda x: x.mean(3).mean(2) 16 | 17 | 18 | class ModeNorm(_ModeNormalization): 19 | """ 20 | An implementation of mode normalization. Input samples x are allocated into individual modes (their number is controlled by n_components) by a gating network; samples belonging together are jointly normalized and then passed back to the network. 21 | args: 22 | dim: int 23 | momentum: float 24 | n_components: int 25 | eps: float 26 | """ 27 | def __init__(self, dim, momentum, n_components, eps=1.e-5): 28 | super(ModeNorm, self).__init__(dim, n_components, eps) 29 | 30 | self.momentum = momentum 31 | 32 | self.x_ra = torch.zeros(n_components, 1, dim, 1, 1).cuda() 33 | self.x2_ra = torch.zeros(n_components, 1, dim, 1, 1).cuda() 34 | 35 | self.W = torch.nn.Linear(dim, n_components) 36 | self.W.weight.data = torch.ones(n_components, dim) / n_components + .01 * torch.randn(n_components, dim) 37 | self.softmax = torch.nn.Softmax(dim=1) 38 | 39 | self.weighted_mean = lambda w, x, n: (w * x).mean(3, keepdim=True).mean(2, keepdim=True).sum(0, keepdim=True) / n 40 | 41 | 42 | def forward(self, x): 43 | g = self._g(x) 44 | n_k = torch.sum(g, dim=1).squeeze() 45 | 46 | if self.training: 47 | self._update_running_means(g.detach(), x.detach()) 48 | 49 | x_split = torch.zeros(x.size()).cuda().to(x.device) 50 | 51 | for k in range(self.n_components): 52 | if self.training: 53 | mu_k = self.weighted_mean(g[k], x, n_k[k]) 54 | var_k = self.weighted_mean(g[k], (x - mu_k)**2, n_k[k]) 55 | else: 56 | mu_k, var_k = self._mu_var(k) 57 | mu_k = mu_k.to(x.device) 58 | var_k = var_k.to(x.device) 59 | 60 | x_split += g[k] * ((x - mu_k) / torch.sqrt(var_k + self.eps)) 61 | 62 | x = self.alpha * x_split + self.beta 63 | 64 | return x 65 | 66 | 67 | def _g(self, x): 68 | """ 69 | Image inputs are first flattened along their height and width dimensions by phi(x), then mode memberships are determined via a linear transformation, followed by a softmax activation. The gates are returned with size (k, n, c, 1, 1). 70 | args: 71 | x: torch.Tensor 72 | returns: 73 | g: torch.Tensor 74 | """ 75 | g = self.softmax(self.W(self.phi(x))).transpose(0, 1)[:, :, None, None, None] 76 | return g 77 | 78 | 79 | def _mu_var(self, k): 80 | """ 81 | At test time, this function is used to compute the k'th mean and variance from weighted running averages of x and x^2. 82 | args: 83 | k: int 84 | returns: 85 | mu, var: torch.Tensor, torch.Tensor 86 | """ 87 | mu = self.x_ra[k] 88 | var = self.x2_ra[k] - (self.x_ra[k] ** 2) 89 | return mu, var 90 | 91 | 92 | def _update_running_means(self, g, x): 93 | """ 94 | Updates weighted running averages. These are kept and used to compute estimators at test time. 95 | args: 96 | g: torch.Tensor 97 | x: torch.Tensor 98 | """ 99 | n_k = torch.sum(g, dim=1).squeeze() 100 | 101 | for k in range(self.n_components): 102 | x_new = self.weighted_mean(g[k], x, n_k[k]) 103 | x2_new = self.weighted_mean(g[k], x**2, n_k[k]) 104 | 105 | # ensure that tensors are on the right devices 106 | self.x_ra = self.x_ra.to(x_new.device) 107 | self.x2_ra = self.x2_ra.to(x2_new.device) 108 | self.x_ra[k] = self.momentum * x_new + (1-self.momentum) * self.x_ra[k] 109 | self.x2_ra[k] = self.momentum * x2_new + (1-self.momentum) * self.x2_ra[k] 110 | -------------------------------------------------------------------------------- /polygoncode/polygonembed/spec_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | import math 6 | import numpy as np 7 | from math import pi 8 | 9 | from polygonembed.ddsl_utils import * 10 | 11 | 12 | class SpecturalPooling(nn.Module): 13 | """ 14 | Do spectural pooling similar to https://arxiv.org/pdf/1506.03767.pdf 15 | """ 16 | def __init__(self, min_freqXY_ratio = 0.5, max_pool_freqXY = [10, 10], freqXY = [16, 16], device = "cpu"): 17 | """ 18 | Args: 19 | min_freqXY_ratio: c_m, the ratio to control the lowest fx, fy to keep, 20 | c_m = alpha + m/M * (beta - alpha) 21 | maxPoolFreqXY: H_m, the maximum frequency we want to keep, 22 | freqXY: res in DDSL_spec(), [fx, fy] 23 | fx, fy: number of frequency in X or Y dimention 24 | device: 25 | """ 26 | super(SpecturalPooling, self).__init__() 27 | self.min_freqXY_ratio = min_freqXY_ratio 28 | self.max_pool_freqXY = max_pool_freqXY 29 | self.freqXY = freqXY 30 | self.device = device 31 | 32 | # the maximum pool fx, fy should smaller/equal to freqXY 33 | assert len(freqXY) == len(max_pool_freqXY) 34 | for i in range(len(freqXY)): 35 | assert freqXY[i] >= max_pool_freqXY[i] 36 | 37 | 38 | 39 | 40 | def get_pool_freqXY(self, freqXY_ratio, max_pool_freqXY): 41 | min_pool_freqXY = [] 42 | for f in max_pool_freqXY: 43 | min_pool_freqXY.append(math.floor(f * freqXY_ratio)) 44 | return min_pool_freqXY 45 | 46 | def make_select_mask(self, fx, fy, maxfx, maxfy, y_dim): 47 | ''' 48 | mask the non select freq elements into 0 49 | Args: 50 | fx, fy: the select dimention in x, y axis 51 | maxfx, maxfy: the original freq dimention 52 | y_dim: the y dimention 53 | Return: 54 | mask: torch.FloatTensor(), shape (maxfx, y_dim, 1) 55 | 56 | ''' 57 | assert maxfy > y_dim 58 | fxtop, fxlow, fytop = self.get_freqXY_select_idx(fx, fy, maxfx, maxfy) 59 | 60 | mask = np.zeros((maxfx, y_dim)) 61 | mask[0:fxtop, 0:fytop] = 1 62 | mask[-fxlow:, 0:fytop] = 1 63 | 64 | mask = torch.FloatTensor(mask).unsqueeze(-1) 65 | return mask 66 | 67 | def get_freqXY_select_idx(self, fx, fy, maxfx, maxfy): 68 | ''' 69 | mask the non select freq elements into 0 70 | Args: 71 | fx, fy: the select dimention in x, y axis 72 | maxfx, maxfy: the original freq dimention 73 | Return: 74 | fxtop, fxlow: 0..fxtop and -fylow ... -1 selected 75 | fytop: 0...fytop selected 76 | 77 | ''' 78 | fxtop = math.ceil(fx/2) 79 | fxlow = fx - fxtop 80 | 81 | fytop = math.ceil(fy/2) 82 | return fxtop, fxlow, fytop 83 | 84 | def crop_freqmap(self, x, cur_pool_freqXY, freqXY): 85 | ''' 86 | crop the frequency map to (fx, fy) 87 | Args: 88 | x: torch.FloatTensor(), the input features (e.g., polygons) in the specture domain 89 | shape (batch_size, n_channel = 1, maxfx, maxfy//2+1, 2) 90 | cur_pool_freqXY: the pool freq dimention 91 | freqXY: the original dimention 92 | Return: 93 | spec_pool_res: torch.FloatTensor(), 94 | shape (batch_size, n_channel = 1, fx, ceil(fy/2), 2) 95 | ''' 96 | maxfx, maxfy = freqXY 97 | fx, fy = cur_pool_freqXY 98 | fxtop, fxlow, fytop = self.get_freqXY_select_idx(fx, fy, maxfx, maxfy) 99 | 100 | upblock = x[:, :, 0:fxtop, 0:fytop, :] 101 | lowblock = x[:, :, -fxlow:, 0:fytop, :] 102 | 103 | spec_pool_res = torch.cat([upblock, lowblock], dim = 2) 104 | return spec_pool_res 105 | 106 | def forward(self, x): 107 | ''' 108 | Args: 109 | x: torch.FloatTensor(), the input features (e.g., polygons) in the specture domain 110 | shape (batch_size, n_channel = 1, fx, fy//2+1, 2) 111 | ''' 112 | freqXY_ratio = np.random.uniform(self.min_freqXY_ratio, 1) 113 | # compute the current no-maked X, Y dimention 114 | cur_pool_freqXY = self.get_min_pool_freqXY(freqXY_ratio, max_pool_freqXY) 115 | 116 | batch_size, n_channel, fx, fy2, n_dim = x.shape 117 | assert fx == self.freqXY[0] 118 | assert fy2 == self.freqXY[1]//2 + 1 119 | 120 | # crop the input x into max_pool_freqXY 121 | # spec_pool_res: shape (batch_size, n_channel = 1, max_pool_freqXY[0], ceil(max_pool_freqXY[1]/2), 2) 122 | spec_pool_res = self.crop_freqmap(x, max_pool_freqXY, freqXY) 123 | 124 | # mask the freq element outside of cur_pool_freqXY 125 | # mask: shape (max_pool_freqXY[0], ceil(max_pool_freqXY[1]/2), 1) 126 | mask = make_select_mask(fx = cur_pool_freqXY[0], 127 | fy = cur_pool_freqXY[1], 128 | maxfx = max_pool_freqXY[0], 129 | maxfy = max_pool_freqXY[1], 130 | y_dim = spec_pool_res.shape[-2]) 131 | 132 | # spec_pool_mask_res: shape (batch_size, n_channel = 1, max_pool_freqXY[0], ceil(max_pool_freqXY[1]/2), 2) 133 | spec_pool_mask_res = spec_pool_res * mask.to(x.device) 134 | return spec_pool_mask_res 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /polygoncode/polygonembed/atten.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import pickle 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from math import sqrt 10 | 11 | 12 | class AttentionSet(nn.Module): 13 | def __init__(self, mode_dims, att_reg=0., att_tem=1., att_type="whole", bn='no', nat=1, name="Real"): 14 | ''' 15 | Paper https://openreview.net/forum?id=BJgr4kSFDS 16 | Modified from https://github.com/hyren/query2box/blob/master/codes/model.py 17 | ''' 18 | super(AttentionSet, self).__init__() 19 | 20 | self.att_reg = att_reg 21 | self.att_type = att_type 22 | self.att_tem = att_tem 23 | self.Attention_module = Attention(mode_dims, att_type=att_type, bn=bn, nat=nat, name = name) 24 | 25 | def forward(self, embeds): 26 | ''' 27 | Args: 28 | embeds: shape (B, mode_dims, L), B: batch_size; L: number of embeddings we need to aggregate 29 | Return: 30 | combined: shape (B, mode_dims) 31 | ''' 32 | # temp: shape (B, 1, L) or (B, mode_dims, L) 33 | temp = (self.Attention_module(embeds) + self.att_reg)/(self.att_tem+1e-4) 34 | if self.att_type == 'whole': 35 | # whole: we combine embeddings with a scalar attention coefficient 36 | # attens: shape (B, 1, L) 37 | attens = F.softmax(temp, dim=-1) 38 | # combined: shape (B, mode_dims, L) 39 | combined = embeds * attens 40 | # combined: shape (B, mode_dims) 41 | combined = torch.sum(combined, dim = -1) 42 | 43 | elif self.att_type == 'ele': 44 | # ele: we combine embeds1 and embeds2 with a vector attention coefficient 45 | # attens: shape (B, mode_dims, L) 46 | attens = F.softmax(temp, dim=-1) 47 | # combined: shape (B, mode_dims, L) 48 | combined = embeds * attens 49 | # combined: shape (B, mode_dims) 50 | combined = torch.sum(combined, dim = -1) 51 | 52 | return combined 53 | 54 | 55 | 56 | class Attention(nn.Module): 57 | def __init__(self, mode_dims, att_type="whole", bn = 'no', nat = 1, name="Real"): 58 | ''' 59 | Paper https://openreview.net/forum?id=BJgr4kSFDS 60 | Modified from https://github.com/hyren/query2box/blob/master/codes/model.py 61 | 62 | Args: 63 | mode_dims: the input embedding dimention we need to do attention 64 | att_type: the type of attention 65 | whole: we combine embeddings with a scalar attention coefficient 66 | ele: we combine embedding with a vector attention coefficient 67 | bn: the type of batch noralization type 68 | no: no batch norm 69 | before: batch norm before ReLU 70 | after: batch norm after ReLU 71 | nat: scalar = [1,2,3], the number of attention matrix we want to go through before atten_mats2 72 | ''' 73 | super(Attention, self).__init__() 74 | 75 | self.bn = bn 76 | self.nat = nat 77 | 78 | self.atten_mats1 = nn.Parameter(torch.FloatTensor(mode_dims, mode_dims)) 79 | nn.init.xavier_uniform_(self.atten_mats1) 80 | self.register_parameter("atten_mats1_%s"%name, self.atten_mats1) 81 | if self.nat >= 2: 82 | self.atten_mats1_1 = nn.Parameter(torch.FloatTensor(mode_dims, mode_dims)) 83 | nn.init.xavier_uniform_(self.atten_mats1_1) 84 | self.register_parameter("atten_mats1_1_%s"%name, self.atten_mats1_1) 85 | if self.nat >= 3: 86 | self.atten_mats1_2 = nn.Parameter(torch.FloatTensor(mode_dims, mode_dims)) 87 | nn.init.xavier_uniform_(self.atten_mats1_2) 88 | self.register_parameter("atten_mats1_2_%s"%name, self.atten_mats1_2) 89 | if bn != 'no': 90 | self.bn1 = nn.BatchNorm1d(mode_dims) 91 | self.bn1_1 = nn.BatchNorm1d(mode_dims) 92 | self.bn1_2 = nn.BatchNorm1d(mode_dims) 93 | if att_type == 'whole': 94 | self.atten_mats2 = nn.Parameter(torch.FloatTensor(1, mode_dims)) 95 | elif att_type == 'ele': 96 | self.atten_mats2 = nn.Parameter(torch.FloatTensor(mode_dims, mode_dims)) 97 | nn.init.xavier_uniform_(self.atten_mats2) 98 | self.register_parameter("atten_mats2_%s"%name, self.atten_mats2) 99 | 100 | def forward(self, center_embed): 101 | ''' 102 | Args: 103 | center_embed: shape (B, mode_dims, L), B: batch_size; L: number of embeddings we need to aggregate 104 | Return: 105 | temp3: 106 | if att_type == 'whole': 107 | temp3: shape (B, 1, L) 108 | elif att_type == 'ele': 109 | temp3: shape (B, mode_dims, L) 110 | ''' 111 | temp1 = center_embed 112 | if self.nat >= 1: 113 | # temp2: shape (B, mode_dims, L) 114 | temp2 = torch.einsum('kc,bcl->bkl', self.atten_mats1, temp1) 115 | if self.bn == 'no': 116 | temp2 = F.relu(temp2) 117 | elif self.bn == 'before': 118 | temp2 = F.relu(self.bn1(temp2)) 119 | elif self.bn == 'after': 120 | temp2 = self.bn1(F.relu(temp2)) 121 | # temp2: shape (B, mode_dims, L) 122 | if self.nat >= 2: 123 | temp2 = torch.einsum('kc,bcl->bkl', self.atten_mats1_1, temp2) 124 | if self.bn == 'no': 125 | temp2 = F.relu(temp2) 126 | elif self.bn == 'before': 127 | temp2 = F.relu(self.bn1_1(temp2)) 128 | elif self.bn == 'after': 129 | temp2 = self.bn1_1(F.relu(temp2)) 130 | # temp2: shape (B, mode_dims, L) 131 | if self.nat >= 3: 132 | temp2 = torch.einsum('kc,bcl->bkl', self.atten_mats1_2, temp2) 133 | if self.bn == 'no': 134 | temp2 = F.relu(temp2) 135 | elif self.bn == 'before': 136 | temp2 = F.relu(self.bn1_2(temp2)) 137 | elif self.bn == 'after': 138 | temp2 = self.bn1_2(F.relu(temp2)) 139 | # temp2: shape (B, mode_dims, L) 140 | 141 | temp3 = torch.einsum('kc,bcl->bkl', self.atten_mats2, temp2) 142 | ''' 143 | if att_type == 'whole': 144 | temp3: shape (B, 1, L) 145 | elif att_type == 'ele': 146 | temp3: shape (B, mode_dims, L) 147 | ''' 148 | return temp3 -------------------------------------------------------------------------------- /polygoncode/polygonembed/data_util.py: -------------------------------------------------------------------------------- 1 | import fiona 2 | import geopandas as gpd 3 | import pandas as pd 4 | 5 | import math 6 | import os 7 | import json 8 | import pickle 9 | import numpy as np 10 | from tqdm import tqdm 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | import random 14 | import re 15 | import requests 16 | import copy 17 | 18 | import shapely 19 | from shapely.ops import transform 20 | 21 | from shapely.geometry.point import Point 22 | from shapely.geometry.polygon import Polygon 23 | from shapely.geometry.multipolygon import MultiPolygon 24 | from shapely.geometry.linestring import LineString 25 | from shapely.geometry.polygon import LinearRing 26 | from shapely.geometry import box 27 | 28 | from fiona.crs import from_epsg 29 | 30 | 31 | from datetime import datetime 32 | 33 | 34 | from polygonembed.dataset import * 35 | 36 | 37 | 38 | def json_load(filepath): 39 | with open(filepath, "r") as json_file: 40 | data = json.load(json_file) 41 | return data 42 | 43 | def json_dump(data, filepath, pretty_format = True): 44 | with open(filepath, 'w') as fw: 45 | if pretty_format: 46 | json.dump(data, fw, indent=2, sort_keys=True) 47 | else: 48 | json.dump(data, fw) 49 | 50 | def pickle_dump(obj, pickle_filepath): 51 | with open(pickle_filepath, "wb") as f: 52 | pickle.dump(obj, f, protocol=2) 53 | 54 | def pickle_load(pickle_filepath): 55 | with open(pickle_filepath, "rb") as f: 56 | obj = pickle.load(f) 57 | return obj 58 | 59 | 60 | 61 | # ########################## GeoData Utility 62 | 63 | def make_projected_gdf(gdf, geometry_col = "geometry", epsg_code = 4326): 64 | gdf[geometry_col] = gdf[geometry_col].to_crs(epsg=epsg_code) 65 | gdf.crs = from_epsg(epsg_code) 66 | return gdf 67 | 68 | def explode(indata): 69 | if type(indata) == gpd.GeoDataFrame: 70 | indf = indata 71 | elif type(indata) == str: 72 | indf = gpd.GeoDataFrame.from_file(indata) 73 | else: 74 | raise Exception("Input no recognized") 75 | outdf = gpd.GeoDataFrame(columns=indf.columns, crs = indf.crs) 76 | for idx, row in indf.iterrows(): 77 | if type(row.geometry) == Polygon: 78 | outdf = outdf.append(row,ignore_index=True) 79 | if type(row.geometry) == MultiPolygon: 80 | multdf = gpd.GeoDataFrame(columns=indf.columns) 81 | recs = len(row.geometry) 82 | multdf = multdf.append([row]*recs,ignore_index=True) 83 | for geom in range(recs): 84 | multdf.loc[geom,'geometry'] = row.geometry[geom] 85 | outdf = outdf.append(multdf,ignore_index=True) 86 | return outdf 87 | 88 | def plot_multipolygon(multipygon, figsize = (32, 32), edgecolor='r', ax = None): 89 | if ax is None: 90 | fig, ax = plt.subplots(figsize = figsize) 91 | p = gpd.GeoDataFrame(geometry = [multipygon]) 92 | p.geometry.boundary.plot(edgecolor=edgecolor, ax = ax) 93 | if ax is None: 94 | plt.show() 95 | 96 | def get_bbox_of_df(indf): 97 | bbox_df = indf["geometry"].bounds 98 | 99 | minx = np.min(bbox_df["minx"]) 100 | miny = np.min(bbox_df["miny"]) 101 | maxx = np.max(bbox_df["maxx"]) 102 | maxy = np.max(bbox_df["maxy"]) 103 | extent = (minx, miny, maxx, maxy) 104 | return bbox_df, extent 105 | 106 | 107 | def get_polygon_exterior_max_num_vert(indf, col = "geometry"): 108 | coord_len_list = list(indf[col].apply(lambda x: len(x.exterior.coords) )) 109 | return max(coord_len_list) 110 | 111 | 112 | def upsample_polygon_exterior_gdf_by_num_vert(indf, num_vert): 113 | return indf["geometry"].apply( 114 | lambda x: Polygon(line_interpolate_by_num_vert(x.exterior, num_vert = num_vert).coords ) ) 115 | 116 | def line_interpolate_by_num_vert(geom, num_vert): 117 | if geom.geom_type in ['LineString', 'LinearRing']: 118 | num_vert_origin = len(geom.coords.xy[0]) 119 | 120 | if num_vert_origin == num_vert: 121 | return geom 122 | elif num_vert_origin > num_vert: 123 | raise Exception("The original line has larger number of vetice then your input") 124 | num_vert_add = num_vert - num_vert_origin 125 | 126 | # print(num_vert, num_vert_origin, num_vert_add) 127 | pt_add_list = [] 128 | dist_add_list = [] 129 | for i in range(1, num_vert_add+1): 130 | pt_add = geom.interpolate(float(i) / (num_vert_add+1), normalized=True) 131 | dist_add = geom.project(pt_add) 132 | 133 | pt_add_list.append(pt_add) 134 | dist_add_list.append(dist_add) 135 | 136 | for idx in range(1, num_vert_origin - 1): 137 | pt = Point(geom.coords[idx]) 138 | dist = geom.project(pt) 139 | insert_idx = np.searchsorted(dist_add_list, dist) 140 | 141 | dist_add_list = dist_add_list[:insert_idx] + [dist] + dist_add_list[insert_idx:] 142 | pt_add_list = pt_add_list[:insert_idx] + [pt] + pt_add_list[insert_idx:] 143 | 144 | 145 | pt_add_list = [Point(geom.coords[0])] + pt_add_list + [Point(geom.coords[0])] 146 | if geom.geom_type == 'LineString': 147 | return LineString(pt_add_list) 148 | elif geom.geom_type == 'LinearRing': 149 | line = LineString(pt_add_list) 150 | return LinearRing(line.coords) 151 | else: 152 | raise ValueError('unhandled geometry %s', (geom.geom_type,)) 153 | 154 | def normalize_geometry_by_extent(geom, extent = None): 155 | ''' 156 | Normalize the polygon coords to x: (-1, 1), y: (-1, 1) 157 | Args: 158 | geom: a geometry 159 | extent: (minx, miny, maxx, maxy) 160 | ''' 161 | if extent is not None: 162 | minx, miny, maxx, maxy = extent 163 | else: 164 | minx, miny, maxx, maxy = geom.bounds 165 | 166 | assert minx < maxx 167 | assert miny < maxy 168 | # compute extent center 169 | x_c = (maxx + minx)/2 170 | y_c = (maxy + miny)/2 171 | # 1. affinity to the extent's center 172 | geom_aff = shapely.affinity.affine_transform(geom, matrix = [1, 0, 0, 1, -x_c, -y_c]) 173 | # plot_multipolygon(geom_aff, figsize = (5, 5)) 174 | 175 | 176 | deltax = maxx - minx 177 | deltay = maxy - miny 178 | if deltax >= deltay: 179 | max_len = deltax 180 | else: 181 | max_len = deltay 182 | # 2. scale to x: (-1, 1), y: (-1, 1) 183 | geom_scale = shapely.affinity.scale(geom_aff, xfact=2.0/max_len, yfact=2.0/max_len, zfact=0, origin='center') 184 | # plot_multipolygon(geom_scale, figsize = (5, 5)) 185 | return geom_scale 186 | 187 | 188 | 189 | def load_dataframe(data_dir, filename): 190 | if filename.endswith(".pkl"): 191 | ingdf = pickle_load(os.path.join(data_dir, filename)) 192 | else: 193 | raise Exception('Unknow file type') 194 | return ingdf 195 | 196 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /polygoncode/polygonembed/PolygonDecoder.py: -------------------------------------------------------------------------------- 1 | import fiona 2 | import geopandas as gpd 3 | import pandas as pd 4 | 5 | import math 6 | import os 7 | import json 8 | import pickle 9 | import numpy as np 10 | from tqdm import tqdm 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | import random 14 | import re 15 | import requests 16 | import copy 17 | 18 | import shapely 19 | from shapely.ops import transform 20 | 21 | from shapely.geometry.point import Point 22 | from shapely.geometry.polygon import Polygon 23 | from shapely.geometry.multipolygon import MultiPolygon 24 | from shapely.geometry.linestring import LineString 25 | from shapely.geometry.polygon import LinearRing 26 | from shapely.geometry import box 27 | 28 | from fiona.crs import from_epsg 29 | 30 | 31 | from datetime import datetime 32 | 33 | import torch 34 | import torch.nn as nn 35 | from torch.nn import init 36 | import torch.nn.functional as F 37 | 38 | 39 | 40 | from polygonembed.module import * 41 | from polygonembed.utils import * 42 | 43 | 44 | class ExplicitMLPPolygonDecoder(nn.Module): 45 | """ 46 | 47 | """ 48 | def __init__(self, spa_enc, pgon_embed_dim, num_vert, pgon_dec_grid_init = "uniform", pgon_dec_grid_enc_type = "none", 49 | coord_dim = 2, extent = (-1, 1, -1, 1), device = "cpu"): 50 | """ 51 | Args: 52 | spa_enc: a spatial encoder 53 | pgon_embed_dim: the output polygon embedding dimention 54 | num_vert: number of uniuqe vertices each generated polygon with have 55 | note that this does not include the extra point (same as the 1st one) to close the ring 56 | pgon_dec_grid_init: We generate a list of grid points for polygon decoder, the type of grid points are: 57 | uniform: points uniformly sampled from (-1, 1, -1, 1) 58 | circle: points sampled equal-distance on a circle whose radius is randomly sampled 59 | kdgrid: k-d regular grid, (num_vert - k^2) is uniformly sampled 60 | pgon_dec_grid_enc_type: the type to encode the grid point 61 | none: no encoding, use the original grid point 62 | spa_enc: use space encoder to encode grid point before 63 | coord_dim: 2 64 | device: 65 | """ 66 | super(ExplicitMLPPolygonDecoder, self).__init__() 67 | 68 | self.spa_enc = spa_enc 69 | self.spa_embed_dim = self.spa_enc.spa_embed_dim 70 | self.pgon_embed_dim = pgon_embed_dim 71 | self.num_vert = num_vert 72 | self.pgon_dec_grid_init = pgon_dec_grid_init 73 | self.coord_dim = coord_dim 74 | self.extent = extent 75 | 76 | self.device = device 77 | self.grid_dim = 2 78 | 79 | self.pgon_dec_grid_enc_type = pgon_dec_grid_enc_type 80 | 81 | # Partly borrowed from atlasnetV2 82 | # define point generator 83 | if self.pgon_dec_grid_enc_type == "none": 84 | self.nlatent = self.pgon_embed_dim + self.grid_dim 85 | elif self.pgon_dec_grid_enc_type == "spa_enc": 86 | self.nlatent = self.pgon_embed_dim + self.spa_embed_dim 87 | else: 88 | raise Exception("Unknown pgon_dec_grid_enc_type") 89 | 90 | # by default the bias = True, so this is like a individual MLP 91 | self.conv1 = torch.nn.Conv1d(in_channels = self.nlatent, out_channels = self.nlatent, kernel_size = 1) 92 | self.conv2 = torch.nn.Conv1d(in_channels = self.nlatent, out_channels = self.nlatent//2, kernel_size = 1) 93 | self.conv3 = torch.nn.Conv1d(in_channels = self.nlatent//2, out_channels = self.nlatent//4, kernel_size = 1) 94 | self.conv4 = torch.nn.Conv1d(in_channels = self.nlatent//4, out_channels = coord_dim, kernel_size = 1) 95 | 96 | self.th = nn.Tanh() 97 | self.bn1 = torch.nn.BatchNorm1d(self.nlatent) 98 | self.bn2 = torch.nn.BatchNorm1d(self.nlatent//2) 99 | self.bn3 = torch.nn.BatchNorm1d(self.nlatent//4) 100 | 101 | 102 | 103 | def forward(self, pgon_embeds): 104 | ''' 105 | Args: 106 | pgon_embeds: tensor, shape (batch_size, pgon_embed_dim) 107 | Return: 108 | 109 | polygons: shape (batch_size, num_vert, coord_dim = 2) 110 | rand_grid: shape (batch_size, 2, num_vert) 111 | 112 | ''' 113 | device = pgon_embeds.device 114 | batch_size, pgon_embed_dim = pgon_embeds.shape 115 | 116 | # Concat grids and pgon_embeds: Bx(C+2)xN 117 | # pgon_embeds_dup: shape (batch_size, pgon_embed_dim, num_vert) 118 | pgon_embeds_dup = pgon_embeds.unsqueeze(2).expand( 119 | batch_size, pgon_embed_dim, self.num_vert).contiguous() # BxCxN 120 | 121 | # generate rand grids 122 | # rand_grid: shape (batch_size, 2, num_vert) 123 | rand_grid = generate_rand_grid(self.pgon_dec_grid_init, 124 | batch_size = batch_size, 125 | num_vert = self.num_vert, 126 | extent = self.extent, 127 | grid_dim = self.grid_dim) 128 | if self.pgon_dec_grid_enc_type == "none": 129 | rand_grid = torch.FloatTensor(rand_grid).to(device) 130 | elif self.pgon_dec_grid_enc_type == "spa_enc": 131 | # rand_grid: np.array(), shape (batch_size, num_vert, 2) 132 | rand_grid = rand_grid.transpose(0,2,1) 133 | # rand_grid: torch.tensor(), shape (batch_size, num_vert, spa_embed_dim) 134 | rand_grid = self.spa_enc(rand_grid) 135 | # rand_grid: torch.tensor(), shape (batch_size, spa_embed_dim, num_vert) 136 | rand_grid = rand_grid.permute(0,2,1) 137 | 138 | else: 139 | raise Exception("Unknown pgon_dec_grid_enc_type") 140 | 141 | 142 | 143 | # x: shape (batch_size, 2 + pgon_embed_dim, num_vert) or (batch_size, spa_embed_dim + pgon_embed_dim, num_vert) 144 | x = torch.cat([rand_grid, pgon_embeds_dup], dim=1) 145 | 146 | # Generate points 147 | # x: shape (batch_size, 2 + pgon_embed_dim, num_vert) 148 | x = F.relu(self.bn1(self.conv1(x))) 149 | 150 | # x: shape (batch_size, (2 + pgon_embed_dim)/2, num_vert) 151 | x = F.relu(self.bn2(self.conv2(x))) 152 | 153 | # x: shape (batch_size, (2 + pgon_embed_dim)/4, num_vert) 154 | x = F.relu(self.bn3(self.conv3(x))) 155 | 156 | # x: shape (batch_size, coord_dim, num_vert) 157 | x = self.th(self.conv4(x)) 158 | 159 | # polygons: shape (batch_size, num_vert, coord_dim = 2) 160 | polygons = x.transpose(2, 1) 161 | 162 | # rand_grid_: shape (batch_size, num_vert, grid_dim = 2) 163 | rand_grid_ = rand_grid.transpose(2, 1) 164 | return polygons, rand_grid_ 165 | 166 | 167 | 168 | 169 | 170 | class ExplicitConvPolygonDecoder(nn.Module): 171 | """ 172 | 173 | """ 174 | def __init__(self, spa_enc, pgon_embed_dim, num_vert, pgon_dec_grid_init = "uniform", pgon_dec_grid_enc_type = "none", 175 | coord_dim = 2, padding_mode = 'circular', extent = (-1, 1, -1, 1), device = "cpu"): 176 | """ 177 | Args: 178 | spa_enc: a spatial encoder 179 | pgon_embed_dim: the output polygon embedding dimention 180 | num_vert: number of uniuqe vertices each generated polygon with have 181 | note that this does not include the extra point (same as the 1st one) to close the ring 182 | pgon_dec_grid_init: We generate a list of grid points for polygon decoder, the type of grid points are: 183 | uniform: points uniformly sampled from (-1, 1, -1, 1) 184 | circle: points sampled equal-distance on a circle whose radius is randomly sampled 185 | kdgrid: k-d regular grid, (num_vert - k^2) is uniformly sampled 186 | pgon_dec_grid_enc_type: the type to encode the grid point 187 | none: no encoding, use the original grid point 188 | spa_enc: use space encoder to encode grid point before 189 | coord_dim: 2 190 | padding_mode: 'circular' 191 | device: 192 | """ 193 | super(ExplicitConvPolygonDecoder, self).__init__() 194 | 195 | self.spa_enc = spa_enc 196 | self.spa_embed_dim = self.spa_enc.spa_embed_dim 197 | self.pgon_embed_dim = pgon_embed_dim 198 | self.num_vert = num_vert 199 | self.pgon_dec_grid_init = pgon_dec_grid_init 200 | self.coord_dim = coord_dim 201 | self.extent = extent 202 | 203 | self.device = device 204 | self.grid_dim = 2 205 | 206 | self.pgon_dec_grid_enc_type = pgon_dec_grid_enc_type 207 | 208 | # Partly borrowed from atlasnetV2 209 | # define point generator 210 | if self.pgon_dec_grid_enc_type == "none": 211 | self.nlatent = self.pgon_embed_dim + self.grid_dim 212 | elif self.pgon_dec_grid_enc_type == "spa_enc": 213 | self.nlatent = self.pgon_embed_dim + self.spa_embed_dim 214 | else: 215 | raise Exception("Unknown pgon_dec_grid_enc_type") 216 | 217 | # by default the bias = True, so this is like a individual MLP 218 | self.conv1 = torch.nn.Conv1d(in_channels = self.nlatent, out_channels = self.nlatent, kernel_size = 3, 219 | stride=1, padding=1, padding_mode = padding_mode) 220 | self.conv2 = torch.nn.Conv1d(in_channels = self.nlatent, out_channels = self.nlatent//2, kernel_size = 3, 221 | stride=1, padding=1, padding_mode = padding_mode) 222 | self.conv3 = torch.nn.Conv1d(in_channels = self.nlatent//2, out_channels = self.nlatent//4, kernel_size = 3, 223 | stride=1, padding=1, padding_mode = padding_mode) 224 | self.conv4 = torch.nn.Conv1d(in_channels = self.nlatent//4, out_channels = coord_dim, kernel_size = 1) 225 | 226 | self.th = nn.Tanh() 227 | self.bn1 = torch.nn.BatchNorm1d(self.nlatent) 228 | self.bn2 = torch.nn.BatchNorm1d(self.nlatent//2) 229 | self.bn3 = torch.nn.BatchNorm1d(self.nlatent//4) 230 | 231 | 232 | 233 | def forward(self, pgon_embeds): 234 | ''' 235 | Args: 236 | pgon_embeds: tensor, shape (batch_size, pgon_embed_dim) 237 | Return: 238 | 239 | polygons: shape (batch_size, num_vert, coord_dim = 2) 240 | rand_grid: shape (batch_size, 2, num_vert) 241 | 242 | ''' 243 | device = pgon_embeds.device 244 | batch_size, pgon_embed_dim = pgon_embeds.shape 245 | 246 | # Concat grids and pgon_embeds: Bx(C+2)xN 247 | # pgon_embeds_dup: shape (batch_size, pgon_embed_dim, num_vert) 248 | pgon_embeds_dup = pgon_embeds.unsqueeze(2).expand( 249 | batch_size, pgon_embed_dim, self.num_vert).contiguous() # BxCxN 250 | 251 | # generate rand grids 252 | # rand_grid: shape (batch_size, 2, num_vert) 253 | rand_grid = generate_rand_grid(self.pgon_dec_grid_init, 254 | batch_size = batch_size, 255 | num_vert = self.num_vert, 256 | extent = self.extent, 257 | grid_dim = self.grid_dim) 258 | if self.pgon_dec_grid_enc_type == "none": 259 | rand_grid = torch.FloatTensor(rand_grid).to(device) 260 | elif self.pgon_dec_grid_enc_type == "spa_enc": 261 | # rand_grid: np.array(), shape (batch_size, num_vert, 2) 262 | rand_grid = rand_grid.transpose(0,2,1) 263 | # rand_grid: torch.tensor(), shape (batch_size, num_vert, spa_embed_dim) 264 | rand_grid = self.spa_enc(rand_grid) 265 | # rand_grid: torch.tensor(), shape (batch_size, spa_embed_dim, num_vert) 266 | rand_grid = rand_grid.permute(0,2,1) 267 | 268 | else: 269 | raise Exception("Unknown pgon_dec_grid_enc_type") 270 | 271 | 272 | 273 | # x: shape (batch_size, 2 + pgon_embed_dim, num_vert) or (batch_size, spa_embed_dim + pgon_embed_dim, num_vert) 274 | x = torch.cat([rand_grid, pgon_embeds_dup], dim=1) 275 | 276 | # Generate points 277 | # x: shape (batch_size, 2 + pgon_embed_dim, num_vert) 278 | x = F.relu(self.bn1(self.conv1(x))) 279 | 280 | # x: shape (batch_size, (2 + pgon_embed_dim)/2, num_vert) 281 | x = F.relu(self.bn2(self.conv2(x))) 282 | 283 | # x: shape (batch_size, (2 + pgon_embed_dim)/4, num_vert) 284 | x = F.relu(self.bn3(self.conv3(x))) 285 | 286 | # x: shape (batch_size, coord_dim, num_vert) 287 | x = self.th(self.conv4(x)) 288 | 289 | # polygons: shape (batch_size, num_vert, coord_dim = 2) 290 | polygons = x.transpose(2, 1) 291 | 292 | # rand_grid_: shape (batch_size, num_vert, grid_dim = 2) 293 | rand_grid_ = rand_grid.transpose(2, 1) 294 | return polygons, rand_grid_ -------------------------------------------------------------------------------- /polygoncode/polygonembed/trainer_img.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from torch import optim 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | from polygonembed.dataset import * 6 | from polygonembed.resnet2d import * 7 | from polygonembed.model_utils import * 8 | from polygonembed.data_util import * 9 | from polygonembed.trainer_helper import * 10 | 11 | 12 | 13 | 14 | def make_args_parser(): 15 | parser = ArgumentParser() 16 | # dir 17 | parser.add_argument("--data_dir", type=str, default="./animal/") 18 | parser.add_argument("--model_dir", type=str, default="./model_dir/animal_img/") 19 | parser.add_argument("--log_dir", type=str, default="./model_dir/animal_img/") 20 | 21 | #data 22 | parser.add_argument("--img_filename", type=str, default="animal_img_mats.pkl") 23 | 24 | parser.add_argument("--data_split_num", type=int, default=0, 25 | help='''we might do multiple train/valid/test split, 26 | this indicate which split we will use to train 27 | Note that we use 1, 0, -1 to indicate train/test/valid 28 | 1: train 29 | 0: test 30 | -1: valid ''') 31 | parser.add_argument("--num_worker", type=int, default=0, 32 | help='the number of worker for dataloader') 33 | 34 | 35 | 36 | # model type 37 | parser.add_argument("--model_type", type=str, default="", 38 | help='''the type of image classification model we use, 39 | resnet18 / resnet34 / resnet50 / resnet101 / resnet152 40 | ''') 41 | 42 | 43 | # model 44 | # parser.add_argument("--embed_dim", type=int, default=64, 45 | # help='Point feature embedding dim') 46 | # parser.add_argument("--dropout", type=float, default=0.5, 47 | # help='The dropout rate used in all fully connected layer') 48 | # parser.add_argument("--act", type=str, default='sigmoid', 49 | # help='the activation function for the encoder decoder') 50 | 51 | 52 | 53 | # # # encoder decoder 54 | # # parser.add_argument("--join_dec_type", type=str, default='max', 55 | # # help='the type of join_dec, min/max/mean/cat') 56 | 57 | # # polygon encoder 58 | # parser.add_argument("--pgon_enc", type=str, default="resnet", 59 | # help='''the type of polygon encoder: 60 | # resnet: ResNet based encoder 61 | # veercnn: the CNN model proposed in https://arxiv.org/pdf/1806.03857.pdf''') 62 | # parser.add_argument("--pgon_embed_dim", type=int, default=64, 63 | # help='the embedding dimention of polygon') 64 | # parser.add_argument("--padding_mode", type=str, default="circular", 65 | # help='the type of padding method for Conv1D: circular / zeros / reflect / replicate') 66 | # parser.add_argument("--resnet_add_middle_pool", type=str, default='F', 67 | # help='whether to add MaxPool1D between the middle layers of ResNet') 68 | # parser.add_argument("--resnet_fl_pool_type", type=str, default="mean", 69 | # help='the type of final pooling method: mean / min /max') 70 | # parser.add_argument("--resnet_block_type", type=str, default="basic", 71 | # help='the type of ResNet block we will use: basic / bottleneck') 72 | # parser.add_argument("--resnet_layers_per_block", nargs='+', type=int, default=[], 73 | # help='the number of layers per resnet block, must be 3 layers') 74 | 75 | 76 | 77 | parser.add_argument("--do_data_augment", type=str, default="F", 78 | help = "whether do polygon data argumentation, flip, rotate, scale polygons in each batch") 79 | 80 | 81 | 82 | 83 | # train 84 | parser.add_argument("--opt", type=str, default="adam") 85 | parser.add_argument("--lr", type=float, default=0.01, 86 | help='learning rate') 87 | 88 | parser.add_argument("--weight_decay", type=float, default=0.001, 89 | help='weight decay of adam optimizer') 90 | # parser.add_argument("--task_loss_weight", type=float, default=0.8, 91 | # help='the weight of classification loss when we do join training') 92 | # parser.add_argument("--pgon_norm_reg_weight", type=float, default=0.1, 93 | # help='the weight of polygon embedding norm regularizer') 94 | 95 | # parser.add_argument("--grt_epoches", type=int, default=50000000, 96 | # help='the maximum epoches for generative model converge') 97 | parser.add_argument("--cla_epoches", type=int, default=50000000, 98 | help='the maximum epoches for polygon classifier model converge') 99 | # parser.add_argument("--max_burn_in", type=int, default=5000, 100 | # help='the maximum iterator for relative/global model converge') 101 | parser.add_argument("--batch_size", type=int, default=512) 102 | # parser.add_argument("--tol", type=float, default=0.000001) 103 | 104 | parser.add_argument("--balanced_train_loader", type=str, default="T", 105 | help = "whether we do BalancedSampler for polygon classification") 106 | 107 | 108 | 109 | # eval 110 | parser.add_argument("--log_every", type=int, default=50) 111 | parser.add_argument("--val_every", type=int, default=5000) 112 | 113 | 114 | # load old model 115 | parser.add_argument("--load_model", action='store_true') 116 | 117 | # cuda 118 | parser.add_argument("--device", type=str, default="cpu") 119 | 120 | return parser 121 | 122 | 123 | def bool_arg_handler(arg): 124 | return True if arg == "T" else False 125 | 126 | def update_args(args): 127 | select_args = ["balanced_train_loader", "do_data_augment"] 128 | for arg in select_args: 129 | args.__dict__[arg] = bool_arg_handler(getattr(args, arg)) 130 | 131 | return args 132 | 133 | 134 | def make_args_combine(args): 135 | args_data = "{data:s}-{data_split_num:d}".format( 136 | data=args.data_dir.strip().split("/")[-2], 137 | data_split_num = args.data_split_num 138 | ) 139 | 140 | args_train = "-{batch_size:d}-{lr:.6f}-{opt:s}-{weight_decay:.2f}-{balanced_train_loader:s}".format( 141 | # act = args.act, 142 | # dropout=args.dropout, 143 | batch_size=args.batch_size, 144 | lr=args.lr, 145 | opt = args.opt, 146 | weight_decay = args.weight_decay, 147 | # task_loss_weight = args.task_loss_weight, 148 | # pgon_norm_reg_weight = args.pgon_norm_reg_weight, 149 | balanced_train_loader = args.balanced_train_loader, 150 | # do_polygon_random_start = args.do_polygon_random_start 151 | ) 152 | 153 | args_combine = "/{args_data:s}-{model_type:s}-{args_train:s}".format( 154 | args_data = args_data, 155 | model_type=args.model_type, 156 | args_train = args_train 157 | 158 | ) 159 | return args_combine 160 | 161 | 162 | 163 | class Trainer(): 164 | """ 165 | Trainer 166 | """ 167 | def __init__(self, args, img_gdf, console = True): 168 | 169 | 170 | self.args_combine = make_args_combine(args) #+ ".L2" 171 | 172 | self.log_file = args.log_dir + self.args_combine + ".log" 173 | self.model_file = args.model_dir + self.args_combine + ".pth" 174 | # tensorboard log directory 175 | # self.tb_log_dir = args.model_dir + self.args_combine 176 | self.tb_log_dir = args.model_dir + "/tb" 177 | 178 | if not os.path.exists(args.model_dir): 179 | os.makedirs(args.model_dir) 180 | if not os.path.exists(args.log_dir): 181 | os.makedirs(args.log_dir) 182 | 183 | self.logger = setup_logging(self.log_file, console = console, filemode='a') 184 | 185 | 186 | 187 | self.img_gdf = img_gdf 188 | 189 | args = update_args(args) 190 | self.args = args 191 | 192 | self.img_cla_dataset, self.img_cla_dataloader, self.split2num, self.data_split_col, self.train_sampler, self.num_classes = self.load_image_cla_dataset_dataloader( 193 | img_gdf, 194 | num_worker = args.num_worker, 195 | batch_size = args.batch_size, 196 | balanced_train_loader = args.balanced_train_loader, 197 | id_col = "ID", 198 | img_col = "IMAGE", 199 | class_col = "TYPEID", 200 | data_split_num = args.data_split_num, 201 | do_data_augment = args.do_data_augment, 202 | device = args.device) 203 | 204 | self.model = get_resnet_model(resnet_type = args.model_type, num_classes = self.num_classes).to(args.device) 205 | 206 | 207 | if args.opt == "sgd": 208 | self.optimizer = optim.SGD(filter(lambda p : p.requires_grad, self.model.parameters()), lr=args.lr, momentum=0) 209 | elif args.opt == "adam": 210 | self.optimizer = optim.Adam(filter(lambda p : p.requires_grad, self.model.parameters()), lr=args.lr, weight_decay = args.weight_decay) 211 | 212 | print("create model from {}".format(self.args_combine + ".pth")) 213 | self.logger.info("Save file at {}".format(self.args_combine + ".pth")) 214 | 215 | self.tb_writer = SummaryWriter(self.tb_log_dir) 216 | self.global_batch_idx = 0 217 | 218 | def load_image_cla_dataset_dataloader(self, img_gdf, 219 | num_worker, batch_size, balanced_train_loader, 220 | id_col = "ID", img_col = "IMAGE", class_col = "TYPEID", data_split_num = 0, 221 | do_data_augment = False, device = "cpu"): 222 | ''' 223 | load polygon classification dataset including training, validation, testing 224 | ''' 225 | img_cla_dataset = dict() 226 | img_cla_dataloader = dict() 227 | data_split_col = "SPLIT_{:d}".format(data_split_num) 228 | split2num = {"TRAIN": 1, "TEST": 0, "VALID": -1} 229 | split_nums = np.unique(np.array(img_gdf[data_split_col])) 230 | assert 1 in split_nums and 0 in split_nums 231 | dup_test = False 232 | if -1 not in split_nums: 233 | # we will make valid and test the same 234 | dup_test = True 235 | 236 | un_class = np.unique(np.array(img_gdf[class_col])) 237 | num_classes = len(un_class) 238 | max_num_exs_per_class = math.ceil(batch_size/num_classes) 239 | 240 | 241 | for split in ["TRAIN", "TEST", "VALID"]: 242 | # make dataset 243 | 244 | if split == "VALID" and dup_test: 245 | img_cla_dataset[split] = img_cla_dataset["TEST"] 246 | else: 247 | img_split_gdf = img_gdf[ img_gdf[data_split_col] == split2num[split] ] 248 | 249 | if split == "TRAIN": 250 | img_cla_dataset[split] = ImageDataset(img_gdf = img_split_gdf, 251 | id_col = id_col, 252 | img_col = img_col, 253 | class_col = class_col, 254 | do_data_augment = do_data_augment, 255 | device = device) 256 | else: 257 | img_cla_dataset[split] = ImageDataset(img_gdf = img_split_gdf, 258 | id_col = id_col, 259 | img_col = img_col, 260 | class_col = class_col, 261 | do_data_augment = False, 262 | device = device) 263 | 264 | # make dataloader 265 | if split == "TRAIN": 266 | if balanced_train_loader: 267 | train_sampler = BalancedSampler(classes = img_cla_dataset["TRAIN"].class_list.cpu().numpy(), 268 | num_per_class = max_num_exs_per_class, 269 | use_replace=False, 270 | multi_label=False) 271 | img_cla_dataloader[split] = torch.utils.data.DataLoader(img_cla_dataset[split], 272 | num_workers = num_worker, 273 | batch_size = batch_size, 274 | sampler=train_sampler, 275 | shuffle = False) 276 | else: 277 | train_sampler = None 278 | img_cla_dataloader[split] = torch.utils.data.DataLoader(img_cla_dataset[split], 279 | num_workers = num_worker, 280 | batch_size = batch_size, 281 | shuffle = True) 282 | elif split == "VALID" and dup_test: 283 | img_cla_dataloader[split] = img_cla_dataloader['TEST'] 284 | else: 285 | img_cla_dataloader[split] = torch.utils.data.DataLoader(img_cla_dataset[split], 286 | num_workers = num_worker, 287 | batch_size = batch_size, 288 | shuffle = False) 289 | 290 | 291 | return img_cla_dataset, img_cla_dataloader, split2num, data_split_col, train_sampler, num_classes 292 | 293 | def run_train(self): 294 | # assert "norm" in self.geom_type_list 295 | self.global_batch_idx = train_image_model( 296 | self.args, 297 | model = self.model, 298 | img_cla_dataloader = self.img_cla_dataloader, 299 | optimizer = self.optimizer, 300 | tb_writer = self.tb_writer, 301 | logger = self.logger, 302 | model_file = self.model_file, 303 | cla_epoches = self.args.cla_epoches, 304 | log_every = self.args.log_every, 305 | val_every = self.args.val_every, 306 | global_batch_idx = self.global_batch_idx) 307 | 308 | def load_model(self): 309 | self.model, self.optimizer, self.args = load_model(self.model, self.optimizer, self.model_file) -------------------------------------------------------------------------------- /polygoncode/polygonembed/dla_resnext.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNeXt for ImageNet-1K, implemented in PyTorch. 3 | Original paper: 'Aggregated Residual Transformations for Deep Neural Networks,' http://arxiv.org/abs/1611.05431. 4 | """ 5 | 6 | # __all__ = ['ResNeXt', 'resnext14_16x4d', 'resnext14_32x2d', 'resnext14_32x4d', 'resnext26_16x4d', 'resnext26_32x2d', 7 | # 'resnext26_32x4d', 'resnext38_32x4d', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 8 | # 'ResNeXtBottleneck', 'ResNeXtUnit'] 9 | __all__ = ['ResNeXtBottleneck'] 10 | 11 | import os 12 | import math 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | from polygonembed.dla_common import conv1x1_block, conv3x3_block 16 | from polygonembed.dla_resnet import ResInitBlock 17 | 18 | 19 | class ResNeXtBottleneck(nn.Module): 20 | """ 21 | ResNeXt bottleneck block for residual path in ResNeXt unit. 22 | 23 | Parameters: 24 | ---------- 25 | in_channels : int 26 | Number of input channels. 27 | out_channels : int 28 | Number of output channels. 29 | stride : int or tuple/list of 2 int 30 | Strides of the convolution. 31 | cardinality: int 32 | Number of groups. 33 | bottleneck_width: int 34 | Width of bottleneck block. 35 | bottleneck_factor : int, default 4 36 | Bottleneck factor. 37 | """ 38 | def __init__(self, 39 | in_channels, 40 | out_channels, 41 | stride, 42 | cardinality, 43 | bottleneck_width, 44 | bottleneck_factor=4): 45 | super(ResNeXtBottleneck, self).__init__() 46 | mid_channels = out_channels // bottleneck_factor 47 | D = int(math.floor(mid_channels * (bottleneck_width / 64.0))) 48 | group_width = cardinality * D 49 | 50 | self.conv1 = conv1x1_block( 51 | in_channels=in_channels, 52 | out_channels=group_width) 53 | self.conv2 = conv3x3_block( 54 | in_channels=group_width, 55 | out_channels=group_width, 56 | stride=stride, 57 | groups=cardinality) 58 | self.conv3 = conv1x1_block( 59 | in_channels=group_width, 60 | out_channels=out_channels, 61 | activation=None) 62 | 63 | def forward(self, x): 64 | x = self.conv1(x) 65 | x = self.conv2(x) 66 | x = self.conv3(x) 67 | return x 68 | 69 | 70 | # class ResNeXtUnit(nn.Module): 71 | # """ 72 | # ResNeXt unit with residual connection. 73 | 74 | # Parameters: 75 | # ---------- 76 | # in_channels : int 77 | # Number of input channels. 78 | # out_channels : int 79 | # Number of output channels. 80 | # stride : int or tuple/list of 2 int 81 | # Strides of the convolution. 82 | # cardinality: int 83 | # Number of groups. 84 | # bottleneck_width: int 85 | # Width of bottleneck block. 86 | # """ 87 | # def __init__(self, 88 | # in_channels, 89 | # out_channels, 90 | # stride, 91 | # cardinality, 92 | # bottleneck_width): 93 | # super(ResNeXtUnit, self).__init__() 94 | # self.resize_identity = (in_channels != out_channels) or (stride != 1) 95 | 96 | # self.body = ResNeXtBottleneck( 97 | # in_channels=in_channels, 98 | # out_channels=out_channels, 99 | # stride=stride, 100 | # cardinality=cardinality, 101 | # bottleneck_width=bottleneck_width) 102 | # if self.resize_identity: 103 | # self.identity_conv = conv1x1_block( 104 | # in_channels=in_channels, 105 | # out_channels=out_channels, 106 | # stride=stride, 107 | # activation=None) 108 | # self.activ = nn.ReLU(inplace=True) 109 | 110 | # def forward(self, x): 111 | # if self.resize_identity: 112 | # identity = self.identity_conv(x) 113 | # else: 114 | # identity = x 115 | # x = self.body(x) 116 | # x = x + identity 117 | # x = self.activ(x) 118 | # return x 119 | 120 | 121 | # class ResNeXt(nn.Module): 122 | # """ 123 | # ResNeXt model from 'Aggregated Residual Transformations for Deep Neural Networks,' http://arxiv.org/abs/1611.05431. 124 | 125 | # Parameters: 126 | # ---------- 127 | # channels : list of list of int 128 | # Number of output channels for each unit. 129 | # init_block_channels : int 130 | # Number of output channels for the initial unit. 131 | # cardinality: int 132 | # Number of groups. 133 | # bottleneck_width: int 134 | # Width of bottleneck block. 135 | # in_channels : int, default 3 136 | # Number of input channels. 137 | # in_size : tuple of two ints, default (224, 224) 138 | # Spatial size of the expected input image. 139 | # num_classes : int, default 1000 140 | # Number of classification classes. 141 | # """ 142 | # def __init__(self, 143 | # channels, 144 | # init_block_channels, 145 | # cardinality, 146 | # bottleneck_width, 147 | # in_channels=3, 148 | # in_size=(224, 224), 149 | # num_classes=1000): 150 | # super(ResNeXt, self).__init__() 151 | # self.in_size = in_size 152 | # self.num_classes = num_classes 153 | 154 | # self.features = nn.Sequential() 155 | # self.features.add_module("init_block", ResInitBlock( 156 | # in_channels=in_channels, 157 | # out_channels=init_block_channels)) 158 | # in_channels = init_block_channels 159 | # for i, channels_per_stage in enumerate(channels): 160 | # stage = nn.Sequential() 161 | # for j, out_channels in enumerate(channels_per_stage): 162 | # stride = 2 if (j == 0) and (i != 0) else 1 163 | # stage.add_module("unit{}".format(j + 1), ResNeXtUnit( 164 | # in_channels=in_channels, 165 | # out_channels=out_channels, 166 | # stride=stride, 167 | # cardinality=cardinality, 168 | # bottleneck_width=bottleneck_width)) 169 | # in_channels = out_channels 170 | # self.features.add_module("stage{}".format(i + 1), stage) 171 | # self.features.add_module("final_pool", nn.AvgPool2d( 172 | # kernel_size=7, 173 | # stride=1)) 174 | 175 | # self.output = nn.Linear( 176 | # in_features=in_channels, 177 | # out_features=num_classes) 178 | 179 | # self._init_params() 180 | 181 | # def _init_params(self): 182 | # for name, module in self.named_modules(): 183 | # if isinstance(module, nn.Conv2d): 184 | # init.kaiming_uniform_(module.weight) 185 | # if module.bias is not None: 186 | # init.constant_(module.bias, 0) 187 | 188 | # def forward(self, x): 189 | # x = self.features(x) 190 | # x = x.view(x.size(0), -1) 191 | # x = self.output(x) 192 | # return x 193 | 194 | 195 | # def get_resnext(blocks, 196 | # cardinality, 197 | # bottleneck_width, 198 | # model_name=None, 199 | # pretrained=False, 200 | # root=os.path.join("~", ".torch", "models"), 201 | # **kwargs): 202 | # """ 203 | # Create ResNeXt model with specific parameters. 204 | 205 | # Parameters: 206 | # ---------- 207 | # blocks : int 208 | # Number of blocks. 209 | # cardinality: int 210 | # Number of groups. 211 | # bottleneck_width: int 212 | # Width of bottleneck block. 213 | # model_name : str or None, default None 214 | # Model name for loading pretrained model. 215 | # pretrained : bool, default False 216 | # Whether to load the pretrained weights for model. 217 | # root : str, default '~/.torch/models' 218 | # Location for keeping the model parameters. 219 | # """ 220 | 221 | # if blocks == 14: 222 | # layers = [1, 1, 1, 1] 223 | # elif blocks == 26: 224 | # layers = [2, 2, 2, 2] 225 | # elif blocks == 38: 226 | # layers = [3, 3, 3, 3] 227 | # elif blocks == 50: 228 | # layers = [3, 4, 6, 3] 229 | # elif blocks == 101: 230 | # layers = [3, 4, 23, 3] 231 | # else: 232 | # raise ValueError("Unsupported ResNeXt with number of blocks: {}".format(blocks)) 233 | 234 | # assert (sum(layers) * 3 + 2 == blocks) 235 | 236 | # init_block_channels = 64 237 | # channels_per_layers = [256, 512, 1024, 2048] 238 | 239 | # channels = [[ci] * li for (ci, li) in zip(channels_per_layers, layers)] 240 | 241 | # net = ResNeXt( 242 | # channels=channels, 243 | # init_block_channels=init_block_channels, 244 | # cardinality=cardinality, 245 | # bottleneck_width=bottleneck_width, 246 | # **kwargs) 247 | 248 | # if pretrained: 249 | # if (model_name is None) or (not model_name): 250 | # raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") 251 | # from .model_store import download_model 252 | # download_model( 253 | # net=net, 254 | # model_name=model_name, 255 | # local_model_store_dir_path=root) 256 | 257 | # return net 258 | 259 | 260 | # def resnext14_16x4d(**kwargs): 261 | # """ 262 | # ResNeXt-14 (16x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 263 | # http://arxiv.org/abs/1611.05431. 264 | 265 | # Parameters: 266 | # ---------- 267 | # pretrained : bool, default False 268 | # Whether to load the pretrained weights for model. 269 | # root : str, default '~/.torch/models' 270 | # Location for keeping the model parameters. 271 | # """ 272 | # return get_resnext(blocks=14, cardinality=16, bottleneck_width=4, model_name="resnext14_16x4d", **kwargs) 273 | 274 | 275 | # def resnext14_32x2d(**kwargs): 276 | # """ 277 | # ResNeXt-14 (32x2d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 278 | # http://arxiv.org/abs/1611.05431. 279 | 280 | # Parameters: 281 | # ---------- 282 | # pretrained : bool, default False 283 | # Whether to load the pretrained weights for model. 284 | # root : str, default '~/.torch/models' 285 | # Location for keeping the model parameters. 286 | # """ 287 | # return get_resnext(blocks=14, cardinality=32, bottleneck_width=2, model_name="resnext14_32x2d", **kwargs) 288 | 289 | 290 | # def resnext14_32x4d(**kwargs): 291 | # """ 292 | # ResNeXt-14 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 293 | # http://arxiv.org/abs/1611.05431. 294 | 295 | # Parameters: 296 | # ---------- 297 | # pretrained : bool, default False 298 | # Whether to load the pretrained weights for model. 299 | # root : str, default '~/.torch/models' 300 | # Location for keeping the model parameters. 301 | # """ 302 | # return get_resnext(blocks=14, cardinality=32, bottleneck_width=4, model_name="resnext14_32x4d", **kwargs) 303 | 304 | 305 | # def resnext26_16x4d(**kwargs): 306 | # """ 307 | # ResNeXt-26 (16x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 308 | # http://arxiv.org/abs/1611.05431. 309 | 310 | # Parameters: 311 | # ---------- 312 | # pretrained : bool, default False 313 | # Whether to load the pretrained weights for model. 314 | # root : str, default '~/.torch/models' 315 | # Location for keeping the model parameters. 316 | # """ 317 | # return get_resnext(blocks=26, cardinality=16, bottleneck_width=4, model_name="resnext26_16x4d", **kwargs) 318 | 319 | 320 | # def resnext26_32x2d(**kwargs): 321 | # """ 322 | # ResNeXt-26 (32x2d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 323 | # http://arxiv.org/abs/1611.05431. 324 | 325 | # Parameters: 326 | # ---------- 327 | # pretrained : bool, default False 328 | # Whether to load the pretrained weights for model. 329 | # root : str, default '~/.torch/models' 330 | # Location for keeping the model parameters. 331 | # """ 332 | # return get_resnext(blocks=26, cardinality=32, bottleneck_width=2, model_name="resnext26_32x2d", **kwargs) 333 | 334 | 335 | # def resnext26_32x4d(**kwargs): 336 | # """ 337 | # ResNeXt-26 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 338 | # http://arxiv.org/abs/1611.05431. 339 | 340 | # Parameters: 341 | # ---------- 342 | # pretrained : bool, default False 343 | # Whether to load the pretrained weights for model. 344 | # root : str, default '~/.torch/models' 345 | # Location for keeping the model parameters. 346 | # """ 347 | # return get_resnext(blocks=26, cardinality=32, bottleneck_width=4, model_name="resnext26_32x4d", **kwargs) 348 | 349 | 350 | # def resnext38_32x4d(**kwargs): 351 | # """ 352 | # ResNeXt-38 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 353 | # http://arxiv.org/abs/1611.05431. 354 | 355 | # Parameters: 356 | # ---------- 357 | # pretrained : bool, default False 358 | # Whether to load the pretrained weights for model. 359 | # root : str, default '~/.torch/models' 360 | # Location for keeping the model parameters. 361 | # """ 362 | # return get_resnext(blocks=38, cardinality=32, bottleneck_width=4, model_name="resnext38_32x4d", **kwargs) 363 | 364 | 365 | # def resnext50_32x4d(**kwargs): 366 | # """ 367 | # ResNeXt-50 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 368 | # http://arxiv.org/abs/1611.05431. 369 | 370 | # Parameters: 371 | # ---------- 372 | # pretrained : bool, default False 373 | # Whether to load the pretrained weights for model. 374 | # root : str, default '~/.torch/models' 375 | # Location for keeping the model parameters. 376 | # """ 377 | # return get_resnext(blocks=50, cardinality=32, bottleneck_width=4, model_name="resnext50_32x4d", **kwargs) 378 | 379 | 380 | # def resnext101_32x4d(**kwargs): 381 | # """ 382 | # ResNeXt-101 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 383 | # http://arxiv.org/abs/1611.05431. 384 | 385 | # Parameters: 386 | # ---------- 387 | # pretrained : bool, default False 388 | # Whether to load the pretrained weights for model. 389 | # root : str, default '~/.torch/models' 390 | # Location for keeping the model parameters. 391 | # """ 392 | # return get_resnext(blocks=101, cardinality=32, bottleneck_width=4, model_name="resnext101_32x4d", **kwargs) 393 | 394 | 395 | # def resnext101_64x4d(**kwargs): 396 | # """ 397 | # ResNeXt-101 (64x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,' 398 | # http://arxiv.org/abs/1611.05431. 399 | 400 | # Parameters: 401 | # ---------- 402 | # pretrained : bool, default False 403 | # Whether to load the pretrained weights for model. 404 | # root : str, default '~/.torch/models' 405 | # Location for keeping the model parameters. 406 | # """ 407 | # return get_resnext(blocks=101, cardinality=64, bottleneck_width=4, model_name="resnext101_64x4d", **kwargs) 408 | 409 | 410 | # def _calc_width(net): 411 | # import numpy as np 412 | # net_params = filter(lambda p: p.requires_grad, net.parameters()) 413 | # weight_count = 0 414 | # for param in net_params: 415 | # weight_count += np.prod(param.size()) 416 | # return weight_count 417 | 418 | 419 | # def _test(): 420 | # import torch 421 | 422 | # pretrained = False 423 | 424 | # models = [ 425 | # resnext14_16x4d, 426 | # resnext14_32x2d, 427 | # resnext14_32x4d, 428 | # resnext26_16x4d, 429 | # resnext26_32x2d, 430 | # resnext26_32x4d, 431 | # resnext38_32x4d, 432 | # resnext50_32x4d, 433 | # resnext101_32x4d, 434 | # resnext101_64x4d, 435 | # ] 436 | 437 | # for model in models: 438 | 439 | # net = model(pretrained=pretrained) 440 | 441 | # # net.train() 442 | # net.eval() 443 | # weight_count = _calc_width(net) 444 | # print("m={}, {}".format(model.__name__, weight_count)) 445 | # assert (model != resnext14_16x4d or weight_count == 7127336) 446 | # assert (model != resnext14_32x2d or weight_count == 7029416) 447 | # assert (model != resnext14_32x4d or weight_count == 9411880) 448 | # assert (model != resnext26_16x4d or weight_count == 10119976) 449 | # assert (model != resnext26_32x2d or weight_count == 9924136) 450 | # assert (model != resnext26_32x4d or weight_count == 15389480) 451 | # assert (model != resnext38_32x4d or weight_count == 21367080) 452 | # assert (model != resnext50_32x4d or weight_count == 25028904) 453 | # assert (model != resnext101_32x4d or weight_count == 44177704) 454 | # assert (model != resnext101_64x4d or weight_count == 83455272) 455 | 456 | # x = torch.randn(1, 3, 224, 224) 457 | # y = net(x) 458 | # y.sum().backward() 459 | # assert (tuple(y.size()) == (1, 1000)) 460 | 461 | 462 | if __name__ == "__main__": 463 | _test() 464 | -------------------------------------------------------------------------------- /polygoncode/polygonembed/ddsl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | from math import pi 6 | 7 | from polygonembed.ddsl_utils import * 8 | 9 | 10 | 11 | class SimplexFT(Function): 12 | """ 13 | Fourier transform for signal defined on a j-simplex set in R^n space 14 | :param V: vertex tensor. float tensor of shape (n_vertex, n_dims) 15 | :param E: element tensor. int tensor of shape (n_elem, j or j+1) 16 | if j cols, triangulate/tetrahedronize interior first. 17 | :param D: int ndarray of shape (n_elem, n_channel) 18 | :param res: n_dims int tuple of number of frequency modes 19 | :param t: n_dims tuple of period in each dimension 20 | :param j: dimension of simplex set 21 | :param mode: normalization mode. 22 | 'density' for preserving density, 'mass' for preserving mass 23 | :return: F: ndarray of shape (res[0], res[1], ..., res[-1]/2, n_channel) 24 | last dimension is halfed since the signal is assumed to be real 25 | """ 26 | @staticmethod 27 | def forward(ctx, V, E, D, res, t, j, 28 | min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "fft", 29 | elem_batch=100, mode='density'): 30 | ''' 31 | Args: 32 | V: vertex tensor. float tensor of shape (batch_size, num_vert, n_dims = 2) 33 | E: element tensor. int tensor of shape (batch_size, n_elem = num_vert, j or j+1) 34 | if j cols, triangulate/tetrahedronize interior first. 35 | (num_vert, 2), indicate the connectivity 36 | 37 | 38 | D: int ndarray of shape (batch_size, n_elem = num_vert, n_channel = 1), all 1 39 | res: n_dims int tuple of number of frequency modes, (fx, fy) for polygon 40 | t: n_dims tuple of period in each dimension, (tx, ty) for polygon 41 | j: dimension of simplex set, 2 for polygon 42 | max_freqXY: the maximum frequency 43 | min_freqXY: the minimum frequency 44 | freq_init: frequency generated method, 45 | "geometric": geometric series 46 | "fft": fast fourier transformation 47 | mode: normalization mode. 48 | 'density' for preserving density, 'mass' for preserving mass 49 | Return: 50 | F: ndarray of shape (res[0], res[1], ..., res[-1]/2, n_channel) 51 | last dimension is halfed since the signal is assumed to be real 52 | Shape (fx, fy//2+1, n_channel = 1, 2) for polygon case 53 | ''' 54 | ## check if E is subdim 55 | ''' 56 | subdim: See ddsl_illustration.jpg 57 | True: V is a sequence of vertex and E is sequence of lines 58 | False: V is a sequence of vertex + one and E is sequence of triangles 59 | ''' 60 | assert V.shape[0] == E.shape[0] == D.shape[0] # batch_size 61 | batch_size = V.shape[0] 62 | subdim = E.shape[-1] == j and V.shape[-1] == j 63 | assert (E.shape[-1] == j+1 or subdim) 64 | # assert V.shape[0] == E.shape[0] == D.shape[0] # batch_size dim should match 65 | 66 | if subdim: 67 | # make sure all D has same density 68 | D_repeat = torch.repeat_interleave(D[:, 0].unsqueeze(1), repeats = D.shape[1], dim=1) 69 | assert((D == D_repeat).sum().item() == D.numel()) # assert same densities for all simplices (homogeneous filling) 70 | # add (0,0) as an auxilary vertex in V 71 | # aux_vert_mat: shape (batch_size, 1, n_dims = 2) 72 | aux_vert_mat = torch.zeros((batch_size, 1, V.shape[-1]), device=V.device, dtype=V.dtype) 73 | # V: (batch_size, num_vert + 1, n_dims = 2) 74 | V = torch.cat((V, aux_vert_mat), dim=1) 75 | 76 | # the index of the added auxilary vertex (0, 0) 77 | aux_vert_idx = V.shape[1] - 1 78 | # add_aux_vert_mat: (batch_size, n_elem = num_vert, 1) 79 | # values are the index of the added auxilary vertex 80 | add_aux_vert_mat = torch.zeros((batch_size, E.shape[1], 1), device=E.device, dtype=E.dtype) + aux_vert_idx 81 | # E: (batch_size, n_elem = num_vert, j+1 = 3), add (0, 0) as the 3rd vertice for each line in E, we have construct all 2-simplex 82 | E = torch.cat((E, add_aux_vert_mat), dim=-1) 83 | 84 | n_elem = E.shape[-2] # n_elem = num_vert, number of 2-simplex, the triangles 85 | n_vert = V.shape[-2] # n_vert = num_vert + 1, number of vertex 86 | n_channel = D.shape[-1] # n_channel = 1 87 | 88 | ## save context info for backwards 89 | ctx.mark_non_differentiable(E, D) # mark non-differentiable 90 | ctx.res = res 91 | ctx.t = t 92 | ctx.j = j 93 | ctx.mode = mode 94 | ctx.n_dims = V.shape[-1] # n_dims = 2 95 | ctx.elem_batch = elem_batch # elem_batch = 100 96 | ctx.subdim = subdim 97 | 98 | ctx.min_freqXY = min_freqXY 99 | ctx.max_freqXY = max_freqXY 100 | ctx.mid_freqXY = mid_freqXY, 101 | ctx.freq_init = freq_init 102 | 103 | 104 | 105 | 106 | # compute content array 107 | # C: normalized simple content, \gama_n^j, shape (batch_size, n_elem = num_vert, 1) 108 | # unsigned: Equation 6 in https://openreview.net/pdf?id=B1G5ViAqFm 109 | # signed: Equation 8 in https://openreview.net/pdf?id=B1G5ViAqFm 110 | C = math.factorial(j) * simplex_content(V, E, signed=subdim) # [n_elem, 1] 111 | ctx.save_for_backward(V, E, D, C) 112 | 113 | ## compute frequencies F 114 | n_dims = ctx.n_dims 115 | assert(n_dims == len(res)) # consistent spacial dimensionality 116 | assert(E.shape[1] == D.shape[1]) # consistent vertex numbers 117 | assert(mode in ['density', 'mass']) 118 | 119 | 120 | 121 | # frequency tensor 122 | ''' 123 | omega: the fft frequance matrix 124 | if extract = True 125 | for res = (fx, fy) => shape (fx, fy//2+1, 2) 126 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2+1, n = n_dims) 127 | ''' 128 | # omega = fftfreqs(res, dtype=V.dtype).to(V.device) # [dim0, dim1, dim2, d] 129 | # omega: [dim0, dim1, dim2, d] 130 | omega = get_fourier_freqs(res = res, 131 | min_freqXY = min_freqXY, 132 | max_freqXY = max_freqXY, 133 | mid_freqXY = mid_freqXY, 134 | freq_init = freq_init, 135 | dtype=V.dtype).to(V.device) 136 | 137 | # normalize frequencies 138 | for dim in range(n_dims): 139 | omega[..., dim] *= 2 * math.pi / t[dim] 140 | 141 | 142 | 143 | # initialize output F 144 | # F_shape = [fx, fy//2+1] 145 | F_shape = list(omega.shape)[:-1] 146 | # F_shape = [fx, fy//2+1, n_channel = 1, 2] 147 | F_shape += [n_channel, 2] 148 | # F_shape = [batch_size, fx, fy//2+1, n_channel = 1, 2] 149 | F_shape = [batch_size] + F_shape 150 | # F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2) 151 | F = torch.zeros(*F_shape, dtype=V.dtype, device=V.device) # [dimX, dimY, dimZ, n_chan, 2] 2: real/imag 152 | 153 | # compute element-point tensor 154 | # P: point tensor. float, shape (batch_size, n_elem = num_vert, j+1 = 3, n_dims = 2) 155 | # P = V[E] 156 | P = coord_P_lookup_from_V_E(V, E) 157 | 158 | 159 | 160 | # loop over element/simplex batches 161 | for idx in range(math.ceil(n_elem/elem_batch)): 162 | id_start = idx * elem_batch 163 | id_end = min((idx+1) * elem_batch, n_elem) 164 | # Xi: coordinate mat, shape [batch_size, elem_batch, j+1, n_dims = 2] 165 | Xi = P[:, id_start:id_end] 166 | # Di: simple density mat, shape [batch_size, elem_batch, n_channel = 1] 167 | Di = D[:, id_start:id_end] 168 | # Ci: normalized simple content, \gama_n^j, Equ. 6, shape [batch_size, elem_batch, 1] 169 | Ci = C[:, id_start:id_end] 170 | # CDi: \gama_n^j * \rho_n 171 | # CDi: shape [batch_size, elem_batch, n_channel] 172 | CDi = Ci.expand_as(Di) * Di 173 | # sig: shape (batch_size, elem_batch, j+1, fx, fy//2+1) 174 | # k dot x, Equation 3 in https://openreview.net/pdf?id=B1G5ViAqFm 175 | sig = torch.einsum('lbjd,...d->lbj...', (Xi, omega)) 176 | 177 | 178 | # sig: shape (batch_size, elem_batch, j+1, fx, fy//2+1, 1) 179 | sig = torch.unsqueeze(sig, dim=-1) # [elem_batch, j+1, dimX, dimY, dimZ, 1] 180 | # esig: e^{-i*\sigma_t}, Euler's formular, 181 | # shape (batch_size, elem_batch, j+1, fx, fy//2+1, 1, 2), 182 | esig = torch.stack((torch.cos(sig), -torch.sin(sig)), dim=-1) # [elem_batch, j+1, dimX, dimY, dimZ, 1, 2] 183 | # sig: shape (batch_size, elem_batch, j+1, fx, fy//2+1, 1, 1) 184 | sig = torch.unsqueeze(sig, dim=-1) # [elem_batch, j+1, dimX, dimY, dimZ, 1, 1] 185 | 186 | 187 | 188 | # denom: prod(sigma_t - sigm_l), the demoniator of Equation 7 in https://openreview.net/pdf?id=B1G5ViAqFm 189 | # denom: shape (batch_size, elem_batch, j+1, fx, fy//2+1, 1, 1) 190 | denom = torch.ones_like(sig) # [elem_batch, j+1, dimX, dimY, dimZ, 1, 1] 191 | for dim in range(1, j+1): 192 | seq = permute_seq(dim, j+1) 193 | denom *= sig - sig[:, :, seq] 194 | # tmp: shape (batch_size, elem_batch, fx, fy//2+1, 1, 2) 195 | # sum e^{-i*\sigma_t}/prod(sigma_t - sigm_l), 196 | # signed context, Equation 4 in https://openreview.net/pdf?id=B1G5ViAqFm 197 | tmp = torch.sum(esig / denom, dim=2) # [elem_batch, dimX, dimY, dimZ, 1, 2] 198 | 199 | # select all cases when denom == 0 200 | mask = ((denom == 0).repeat_interleave(repeats = 2, dim=-1).sum(dim = 2) > 0) 201 | # mask all cases as 0 202 | tmp[mask] = 0 203 | 204 | # CDi: shape (batch_size, elem_batch, n_channel, 1) 205 | CDi.unsqueeze_(-1) # [elem_batch, n_channel, 1] 206 | # CDi: shape (batch_size, elem_batch, 1, 1, n_channel, 1) 207 | for _ in range(n_dims): # unsqueeze to broadcast 208 | CDi.unsqueeze_(dim=2) # [elem_batch, 1, 1, 1, n_channel, 2] 209 | 210 | # shape_ = (batch_size, elem_batch, fx, fy//2+1, n_channel, 2) 211 | shape_ = (list(tmp.shape[:-2])+[n_channel, 2]) 212 | # tmp: (batch_size, elem_batch, fx, fy//2+1, n_channel, 2) 213 | # \gama_n^j * \rho_n * sum e^{-i*\sigma_t}/prod(sigma_t - sigm_l) 214 | tmp = tmp * CDi # [elem_batch, dimX, dimY, dimZ, n_channel, 2] 215 | # Fi: shape (batch_size, fx, fy//2+1, n_channel, 2) 216 | Fi = torch.sum(tmp, dim=1, keepdim=False) # [dimX, dimY, dimZ, n_channel, 2] 217 | 218 | 219 | # CDi_: shape (batch_size, 1, 1, n_channel, 1) 220 | # the sum of the polygon content 221 | CDi_ = torch.sum(CDi, dim=1) 222 | # CDi_: shape (batch_size, n_channel, 1) 223 | for _ in range(n_dims): # squeeze dims 224 | CDi_.squeeze_(dim=1) # [n_channel, 2] 225 | 226 | 227 | # Fi[:, tuple([0] * n_dims)] = - 1 / factorial(j) * CDi_ # ????? 228 | 229 | # Fi: shape (batch_size, fx, fy//2+1, n_channel = 1, 2) 230 | if n_dims == 2: 231 | Fi[:, 0, 0] = - 1 / math.factorial(j) * CDi_ 232 | elif n_dims == 3: 233 | Fi[:, 0, 0, 0] = - 1 / math.factorial(j) * CDi_ 234 | else: 235 | raise Exception("n_dims is not 2 or 3") 236 | F += Fi 237 | 238 | # F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2) 239 | # multiply tensor F by i ** j 240 | # see Equation 4 and 7 in https://openreview.net/pdf?id=B1G5ViAqFm 241 | F = img(F, deg=j) # Fi *= i**j [dimX, dimY, dimZ, n_chan, 2] 2: real/imag 242 | 243 | if mode == 'density': 244 | res_t = torch.tensor(res) 245 | if not torch.equal(res_t, res[0]*torch.ones(len(res), dtype=res_t.dtype)): 246 | print("WARNING: density preserving mode not correctly implemented if not all res are equal") 247 | F *= res[0] ** j 248 | return F 249 | 250 | 251 | class DDSL_spec(nn.Module): 252 | """ 253 | Module for DDSL layer. Takes in a simplex mesh and returns the spectral raster. 254 | """ 255 | def __init__(self, res, t, j, 256 | min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "fft", 257 | elem_batch=100, mode='density'): 258 | """ 259 | Args: 260 | res: n_dims int tuple of number of frequency modes 261 | t: n_dims tuple of period in each dimension 262 | j: dimension of simplex set 263 | max_freqXY: the maximum frequency 264 | min_freqXY: the minimum frequency 265 | freq_init: frequency generated method, 266 | "geometric": geometric series 267 | "fft": fast fourier transformation 268 | elem_batch: element-wise batch size. 269 | mode: 'density' for density conserving, or 'mass' for mess conserving. Defaults 'density' 270 | """ 271 | super(DDSL_spec, self).__init__() 272 | self.res = res 273 | self.t = t 274 | self.j = j 275 | self.min_freqXY = min_freqXY 276 | self.max_freqXY = max_freqXY 277 | self.mid_freqXY = mid_freqXY 278 | self.freq_init = freq_init 279 | self.elem_batch = elem_batch 280 | self.mode = mode 281 | def forward(self, V, E, D): 282 | """ 283 | V: vertex tensor. float tensor of shape (batch_size, num_vert, n_dims = 2) 284 | E: element tensor. int tensor of shape (batch_size, n_elem = num_vert, j or j+1) 285 | if j cols, triangulate/tetrahedronize interior first. 286 | (num_vert, 2), indicate the connectivity 287 | D: int ndarray of shape (batch_size, n_elem, n_channel) 288 | :return F: ndarray of shape (res[0], res[1], ..., res[-1]/2, n_channel) 289 | last dimension is halfed since the signal is assumed to be real 290 | F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2) 291 | """ 292 | V, D = V.double(), D.double() 293 | return SimplexFT.apply(V,E,D,self.res,self.t,self.j, 294 | self.min_freqXY, self.max_freqXY, self.mid_freqXY, self.freq_init, 295 | self.elem_batch,self.mode) 296 | 297 | 298 | class DDSL_phys(nn.Module): 299 | """ 300 | Module for DDSL layer. Takes in a simplex mesh and returns a dealiased raster image (in physical domain). 301 | """ 302 | def __init__(self, res, t, j, 303 | min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "fft", 304 | smoothing='gaussian', sig=2.0, elem_batch=100, mode='density'): 305 | """ 306 | Args: 307 | res: n_dims int tuple of number of frequency modes 308 | t: n_dims tuple of period in each dimension 309 | j: dimension of simplex set 310 | max_freqXY: the maximum frequency 311 | min_freqXY: the minimum frequency 312 | freq_init: frequency generated method, 313 | "geometric": geometric series 314 | "fft": fast fourier transformation 315 | smoothing: str, choice of spectral smoothing function. Defaults 'gaussian' 316 | sig: sigma of gaussian at highest frequency 317 | elem_batch: element-wise batch size. 318 | mode: 'density' for density conserving, or 'mass' for mess conserving. Defaults 'density' 319 | """ 320 | super(DDSL_phys, self).__init__() 321 | self.res = res 322 | self.t = t 323 | self.j = j 324 | self.min_freqXY = min_freqXY 325 | self.max_freqXY = max_freqXY 326 | self.mid_freqXY = mid_freqXY 327 | self.freq_init = freq_init 328 | self.elem_batch = elem_batch 329 | self.mode = mode 330 | self.filter = None 331 | self.sig = sig 332 | 333 | if isinstance(smoothing, str): 334 | assert(smoothing in ["gaussian"]) 335 | if smoothing == 'gaussian': 336 | # filter: shape (fx, fy//2+1, 1, 1) 337 | self.filter = self._gaussian_filter() 338 | 339 | def forward(self, V, E, D): 340 | """ 341 | V: vertex tensor. float tensor of shape (batch_size, num_vert, n_dims = 2) 342 | E: element tensor. int tensor of shape (batch_size, n_elem = num_vert, j or j+1) 343 | if j cols, triangulate/tetrahedronize interior first. 344 | (num_vert, 2), indicate the connectivity 345 | :param D: int ndarray of shape (batch_size, n_elem, n_channel) 346 | Return: 347 | f: dealiased raster image in physical domain of shape (batch_size, res[0], res[1], ..., res[-1], n_channel) 348 | shape (batch_size, fx, fy, n_channel = 1) 349 | """ 350 | 351 | V, D = V.double(), D.double() 352 | # F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2) for polygon case 353 | F = SimplexFT.apply(V,E,D,self.res,self.t,self.j, 354 | self.min_freqXY, self.max_freqXY, self.mid_freqXY, self.freq_init, 355 | self.elem_batch,self.mode) 356 | F[torch.isnan(F)] = 0 # pad nans to 0 357 | if self.filter is not None: 358 | # filter: shape (fx, fy//2+1, 1, 1) 359 | self.filter = self.filter.to(F.device) 360 | # filter_: shape (fx, fy//2+1, 1, 2) 361 | filter_ = torch.repeat_interleave(self.filter, repeats = F.shape[-1], dim=-1) 362 | # F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2) for polygon case 363 | F *= filter_ # [dim0, dim1, dim2, n_channel, 2] 364 | dim = len(self.res) 365 | # F: shape (batch_size, n_channel = 1, fx, fy//2+1, 2) for polygon case 366 | F = F.permute(*([0, dim+1] + list(range(1, dim+1)) + [dim+2])) # [n_channel, dim0, dim1, dim2, 2] 367 | # f: shape (batch_size, n_channel = 1, fx, fy) 368 | f = torch.irfft(F, dim, signal_sizes=self.res) 369 | # f: shape (batch_size, fx, fy, n_channel = 1) 370 | f = f.permute(*([0] + list(range(2, 2+dim)) + [1])) 371 | 372 | return f 373 | 374 | def _gaussian_filter(self): 375 | ''' 376 | Return: 377 | filter_: shape (fx, fy//2+1, 1, 1) 378 | ''' 379 | 380 | # omega = fftfreqs(self.res, dtype=torch.float64) # [dim0, dim1, dim2, d] 381 | 382 | # omega: shape (fx, fy//2+1, 2) 383 | omega = get_fourier_freqs(res = self.res, 384 | min_freqXY = self.min_freqXY, 385 | max_freqXY = self.max_freqXY, 386 | mid_freqXY = self.mid_freqXY, 387 | freq_init = self.freq_init, 388 | dtype=torch.float64) 389 | # dis: shape (fx, fy//2+1) 390 | dis = torch.sqrt(torch.sum(omega ** 2, dim=-1)) 391 | # filter_: shape (fx, fy//2+1, 1, 1) 392 | filter_ = torch.exp(-0.5*((self.sig*2*dis/self.res[0])**2)).unsqueeze(-1).unsqueeze(-1) 393 | filter_.requires_grad = False 394 | return filter_ -------------------------------------------------------------------------------- /polygoncode/polygonembed/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | 6 | import torch.utils.data 7 | import math 8 | import numpy as np 9 | 10 | class LayerNorm(nn.Module): 11 | """ 12 | layer normalization 13 | Simple layer norm object optionally used with the convolutional encoder. 14 | """ 15 | 16 | def __init__(self, feature_dim, eps=1e-6): 17 | super(LayerNorm, self).__init__() 18 | self.gamma = nn.Parameter(torch.ones((feature_dim,))) 19 | self.register_parameter("gamma", self.gamma) 20 | self.beta = nn.Parameter(torch.zeros((feature_dim,))) 21 | self.register_parameter("beta", self.beta) 22 | self.eps = eps 23 | 24 | def forward(self, x): 25 | # x: [batch_size, embed_dim] 26 | # normalize for each embedding 27 | mean = x.mean(-1, keepdim=True) 28 | std = x.std(-1, keepdim=True) 29 | # output shape is the same as x 30 | # Type not match for self.gamma and self.beta?????????????????????? 31 | # output: [batch_size, embed_dim] 32 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 33 | 34 | def get_activation_function(activation, context_str): 35 | if activation == "leakyrelu": 36 | return nn.LeakyReLU(negative_slope=0.2) 37 | elif activation == "relu": 38 | return nn.ReLU() 39 | elif activation == "sigmoid": 40 | return nn.Sigmoid() 41 | elif activation == 'tanh': 42 | return nn.Tanh() 43 | else: 44 | raise Exception("{} activation not recognized.".format(context_str)) 45 | 46 | 47 | def coord_normalize(coords, extent = (-1710000, -1690000, 1610000, 1640000)): 48 | """ 49 | Given a list of coords (X, Y), normalize them to [-1, 1] 50 | Args: 51 | coords: a python list with shape (batch_size, num_context_pt, coord_dim) 52 | extent: (x_min, x_max, y_min, y_max) 53 | Return: 54 | coords_mat: np tensor shape (batch_size, num_context_pt, coord_dim) 55 | """ 56 | if type(coords) == list: 57 | coords_mat = np.asarray(coords).astype(float) 58 | elif type(coords) == np.ndarray: 59 | coords_mat = coords 60 | 61 | # x => [0,1] min_max normalize 62 | x = (coords_mat[:,:,0] - extent[0])*1.0/(extent[1] - extent[0]) 63 | # x => [-1,1] 64 | coords_mat[:,:,0] = (x * 2) - 1 65 | 66 | # y => [0,1] min_max normalize 67 | y = (coords_mat[:,:,1] - extent[2])*1.0/(extent[3] - extent[2]) 68 | # y => [-1,1] 69 | coords_mat[:,:,1] = (y * 2) - 1 70 | 71 | return coords_mat 72 | 73 | class SingleFeedForwardNN(nn.Module): 74 | """ 75 | Creates a single layer fully connected feed forward neural network. 76 | this will use non-linearity, layer normalization, dropout 77 | this is for the hidden layer, not the last layer of the feed forard NN 78 | """ 79 | 80 | def __init__(self, input_dim, 81 | output_dim, 82 | dropout_rate=None, 83 | activation="sigmoid", 84 | use_layernormalize=False, 85 | skip_connection = False, 86 | context_str = ''): 87 | ''' 88 | 89 | Args: 90 | input_dim (int32): the input embedding dim 91 | output_dim (int32): dimension of the output of the network. 92 | dropout_rate (scalar tensor or float): Dropout keep prob. 93 | activation (string): tanh or relu or leakyrelu or sigmoid 94 | use_layernormalize (bool): do layer normalization or not 95 | skip_connection (bool): do skip connection or not 96 | context_str (string): indicate which spatial relation encoder is using the current FFN 97 | 98 | ''' 99 | super(SingleFeedForwardNN, self).__init__() 100 | self.input_dim = input_dim 101 | self.output_dim = output_dim 102 | 103 | if dropout_rate is not None: 104 | self.dropout = nn.Dropout(p=dropout_rate) 105 | else: 106 | self.dropout = None 107 | 108 | self.act = get_activation_function(activation, context_str) 109 | 110 | if use_layernormalize: 111 | # the layer normalization is only used in the hidden layer, not the last layer 112 | self.layernorm = nn.LayerNorm(self.output_dim) 113 | self.layernorm.weight.data.fill_(1) 114 | self.layernorm.bias.data.zero_() 115 | else: 116 | self.layernorm = None 117 | 118 | # the skip connection is only possible, if the input and out dimention is the same 119 | if self.input_dim == self.output_dim: 120 | self.skip_connection = skip_connection 121 | else: 122 | self.skip_connection = False 123 | 124 | self.linear = nn.Linear(self.input_dim, self.output_dim) 125 | nn.init.xavier_uniform_(self.linear.weight) 126 | nn.init.zeros_(self.linear.bias) 127 | 128 | 129 | 130 | 131 | 132 | def forward(self, input_tensor): 133 | ''' 134 | Args: 135 | input_tensor: shape [batch_size, ..., input_dim] 136 | Returns: 137 | tensor of shape [batch_size,..., output_dim] 138 | note there is no non-linearity applied to the output. 139 | 140 | Raises: 141 | Exception: If given activation or normalizer not supported. 142 | ''' 143 | assert input_tensor.size()[-1] == self.input_dim 144 | # Linear layer 145 | output = self.linear(input_tensor) 146 | # non-linearity 147 | output = self.act(output) 148 | # dropout 149 | if self.dropout is not None: 150 | output = self.dropout(output) 151 | 152 | # skip connection 153 | if self.skip_connection: 154 | output = output + input_tensor 155 | 156 | # layer normalization 157 | if self.layernorm is not None: 158 | output = self.layernorm(output) 159 | 160 | return output 161 | 162 | # num_rbf_anchor_pts = 100, rbf_kernal_size = 10e2, frequency_num = 16, 163 | # max_radius = 10000, dropout = 0.5, f_act = "sigmoid", freq_init = "geometric", 164 | # num_hidden_layer = 3, hidden_dim = 128, use_layn = "F", skip_connection = "F", use_post_mat = "T"): 165 | # if use_layn == "T": 166 | # use_layn = True 167 | # else: 168 | # use_layn = False 169 | # if skip_connection == "T": 170 | # skip_connection = True 171 | # else: 172 | # skip_connection = False 173 | # if use_post_mat == "T": 174 | # use_post_mat = True 175 | # else: 176 | # use_post_mat = False 177 | 178 | class MultiLayerFeedForwardNN(nn.Module): 179 | """ 180 | Creates a fully connected feed forward neural network. 181 | N fully connected feed forward NN, each hidden layer will use non-linearity, layer normalization, dropout 182 | The last layer do not have any of these 183 | """ 184 | 185 | def __init__(self, input_dim, 186 | output_dim, 187 | num_hidden_layers=0, 188 | dropout_rate=None, 189 | hidden_dim=-1, 190 | activation="sigmoid", 191 | use_layernormalize=False, 192 | skip_connection = False, 193 | context_str = None): 194 | ''' 195 | 196 | Args: 197 | input_dim (int32): the input embedding dim 198 | num_hidden_layers (int32): number of hidden layers in the network, set to 0 for a linear network. 199 | output_dim (int32): dimension of the output of the network. 200 | dropout (scalar tensor or float): Dropout keep prob. 201 | hidden_dim (int32): size of the hidden layers 202 | activation (string): tanh or relu 203 | use_layernormalize (bool): do layer normalization or not 204 | context_str (string): indicate which spatial relation encoder is using the current FFN 205 | 206 | ''' 207 | super(MultiLayerFeedForwardNN, self).__init__() 208 | self.input_dim = input_dim 209 | self.output_dim = output_dim 210 | self.num_hidden_layers = num_hidden_layers 211 | self.dropout_rate = dropout_rate 212 | self.hidden_dim = hidden_dim 213 | self.activation = activation 214 | self.use_layernormalize = use_layernormalize 215 | self.skip_connection = skip_connection 216 | self.context_str = context_str 217 | 218 | self.layers = nn.ModuleList() 219 | if self.num_hidden_layers <= 0: 220 | self.layers.append( SingleFeedForwardNN(input_dim = self.input_dim, 221 | output_dim = self.output_dim, 222 | dropout_rate = self.dropout_rate, 223 | activation = self.activation, 224 | use_layernormalize = False, 225 | skip_connection = False, 226 | context_str = self.context_str)) 227 | else: 228 | self.layers.append( SingleFeedForwardNN(input_dim = self.input_dim, 229 | output_dim = self.hidden_dim, 230 | dropout_rate = self.dropout_rate, 231 | activation = self.activation, 232 | use_layernormalize = self.use_layernormalize, 233 | skip_connection = self.skip_connection, 234 | context_str = self.context_str)) 235 | 236 | for i in range(self.num_hidden_layers-1): 237 | self.layers.append( SingleFeedForwardNN(input_dim = self.hidden_dim, 238 | output_dim = self.hidden_dim, 239 | dropout_rate = self.dropout_rate, 240 | activation = self.activation, 241 | use_layernormalize = self.use_layernormalize, 242 | skip_connection = self.skip_connection, 243 | context_str = self.context_str)) 244 | 245 | self.layers.append( SingleFeedForwardNN(input_dim = self.hidden_dim, 246 | output_dim = self.output_dim, 247 | dropout_rate = self.dropout_rate, 248 | activation = self.activation, 249 | use_layernormalize = False, 250 | skip_connection = False, 251 | context_str = self.context_str)) 252 | 253 | 254 | 255 | def forward(self, input_tensor): 256 | ''' 257 | Args: 258 | input_tensor: shape [batch_size, ..., input_dim] 259 | Returns: 260 | tensor of shape [batch_size, ..., output_dim] 261 | note there is no non-linearity applied to the output. 262 | 263 | Raises: 264 | Exception: If given activation or normalizer not supported. 265 | ''' 266 | assert input_tensor.size()[-1] == self.input_dim 267 | output = input_tensor 268 | for i in range(len(self.layers)): 269 | output = self.layers[i](output) 270 | 271 | return output 272 | 273 | 274 | 275 | class MultiLayerFeedForwardNNFlexible(nn.Module): 276 | """ 277 | Creates a fully connected feed forward neural network. 278 | N fully connected feed forward NN, each hidden layer will use non-linearity, layer normalization, dropout 279 | The last layer do not have any of these 280 | """ 281 | 282 | def __init__(self, input_dim, 283 | output_dim, 284 | hidden_layers=[], 285 | dropout_rate=None, 286 | activation="sigmoid", 287 | use_layernormalize=False, 288 | skip_connection = False, 289 | context_str = None): 290 | ''' 291 | 292 | Args: 293 | input_dim (int32): the input embedding dim 294 | hidden_layers (list): a list of hidden layer dimention 295 | output_dim (int32): dimension of the output of the network. 296 | dropout (scalar tensor or float): Dropout keep prob. 297 | activation (string): tanh or relu 298 | use_layernormalize (bool): do layer normalization or not 299 | context_str (string): indicate which spatial relation encoder is using the current FFN 300 | 301 | ''' 302 | super(MultiLayerFeedForwardNNFlexible, self).__init__() 303 | self.input_dim = input_dim 304 | self.output_dim = output_dim 305 | self.hidden_layers = hidden_layers 306 | self.dropout_rate = dropout_rate 307 | self.activation = activation 308 | self.use_layernormalize = use_layernormalize 309 | self.skip_connection = skip_connection 310 | self.context_str = context_str 311 | 312 | self.num_hidden_layers = len(self.hidden_layers) 313 | 314 | for dim in self.hidden_layers: 315 | assert type(dim) == int 316 | 317 | self.layers = nn.ModuleList() 318 | if self.num_hidden_layers == 0: 319 | self.layers.append( SingleFeedForwardNN(input_dim = self.input_dim, 320 | output_dim = self.output_dim, 321 | dropout_rate = self.dropout_rate, 322 | activation = self.activation, 323 | use_layernormalize = False, 324 | skip_connection = False, 325 | context_str = self.context_str)) 326 | else: 327 | self.layers.append( SingleFeedForwardNN(input_dim = self.input_dim, 328 | output_dim = self.hidden_layers[0], 329 | dropout_rate = self.dropout_rate, 330 | activation = self.activation, 331 | use_layernormalize = self.use_layernormalize, 332 | skip_connection = self.skip_connection, 333 | context_str = self.context_str)) 334 | 335 | for i in range(self.num_hidden_layers-1): 336 | self.layers.append( SingleFeedForwardNN(input_dim = self.hidden_layers[i], 337 | output_dim = self.hidden_layers[i+1], 338 | dropout_rate = self.dropout_rate, 339 | activation = self.activation, 340 | use_layernormalize = self.use_layernormalize, 341 | skip_connection = self.skip_connection, 342 | context_str = self.context_str)) 343 | 344 | self.layers.append( SingleFeedForwardNN(input_dim = self.hidden_layers[-1], 345 | output_dim = self.output_dim, 346 | dropout_rate = self.dropout_rate, 347 | activation = self.activation, 348 | use_layernormalize = False, 349 | skip_connection = False, 350 | context_str = self.context_str)) 351 | 352 | 353 | 354 | def forward(self, input_tensor): 355 | ''' 356 | Args: 357 | input_tensor: shape [batch_size, ..., input_dim] 358 | Returns: 359 | tensor of shape [batch_size, ..., output_dim] 360 | note there is no non-linearity applied to the output. 361 | 362 | Raises: 363 | Exception: If given activation or normalizer not supported. 364 | ''' 365 | assert input_tensor.size()[-1] == self.input_dim 366 | output = input_tensor 367 | for i in range(len(self.layers)): 368 | output = self.layers[i](output) 369 | 370 | return output 371 | 372 | 373 | # from Presence-Only Geographical Priors for Fine-Grained Image Classification 374 | # www.vision.caltech.edu/~macaodha/projects/geopriors 375 | 376 | class ResLayer(nn.Module): 377 | def __init__(self, linear_size): 378 | super(ResLayer, self).__init__() 379 | self.l_size = linear_size 380 | self.nonlin1 = nn.ReLU(inplace=True) 381 | self.nonlin2 = nn.ReLU(inplace=True) 382 | self.dropout1 = nn.Dropout() 383 | self.w1 = nn.Linear(self.l_size, self.l_size) 384 | self.w2 = nn.Linear(self.l_size, self.l_size) 385 | 386 | def forward(self, x): 387 | y = self.w1(x) 388 | y = self.nonlin1(y) 389 | y = self.dropout1(y) 390 | y = self.w2(y) 391 | y = self.nonlin2(y) 392 | out = x + y 393 | 394 | return out 395 | 396 | 397 | class FCNet(nn.Module): 398 | # def __init__(self, num_inputs, num_classes, num_filts, num_users=1): 399 | def __init__(self, num_inputs, num_filts, num_hidden_layers): 400 | ''' 401 | Args: 402 | num_inputs: input embedding diemntion 403 | num_filts: hidden embedding dimention 404 | num_hidden_layers: number of hidden layer 405 | ''' 406 | super(FCNet, self).__init__() 407 | # self.inc_bias = False 408 | # self.class_emb = nn.Linear(num_filts, num_classes, bias=self.inc_bias) 409 | # self.user_emb = nn.Linear(num_filts, num_users, bias=self.inc_bias) 410 | 411 | # self.feats = nn.Sequential(nn.Linear(num_inputs, num_filts), 412 | # nn.ReLU(inplace=True), 413 | # ResLayer(num_filts), 414 | # ResLayer(num_filts), 415 | # ResLayer(num_filts), 416 | # ResLayer(num_filts)) 417 | self.num_hidden_layers = num_hidden_layers 418 | self.feats = nn.Sequential() 419 | self.feats.add_module("ln_1", nn.Linear(num_inputs, num_filts)) 420 | self.feats.add_module("relu_1", nn.ReLU(inplace=True)) 421 | for i in range(num_hidden_layers): 422 | self.feats.add_module("resnet_{}".format(i+1), ResLayer(num_filts)) 423 | 424 | # def forward(self, x, class_of_interest=None, return_feats=False): 425 | def forward(self, x): 426 | loc_emb = self.feats(x) 427 | # if return_feats: 428 | # return loc_emb 429 | # if class_of_interest is None: 430 | # class_pred = self.class_emb(loc_emb) 431 | # else: 432 | # class_pred = self.eval_single_class(loc_emb, class_of_interest) 433 | 434 | # return torch.sigmoid(class_pred) 435 | return loc_emb 436 | 437 | # def eval_single_class(self, x, class_of_interest): 438 | # if self.inc_bias: 439 | # return torch.matmul(x, self.class_emb.weight[class_of_interest, :]) + self.class_emb.bias[class_of_interest] 440 | # else: 441 | # return torch.matmul(x, self.class_emb.weight[class_of_interest, :]) -------------------------------------------------------------------------------- /polygoncode/polygonembed/ddsl_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | from math import factorial 5 | import os 6 | 7 | 8 | def get_fourier_freqs(res, min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "geometric", dtype=torch.float32, exact=True, eps = 1e-4): 9 | """ 10 | Helper function to return frequency tensors 11 | This is a generalization of fftfreqs() 12 | Args: 13 | res: n_dims int tuple of number of frequency modes, (fx, fy) for polygon 14 | max_freqXY: the maximum frequency 15 | min_freqXY: the minimum frequency 16 | freq_init: frequency generated method, 17 | "geometric": geometric series 18 | "fft": fast fourier transformation 19 | :return: omega: 20 | if extract = True 21 | for res = (fx, fy) => shape (fx, fy//2+1, 2) 22 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2+1, n) 23 | if extract = False 24 | for res = (fx, fy) => shape (fx, fy//2, 2) 25 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2, n) 26 | """ 27 | n_dims = len(res) 28 | freqs = [] 29 | for dim in range(n_dims - 1): 30 | r_ = res[dim] 31 | freq = make_fourier_freq_vector(res_dim = r_, 32 | min_freqXY = min_freqXY, 33 | max_freqXY = max_freqXY, 34 | mid_freqXY = mid_freqXY, 35 | freq_init = freq_init, 36 | get_full_vector = True) 37 | freqs.append(torch.tensor(freq, dtype=dtype)) 38 | r_ = res[-1] 39 | freq = make_fourier_freq_vector(res_dim = r_, 40 | min_freqXY = min_freqXY, 41 | max_freqXY = max_freqXY, 42 | mid_freqXY = mid_freqXY, 43 | freq_init = freq_init, 44 | get_full_vector = False) 45 | if exact: 46 | freqs.append(torch.tensor(freq, dtype=dtype)) 47 | else: 48 | freqs.append(torch.tensor(freq[:-1], dtype=dtype)) 49 | omega = torch.meshgrid(freqs) 50 | omega = list(omega) 51 | omega = torch.stack(omega, dim=-1) 52 | 53 | omega[0, 0, :] = torch.FloatTensor(np.random.rand(2)*eps).to(omega.dtype) 54 | return omega 55 | 56 | 57 | def compute_geoemtric_series(min_, max_, num): 58 | log_timescale_increment = (math.log(float(max_) / float(min_)) / 59 | (num*1.0 - 1)) 60 | 61 | timescales = min_ * np.exp( 62 | np.arange(num).astype(float) * log_timescale_increment) 63 | return timescales 64 | 65 | def make_freq_series(min_freqXY, max_freqXY, frequency_num, mid_freqXY = None, freq_init = "geometric"): 66 | if freq_init == "geometric": 67 | return compute_geoemtric_series(min_freqXY, max_freqXY, frequency_num) 68 | elif freq_init == "arith_geometric": 69 | assert mid_freqXY is not None 70 | assert min_freqXY < mid_freqXY < max_freqXY 71 | left_freq_num = int(frequency_num/2) 72 | right_freq_num = int(frequency_num - left_freq_num) 73 | 74 | # left: arithmatric 75 | left_freqs = np.linspace(start = min_freqXY, stop=mid_freqXY, num=left_freq_num, endpoint=False) 76 | 77 | # right: geometric 78 | right_freqs = compute_geoemtric_series(min_ = mid_freqXY, max_ = max_freqXY, num = right_freq_num) 79 | 80 | freqs = np.concatenate([left_freqs, right_freqs], axis = -1) 81 | return freqs 82 | elif freq_init == "geometric_arith": 83 | assert mid_freqXY is not None 84 | assert min_freqXY < mid_freqXY < max_freqXY 85 | left_freq_num = int(frequency_num/2) 86 | right_freq_num = int(frequency_num - left_freq_num) 87 | 88 | # left: geometric 89 | left_freqs = compute_geoemtric_series(min_ = min_freqXY, max_ = mid_freqXY, num = left_freq_num) 90 | 91 | # right: arithmatric 92 | right_freqs = np.linspace(start = mid_freqXY, stop=max_freqXY, num=right_freq_num+1, endpoint=True) 93 | 94 | 95 | 96 | freqs = np.concatenate([left_freqs, right_freqs[1:]], axis = -1) 97 | return freqs 98 | else: 99 | raise Exception(f"freq_init = {freq_init} is not implemented" ) 100 | 101 | def make_fourier_freq_vector(res_dim, min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "geometric", get_full_vector = True): 102 | ''' 103 | make the frequency vector for X or Y dimention 104 | Args: 105 | res_dim: the total frequency we want 106 | max_freqXY: the maximum frequency 107 | min_freqXY: the minimum frequency 108 | get_full_vector: get the full frequency vector, or half of them (Y dimention) 109 | ''' 110 | if freq_init == "fft": 111 | if get_full_vector: 112 | freq = np.fft.fftfreq(res_dim, d=1/res_dim) 113 | else: 114 | freq = np.fft.rfftfreq(res_dim, d=1/res_dim) 115 | else: 116 | half_freqs = make_freq_series(min_freqXY, max_freqXY, frequency_num = res_dim//2, 117 | mid_freqXY = mid_freqXY, freq_init = freq_init) 118 | 119 | if get_full_vector: 120 | neg_half_freqs = -np.flip(half_freqs, axis = -1) 121 | if res_dim % 2 == 0: 122 | freq = np.concatenate([np.array([0.0]), half_freqs[:-1], neg_half_freqs], axis = -1) 123 | else: 124 | freq = np.concatenate([np.array([0.0]), half_freqs, neg_half_freqs], axis = -1) 125 | else: 126 | if res_dim % 2 == 0: 127 | freq = np.concatenate([np.array([0.0]), half_freqs], axis = -1) 128 | else: 129 | freq = np.concatenate([np.array([0.0]), half_freqs], axis = -1) 130 | 131 | if get_full_vector: 132 | assert freq.shape[0] == res_dim 133 | else: 134 | assert freq.shape[0] == math.floor(res_dim*1.0/2)+1 135 | return freq 136 | 137 | def fftfreqs(res, dtype=torch.float32, exact=True, eps = 1e-4): 138 | """ 139 | Helper function to return frequency tensors 140 | :param res: n_dims int tuple of number of frequency modes, (fx, fy) for polygon 141 | :return: omega: 142 | if extract = True 143 | for res = (fx, fy) => shape (fx, fy//2+1, 2) 144 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2+1, n) 145 | if extract = False 146 | for res = (fx, fy) => shape (fx, fy//2, 2) 147 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2, n) 148 | """ 149 | 150 | n_dims = len(res) 151 | freqs = [] 152 | for dim in range(n_dims - 1): 153 | r_ = res[dim] 154 | freq = np.fft.fftfreq(r_, d=1/r_) 155 | freqs.append(torch.tensor(freq, dtype=dtype)) 156 | r_ = res[-1] 157 | if exact: 158 | freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype)) 159 | else: 160 | freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype)) 161 | omega = torch.meshgrid(freqs) 162 | omega = list(omega) 163 | omega = torch.stack(omega, dim=-1) 164 | 165 | omega[0, 0, :] = torch.FloatTensor(np.random.rand(2)*eps).to(omega.dtype) 166 | 167 | return omega 168 | 169 | # def fftfreqs(res, dtype=torch.float32, exact=True): 170 | # """ 171 | # Helper function to return frequency tensors 172 | # :param res: n_dims int tuple of number of frequency modes, (fx, fy) for polygon 173 | # :return: omega: 174 | # if extract = True 175 | # for res = (fx, fy) => shape (fx, fy//2+1, 2) 176 | # for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2+1, n) 177 | # if extract = False 178 | # for res = (fx, fy) => shape (fx, fy//2, 2) 179 | # for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2, n) 180 | # """ 181 | 182 | # n_dims = len(res) 183 | # freqs = [] 184 | # for dim in range(n_dims - 1): 185 | # r_ = res[dim] 186 | # freq = np.fft.fftfreq(r_, d=1/r_) 187 | # freqs.append(torch.tensor(freq, dtype=dtype)) 188 | # r_ = res[-1] 189 | # if exact: 190 | # freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype)) 191 | # else: 192 | # freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype)) 193 | # omega = torch.meshgrid(freqs) 194 | # omega = list(omega) 195 | # omega = torch.stack(omega, dim=-1) 196 | 197 | # return omega 198 | 199 | 200 | def permute_seq(i, len): 201 | """ 202 | Permute the ordering of integer sequences 203 | """ 204 | assert(i 0: 322 | vol2[:, neg_mask] = 0 323 | print("[!]Warning: zeroing {0} small negative number".format(torch.sum(neg_mask).item())) 324 | vol =torch.sqrt(vol2) 325 | else: 326 | # Equation 8 in https://openreview.net/pdf?id=B1G5ViAqFm 327 | # compute based on: https://en.m.wikipedia.org/wiki/Simplex#Volume 328 | assert(j == nd) 329 | # matrix determinant 330 | # x_{j+1} is auxilarity node with (0, 0) coordinate 331 | # [x1,x2,...,x_{j}] - [x_{j+1}] 332 | # P[:, :, :-1]: [x1,x2,...,x_{j}], shape (batch_size, n_elem = num_vert, j = 2, n_dims = 2), 333 | # P[:, :, -1:]: [x_{j+1}], shape (batch_size, n_elem = num_vert, 1, n_dims = 2) 334 | # mat: shape (batch_size, n_elem = num_vert, j = 2, n_dims = 2) 335 | mat = P[:, :, :-1] - P[:, :, -1:] 336 | # vol: shape (batch_size, n_elem = num_vert) 337 | vol = torch.linalg.det(mat) / math.factorial(j) 338 | 339 | return vol.unsqueeze(-1) 340 | 341 | 342 | def batch_det(A): 343 | """ 344 | (No use) Batch compute determinant of square matrix A of shape (*, N, N) 345 | 346 | We can use torch.linalg.det() directly 347 | Return: 348 | Tensor of shape (*) 349 | 350 | 第一种初等变换—bai—交换两行或du列—zhi—要偶数次不改变dao行列的值 351 | 第二种初等zhuan变换——shu某行(列)乘以非0实数——这个可以乘以系数,但总的乘积必须为1方可不改变行列式值 352 | 第三种初等变换——某行(列)乘以实数加到另一行(列)上——此条对行列式值无影响 353 | """ 354 | # 对矩阵进行LU 分解,显然,分解后的矩阵对角线上元素的乘积即为原始矩阵行列式的值 355 | # 由于在进行 LU 分解时可能会进行行交换的情况,而每次行交换都会带来行列式的符号变化,所以我们要记录行交换次数的奇偶性, 356 | # 奇数则符号改变,偶数则符号不变,请密切注意下面程序中出现的用来标记变换次数奇偶性的变量 parity 357 | # LU: the packed LU factorization matrix, (*, N, N) 358 | # pivots: (*, N) 359 | LU, pivots = torch.lu(A) 360 | # torch.einsum('...ii->...i', LU): the diagnoal vector 361 | # det_LU: the product of all diagnal values 362 | det_LU = torch.einsum('...ii->...i', LU).prod(-1) 363 | pivots -= 1 364 | d = pivots.shape[-1] 365 | perm = pivots - torch.arange(d, dtype=pivots.dtype, device=pivots.device).expand(pivots.shape) 366 | det_P = (-1) ** ((perm != 0).sum(-1)) 367 | det = det_LU * det_P.type(det_LU.dtype) 368 | 369 | return det 370 | 371 | 372 | 373 | def make_E(V): 374 | ''' 375 | here, we assume V comes from a simple polygon 376 | Given polygon vertice tensor -> V with shape (batch_size, num_vert, 2) 377 | Generate its edge matrix E 378 | Note, num_vert reflect all unique vertices, remove the repeated last/first node beforehand 379 | Here, num_vert: number of vertice of input polygon = number of edges = number pf 2-simplex (auxiliary node) 380 | 381 | Args: 382 | V with shape (batch_size, num_vert, 2) 383 | Return: 384 | E: torch.LongTensor(), shape (batch_size, num_vert, 2) 385 | ''' 386 | batch_size, num_vert, n_dims = V.shape 387 | a = torch.arange(0, num_vert) 388 | E = torch.stack((a, a+1), dim = 0).permute(1,0) 389 | E[-1, -1] = 0 390 | E = torch.repeat_interleave(E.unsqueeze(0), repeats = batch_size, dim=0).to(V.device) 391 | return E 392 | 393 | def make_D(V): 394 | ''' 395 | Given polygon vertice tensor -> V with shape (batch_size, num_vert, 2) 396 | Generate its density matrix D 397 | Note, num_vert reflect all unique vertices, remove the repeated last/first node beforehand 398 | Here, num_vert: number of vertice of input polygon = number of edges = number pf 2-simplex (auxiliary node) 399 | 400 | Args: 401 | V: shape (batch_size, num_vert, 2) 402 | Return: 403 | D: torch.LongTensor(), shape (batch_size, num_vert, 1) 404 | ''' 405 | batch_size, num_vert, n_dims = V.shape 406 | D = torch.ones(batch_size, num_vert, 1).to(V.device) 407 | return D 408 | 409 | def affinity_V(V, extent): 410 | ''' 411 | affinity vertice tensor to move it to [0, periodX, 0, periodY] 412 | 413 | Args: 414 | V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2) 415 | vertex tensor 416 | extent: the maximum spatial extent of all polygons, (minx, maxx, miny, maxy) 417 | 418 | eps: the maximum noise we add to each polygon vertice 419 | Retunr: 420 | V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2) 421 | ''' 422 | device = V.device 423 | minx, maxx, miny, maxy = extent 424 | 425 | # assert maxx - minx == periodXY[0] 426 | # assert maxy - miny == periodXY[1] 427 | 428 | # affinity all polygons to make them has positive coordinates 429 | # move to (0,2,0,2) 430 | V = V + torch.FloatTensor([-minx, -miny]).to(device) 431 | 432 | return V 433 | 434 | 435 | def add_noise_V(V, eps): 436 | ''' 437 | add small noise to each vertice to make NUFT more robust 438 | Args: 439 | V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2) 440 | vertex tensor 441 | eps: the maximum noise we add to each polygon vertice 442 | Retunr: 443 | V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2) 444 | ''' 445 | # add small noise 446 | V = V + torch.rand(V.shape, device = V.device)*eps 447 | return V 448 | 449 | 450 | 451 | 452 | 453 | 454 | def make_periodXY(extent): 455 | ''' 456 | Make periodXY based on the spatial extent 457 | Args: 458 | extent: (minx, maxx, miny, maxy) 459 | Return: 460 | periodXY: t in DDSL_spec(), [periodX, periodY] 461 | periodX, periodY: the spatial extend from [0, periodX] 462 | ''' 463 | minx, maxx, miny, maxy = extent 464 | 465 | periodX = maxx - minx 466 | periodY = maxy - miny 467 | 468 | periodXY = [periodX, periodY] 469 | return periodXY 470 | 471 | 472 | def polygon_nuft_input(polygons, extent, V = None, E = None): 473 | ''' 474 | polygons: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2) 475 | last points not equal to the 1st one 476 | extent: the maximum spatial extent of all polygons, (minx, maxx, miny, maxy) 477 | ''' 478 | assert (polygons is None and V is not None and E is not None) or (polygons is not None and V is None and E is None) 479 | if polygons is not None: 480 | # V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2) 481 | # vertex tensor 482 | V = polygons 483 | # affinity vertice tensor to move it to [0, periodX, 0, periodY] 484 | V = affinity_V(V, extent) 485 | if E is None: 486 | # add noise 487 | # V = add_noise_V(V, self.eps) 488 | # E: torch.LongTensor(), shape (batch_size, num_vert, 2) 489 | # element tensor, each element [[0,1], [1,2],...,[num_vert-1, 0]] 490 | E = make_E(V) 491 | # D: torch.LongTensor(), shape (batch_size, num_vert, 1) 492 | # all be one tensor 493 | D = make_D(V) 494 | 495 | 496 | return V, E, D -------------------------------------------------------------------------------- /polygoncode/polygonembed/dla.py: -------------------------------------------------------------------------------- 1 | """ 2 | DLA for ImageNet-1K, implemented in PyTorch. 3 | Original paper: 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 4 | """ 5 | 6 | __all__ = ['DLA', 'dla34', 'dla46c', 'dla46xc', 'dla60', 'dla60x', 'dla60xc', 'dla102', 'dla102x', 'dla102x2', 'dla169'] 7 | 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.init as init 12 | from polygonembed.dla_common import conv1x1, conv1x1_block, conv3x3_block, conv7x7_block 13 | from polygonembed.dla_resnet import ResBlock, ResBottleneck 14 | from polygonembed.dla_resnext import ResNeXtBottleneck 15 | 16 | 17 | class DLABottleneck(ResBottleneck): 18 | """ 19 | DLA bottleneck block for residual path in residual block. 20 | 21 | Parameters: 22 | ---------- 23 | in_channels : int 24 | Number of input channels. 25 | out_channels : int 26 | Number of output channels. 27 | stride : int or tuple/list of 2 int 28 | Strides of the convolution. 29 | bottleneck_factor : int, default 2 30 | Bottleneck factor. 31 | """ 32 | def __init__(self, 33 | in_channels, 34 | out_channels, 35 | stride, 36 | bottleneck_factor=2): 37 | super(DLABottleneck, self).__init__( 38 | in_channels=in_channels, 39 | out_channels=out_channels, 40 | stride=stride, 41 | bottleneck_factor=bottleneck_factor) 42 | 43 | 44 | class DLABottleneckX(ResNeXtBottleneck): 45 | """ 46 | DLA ResNeXt-like bottleneck block for residual path in residual block. 47 | 48 | Parameters: 49 | ---------- 50 | in_channels : int 51 | Number of input channels. 52 | out_channels : int 53 | Number of output channels. 54 | stride : int or tuple/list of 2 int 55 | Strides of the convolution. 56 | cardinality: int, default 32 57 | Number of groups. 58 | bottleneck_width: int, default 8 59 | Width of bottleneck block. 60 | """ 61 | def __init__(self, 62 | in_channels, 63 | out_channels, 64 | stride, 65 | cardinality=32, 66 | bottleneck_width=8): 67 | super(DLABottleneckX, self).__init__( 68 | in_channels=in_channels, 69 | out_channels=out_channels, 70 | stride=stride, 71 | cardinality=cardinality, 72 | bottleneck_width=bottleneck_width) 73 | 74 | 75 | class DLAResBlock(nn.Module): 76 | """ 77 | DLA residual block with residual connection. 78 | 79 | Parameters: 80 | ---------- 81 | in_channels : int 82 | Number of input channels. 83 | out_channels : int 84 | Number of output channels. 85 | stride : int or tuple/list of 2 int 86 | Strides of the convolution. 87 | body_class : nn.Module, default ResBlock 88 | Residual block body class. 89 | return_down : bool, default False 90 | Whether return downsample result. 91 | """ 92 | def __init__(self, 93 | in_channels, 94 | out_channels, 95 | stride, 96 | body_class=ResBlock, 97 | return_down=False): 98 | super(DLAResBlock, self).__init__() 99 | self.return_down = return_down 100 | self.downsample = (stride > 1) 101 | self.project = (in_channels != out_channels) 102 | 103 | self.body = body_class( 104 | in_channels=in_channels, 105 | out_channels=out_channels, 106 | stride=stride) 107 | self.activ = nn.ReLU(inplace=True) 108 | if self.downsample: 109 | self.downsample_pool = nn.MaxPool2d( 110 | kernel_size=stride, 111 | stride=stride) 112 | if self.project: 113 | self.project_conv = conv1x1_block( 114 | in_channels=in_channels, 115 | out_channels=out_channels, 116 | activation=None) 117 | 118 | def forward(self, x): 119 | down = self.downsample_pool(x) if self.downsample else x 120 | identity = self.project_conv(down) if self.project else down 121 | if identity is None: 122 | identity = x 123 | x = self.body(x) 124 | x += identity 125 | x = self.activ(x) 126 | if self.return_down: 127 | return x, down 128 | else: 129 | return x 130 | 131 | 132 | class DLARoot(nn.Module): 133 | """ 134 | DLA root block. 135 | 136 | Parameters: 137 | ---------- 138 | in_channels : int 139 | Number of input channels. 140 | out_channels : int 141 | Number of output channels. 142 | residual : bool 143 | Whether use residual connection. 144 | """ 145 | def __init__(self, 146 | in_channels, 147 | out_channels, 148 | residual): 149 | super(DLARoot, self).__init__() 150 | self.residual = residual 151 | 152 | self.conv = conv1x1_block( 153 | in_channels=in_channels, 154 | out_channels=out_channels, 155 | activation=None) 156 | self.activ = nn.ReLU(inplace=True) 157 | 158 | def forward(self, x2, x1, extra): 159 | last_branch = x2 160 | x = torch.cat((x2, x1) + tuple(extra), dim=1) 161 | x = self.conv(x) 162 | if self.residual: 163 | x += last_branch 164 | x = self.activ(x) 165 | return x 166 | 167 | 168 | class DLATree(nn.Module): 169 | """ 170 | DLA tree unit. It's like iterative stage. 171 | 172 | Parameters: 173 | ---------- 174 | levels : int 175 | Number of levels in the stage. 176 | in_channels : int 177 | Number of input channels. 178 | out_channels : int 179 | Number of output channels. 180 | res_body_class : nn.Module 181 | Residual block body class. 182 | stride : int or tuple/list of 2 int 183 | Strides of the convolution in a residual block. 184 | root_residual : bool 185 | Whether use residual connection in the root. 186 | root_dim : int 187 | Number of input channels in the root block. 188 | first_tree : bool, default False 189 | Is this tree stage the first stage in the net. 190 | input_level : bool, default True 191 | Is this tree unit the first unit in the stage. 192 | return_down : bool, default False 193 | Whether return downsample result. 194 | """ 195 | def __init__(self, 196 | levels, 197 | in_channels, 198 | out_channels, 199 | res_body_class, 200 | stride, 201 | root_residual, 202 | root_dim=0, 203 | first_tree=False, 204 | input_level=True, 205 | return_down=False): 206 | super(DLATree, self).__init__() 207 | self.return_down = return_down 208 | self.add_down = (input_level and not first_tree) 209 | self.root_level = (levels == 1) 210 | 211 | if root_dim == 0: 212 | root_dim = 2 * out_channels 213 | if self.add_down: 214 | root_dim += in_channels 215 | 216 | if self.root_level: 217 | self.tree1 = DLAResBlock( 218 | in_channels=in_channels, 219 | out_channels=out_channels, 220 | stride=stride, 221 | body_class=res_body_class, 222 | return_down=True) 223 | self.tree2 = DLAResBlock( 224 | in_channels=out_channels, 225 | out_channels=out_channels, 226 | stride=1, 227 | body_class=res_body_class, 228 | return_down=False) 229 | else: 230 | self.tree1 = DLATree( 231 | levels=levels - 1, 232 | in_channels=in_channels, 233 | out_channels=out_channels, 234 | res_body_class=res_body_class, 235 | stride=stride, 236 | root_residual=root_residual, 237 | root_dim=0, 238 | input_level=False, 239 | return_down=True) 240 | self.tree2 = DLATree( 241 | levels=levels - 1, 242 | in_channels=out_channels, 243 | out_channels=out_channels, 244 | res_body_class=res_body_class, 245 | stride=1, 246 | root_residual=root_residual, 247 | root_dim=root_dim + out_channels, 248 | input_level=False, 249 | return_down=False) 250 | if self.root_level: 251 | self.root = DLARoot( 252 | in_channels=root_dim, 253 | out_channels=out_channels, 254 | residual=root_residual) 255 | 256 | def forward(self, x, extra=None): 257 | extra = [] if extra is None else extra 258 | x1, down = self.tree1(x) 259 | if self.add_down: 260 | extra.append(down) 261 | if self.root_level: 262 | x2 = self.tree2(x1) 263 | x = self.root(x2, x1, extra) 264 | else: 265 | extra.append(x1) 266 | x = self.tree2(x1, extra) 267 | if self.return_down: 268 | return x, down 269 | else: 270 | return x 271 | 272 | 273 | class DLAInitBlock(nn.Module): 274 | """ 275 | DLA specific initial block. 276 | 277 | Parameters: 278 | ---------- 279 | in_channels : int 280 | Number of input channels. 281 | out_channels : int 282 | Number of output channels. 283 | """ 284 | def __init__(self, 285 | in_channels, 286 | out_channels): 287 | super(DLAInitBlock, self).__init__() 288 | mid_channels = out_channels // 2 289 | 290 | self.conv1 = conv7x7_block( 291 | in_channels=in_channels, 292 | out_channels=mid_channels) 293 | self.conv2 = conv3x3_block( 294 | in_channels=mid_channels, 295 | out_channels=mid_channels) 296 | self.conv3 = conv3x3_block( 297 | in_channels=mid_channels, 298 | out_channels=out_channels, 299 | stride=2) 300 | 301 | def forward(self, x): 302 | ''' 303 | Args: 304 | x: shape (N, in_channels, H, W) 305 | ''' 306 | # x: shape (N, mid_channels, H, W) 307 | x = self.conv1(x) 308 | # x: shape (N, mid_channels, H, W) 309 | x = self.conv2(x) 310 | # x: shape (N, out_channels, (H+1)/2, (W+1)/2) 311 | x = self.conv3(x) 312 | return x 313 | 314 | 315 | class DLA(nn.Module): 316 | """ 317 | DLA model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 318 | 319 | Parameters: 320 | ---------- 321 | levels : int 322 | Number of levels in each stage. 323 | channels : list of int 324 | Number of output channels for each stage. 325 | init_block_channels : int 326 | Number of output channels for the initial unit. 327 | res_body_class : nn.Module 328 | Residual block body class. 329 | residual_root : bool 330 | Whether use residual connection in the root blocks. 331 | in_channels : int, default 3 332 | Number of input channels. 333 | in_size : tuple of two ints, default (224, 224) 334 | Spatial size of the expected input image. 335 | num_classes : int, default 1000 336 | Number of classification classes. 337 | """ 338 | def __init__(self, 339 | levels, 340 | channels, 341 | init_block_channels, 342 | res_body_class, 343 | residual_root, 344 | in_channels=3, 345 | in_size=(224, 224), 346 | num_classes=1000): 347 | super(DLA, self).__init__() 348 | self.in_size = in_size 349 | self.num_classes = num_classes 350 | 351 | self.features = nn.Sequential() 352 | self.features.add_module("init_block", DLAInitBlock( 353 | in_channels=in_channels, 354 | out_channels=init_block_channels)) 355 | in_channels = init_block_channels 356 | 357 | for i in range(len(levels)): 358 | levels_i = levels[i] 359 | out_channels = channels[i] 360 | first_tree = (i == 0) 361 | self.features.add_module("stage{}".format(i + 1), DLATree( 362 | levels=levels_i, 363 | in_channels=in_channels, 364 | out_channels=out_channels, 365 | res_body_class=res_body_class, 366 | stride=2, 367 | root_residual=residual_root, 368 | first_tree=first_tree)) 369 | in_channels = out_channels 370 | 371 | self.features.add_module("final_pool", nn.AvgPool2d( 372 | kernel_size=7, 373 | stride=1)) 374 | 375 | self.output = conv1x1( 376 | in_channels=in_channels, 377 | out_channels=num_classes, 378 | bias=True) 379 | 380 | self._init_params() 381 | 382 | def _init_params(self): 383 | for name, module in self.named_modules(): 384 | if isinstance(module, nn.Conv2d): 385 | init.kaiming_uniform_(module.weight) 386 | if module.bias is not None: 387 | init.constant_(module.bias, 0) 388 | 389 | def forward(self, x): 390 | x = self.features(x) 391 | print(x.shape) 392 | x = self.output(x) 393 | print(x.shape) 394 | x = x.view(x.size(0), -1) 395 | print(x.shape) 396 | return x 397 | 398 | 399 | def get_dla(levels, 400 | channels, 401 | res_body_class, 402 | residual_root=False, 403 | model_name=None, 404 | pretrained=False, 405 | root=os.path.join("~", ".torch", "models"), 406 | **kwargs): 407 | """ 408 | Create DLA model with specific parameters. 409 | 410 | Parameters: 411 | ---------- 412 | levels : int 413 | Number of levels in each stage. 414 | channels : list of int 415 | Number of output channels for each stage. 416 | res_body_class : nn.Module 417 | Residual block body class. 418 | residual_root : bool, default False 419 | Whether use residual connection in the root blocks. 420 | model_name : str or None, default None 421 | Model name for loading pretrained model. 422 | pretrained : bool, default False 423 | Whether to load the pretrained weights for model. 424 | root : str, default '~/.torch/models' 425 | Location for keeping the model parameters. 426 | """ 427 | init_block_channels = 32 428 | 429 | net = DLA( 430 | levels=levels, 431 | channels=channels, 432 | init_block_channels=init_block_channels, 433 | res_body_class=res_body_class, 434 | residual_root=residual_root, 435 | **kwargs) 436 | 437 | if pretrained: 438 | if (model_name is None) or (not model_name): 439 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") 440 | from .model_store import download_model 441 | download_model( 442 | net=net, 443 | model_name=model_name, 444 | local_model_store_dir_path=root) 445 | 446 | return net 447 | 448 | 449 | def dla34(**kwargs): 450 | """ 451 | DLA-34 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 452 | 453 | Parameters: 454 | ---------- 455 | pretrained : bool, default False 456 | Whether to load the pretrained weights for model. 457 | root : str, default '~/.torch/models' 458 | Location for keeping the model parameters. 459 | """ 460 | return get_dla(levels=[1, 2, 2, 1], channels=[64, 128, 256, 512], res_body_class=ResBlock, model_name="dla34", 461 | **kwargs) 462 | 463 | 464 | def dla46c(**kwargs): 465 | """ 466 | DLA-46-C model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 467 | 468 | Parameters: 469 | ---------- 470 | pretrained : bool, default False 471 | Whether to load the pretrained weights for model. 472 | root : str, default '~/.torch/models' 473 | Location for keeping the model parameters. 474 | """ 475 | return get_dla(levels=[1, 2, 2, 1], channels=[64, 64, 128, 256], res_body_class=DLABottleneck, model_name="dla46c", 476 | **kwargs) 477 | 478 | 479 | def dla46xc(**kwargs): 480 | """ 481 | DLA-X-46-C model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 482 | 483 | Parameters: 484 | ---------- 485 | pretrained : bool, default False 486 | Whether to load the pretrained weights for model. 487 | root : str, default '~/.torch/models' 488 | Location for keeping the model parameters. 489 | """ 490 | return get_dla(levels=[1, 2, 2, 1], channels=[64, 64, 128, 256], res_body_class=DLABottleneckX, 491 | model_name="dla46xc", **kwargs) 492 | 493 | 494 | def dla60(**kwargs): 495 | """ 496 | DLA-60 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 497 | 498 | Parameters: 499 | ---------- 500 | pretrained : bool, default False 501 | Whether to load the pretrained weights for model. 502 | root : str, default '~/.torch/models' 503 | Location for keeping the model parameters. 504 | """ 505 | return get_dla(levels=[1, 2, 3, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneck, 506 | model_name="dla60", **kwargs) 507 | 508 | 509 | def dla60x(**kwargs): 510 | """ 511 | DLA-X-60 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 512 | 513 | Parameters: 514 | ---------- 515 | pretrained : bool, default False 516 | Whether to load the pretrained weights for model. 517 | root : str, default '~/.torch/models' 518 | Location for keeping the model parameters. 519 | """ 520 | return get_dla(levels=[1, 2, 3, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneckX, 521 | model_name="dla60x", **kwargs) 522 | 523 | 524 | def dla60xc(**kwargs): 525 | """ 526 | DLA-X-60-C model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 527 | 528 | Parameters: 529 | ---------- 530 | pretrained : bool, default False 531 | Whether to load the pretrained weights for model. 532 | root : str, default '~/.torch/models' 533 | Location for keeping the model parameters. 534 | """ 535 | return get_dla(levels=[1, 2, 3, 1], channels=[64, 64, 128, 256], res_body_class=DLABottleneckX, 536 | model_name="dla60xc", **kwargs) 537 | 538 | 539 | def dla102(**kwargs): 540 | """ 541 | DLA-102 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 542 | 543 | Parameters: 544 | ---------- 545 | pretrained : bool, default False 546 | Whether to load the pretrained weights for model. 547 | root : str, default '~/.torch/models' 548 | Location for keeping the model parameters. 549 | """ 550 | return get_dla(levels=[1, 3, 4, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneck, 551 | residual_root=True, model_name="dla102", **kwargs) 552 | 553 | 554 | def dla102x(**kwargs): 555 | """ 556 | DLA-X-102 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 557 | 558 | Parameters: 559 | ---------- 560 | pretrained : bool, default False 561 | Whether to load the pretrained weights for model. 562 | root : str, default '~/.torch/models' 563 | Location for keeping the model parameters. 564 | """ 565 | return get_dla(levels=[1, 3, 4, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneckX, 566 | residual_root=True, model_name="dla102x", **kwargs) 567 | 568 | 569 | def dla102x2(**kwargs): 570 | """ 571 | DLA-X2-102 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 572 | 573 | Parameters: 574 | ---------- 575 | pretrained : bool, default False 576 | Whether to load the pretrained weights for model. 577 | root : str, default '~/.torch/models' 578 | Location for keeping the model parameters. 579 | """ 580 | class DLABottleneckX64(DLABottleneckX): 581 | def __init__(self, in_channels, out_channels, stride): 582 | super(DLABottleneckX64, self).__init__(in_channels, out_channels, stride, cardinality=64) 583 | 584 | return get_dla(levels=[1, 3, 4, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneckX64, 585 | residual_root=True, model_name="dla102x2", **kwargs) 586 | 587 | 588 | def dla169(**kwargs): 589 | """ 590 | DLA-169 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484. 591 | 592 | Parameters: 593 | ---------- 594 | pretrained : bool, default False 595 | Whether to load the pretrained weights for model. 596 | root : str, default '~/.torch/models' 597 | Location for keeping the model parameters. 598 | """ 599 | return get_dla(levels=[2, 3, 5, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneck, 600 | residual_root=True, model_name="dla169", **kwargs) 601 | 602 | 603 | def _calc_width(net): 604 | import numpy as np 605 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 606 | weight_count = 0 607 | for param in net_params: 608 | weight_count += np.prod(param.size()) 609 | return weight_count 610 | 611 | 612 | def _test(): 613 | import torch 614 | 615 | pretrained = False 616 | 617 | models = [ 618 | dla34, 619 | dla46c, 620 | dla46xc, 621 | dla60, 622 | dla60x, 623 | dla60xc, 624 | dla102, 625 | dla102x, 626 | dla102x2, 627 | dla169, 628 | ] 629 | 630 | for model in models: 631 | 632 | net = model(pretrained=pretrained) 633 | 634 | # net.train() 635 | net.eval() 636 | weight_count = _calc_width(net) 637 | print("m={}, {}".format(model.__name__, weight_count)) 638 | assert (model != dla34 or weight_count == 15742104) 639 | assert (model != dla46c or weight_count == 1301400) 640 | assert (model != dla46xc or weight_count == 1068440) 641 | assert (model != dla60 or weight_count == 22036632) 642 | assert (model != dla60x or weight_count == 17352344) 643 | assert (model != dla60xc or weight_count == 1319832) 644 | assert (model != dla102 or weight_count == 33268888) 645 | assert (model != dla102x or weight_count == 26309272) 646 | assert (model != dla102x2 or weight_count == 41282200) 647 | assert (model != dla169 or weight_count == 53389720) 648 | 649 | x = torch.randn(1, 3, 224, 224) 650 | y = net(x) 651 | y.sum().backward() 652 | assert (tuple(y.size()) == (1, 1000)) 653 | 654 | 655 | if __name__ == "__main__": 656 | _test() 657 | -------------------------------------------------------------------------------- /polygoncode/polygonembed/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | COde rewritten based on https://github.com/ldeecke/mn-torch/blob/master/nn/resnet.py 3 | ''' 4 | 5 | 6 | import functools 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from math import sqrt 12 | from torch.nn import BatchNorm1d 13 | 14 | from polygonembed.atten import * 15 | 16 | # from polygonembed.ops import ModeNorm 17 | 18 | # resnet20 = lambda config: ResNet(BasicBlock, [3, 3, 3], config) 19 | # resnet56 = lambda config: ResNet(BasicBlock, [9, 9, 9], config) 20 | # resnet110 = lambda config: ResNet(BasicBlock, [18, 18, 18], config) 21 | 22 | 23 | def get_agg_func(agg_func_type, out_channels, name = "resnet1d"): 24 | if agg_func_type == "mean": 25 | # global average pool 26 | return torch.mean 27 | elif agg_func_type == "min": 28 | # global min pool 29 | return torch.min 30 | elif agg_func_type == "max": 31 | # global max pool 32 | return torch.max 33 | elif agg_func_type.startswith("atten"): 34 | # agg_func_type: atten_whole_no_1 35 | atten_flag, att_type, bn, nat = agg_func_type.split("_") 36 | assert atten_flag == "atten" 37 | return AttentionSet(mode_dims = out_channels, 38 | att_reg = 0., 39 | att_tem = 1., 40 | att_type = att_type, 41 | bn = bn, 42 | nat= int(nat), 43 | name = name) 44 | 45 | 46 | class ResNet1D(nn.Module): 47 | def __init__(self, block, num_layer_list, in_channels, out_channels, add_middle_pool = False, final_pool = "mean", padding_mode = 'circular', dropout_rate = 0.5): 48 | ''' 49 | Args: 50 | block: BasicBlock() or BottleneckBlock() 51 | num_layer_list: [num_blocks0, num_blocks1, num_blocks2] 52 | inplanes: input number of channel 53 | ''' 54 | super(ResNet1D, self).__init__() 55 | 56 | Norm = functools.partial(BatchNorm1d) 57 | 58 | self.num_layer_list = num_layer_list 59 | 60 | self.in_channels = in_channels 61 | # For simplicity, make outplanes dividable by block.expansion 62 | assert out_channels % block.expansion == 0 63 | self.out_channels = out_channels 64 | planes = int(out_channels / block.expansion) 65 | 66 | self.inplanes = out_channels 67 | 68 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode = padding_mode, bias=False) 69 | 70 | 71 | self.norm1 = Norm(out_channels) 72 | self.relu = nn.ReLU(inplace=True) 73 | 74 | self.maxpool = nn.MaxPool1d(kernel_size = 2, stride=2, padding = 0) 75 | 76 | resnet_layers = [] 77 | 78 | if len(self.num_layer_list) >= 1: 79 | # you have at least one number in self.num_layer_list 80 | layer1 = self._make_layer(block, planes, self.num_layer_list[0], Norm, padding_mode = padding_mode) 81 | resnet_layers.append(layer1) 82 | if add_middle_pool and len(self.num_layer_list) > 1: 83 | maxpool = nn.MaxPool1d(kernel_size = 2, stride=2, padding = 0) 84 | resnet_layers.append(maxpool) 85 | 86 | if len(self.num_layer_list) >= 2: 87 | # you have at least two numbers in self.num_layer_list 88 | for i in range(1, len(self.num_layer_list)): 89 | layerk = self._make_layer(block, planes, self.num_layer_list[i], Norm, stride=2, padding_mode = padding_mode) 90 | resnet_layers.append(layerk) 91 | if add_middle_pool and i < len(self.num_layer_list) - 1: 92 | maxpool = nn.MaxPool1d(kernel_size = 2, stride=2, padding = 0) 93 | resnet_layers.append(maxpool) 94 | 95 | self.resnet_layers = nn.Sequential(*resnet_layers) 96 | 97 | self.final_pool = final_pool 98 | 99 | self.final_pool_func = get_agg_func(agg_func_type = final_pool, 100 | out_channels = out_channels, 101 | name = "resnet1d") 102 | 103 | self.dropout = nn.Dropout(p=dropout_rate) 104 | 105 | # if len(self.num_layer_list) == 3: 106 | # # you have three numbers in self.num_layer_list 107 | # self.layer3 = self._make_layer(block, planes, self.num_layer_list[2], Norm, stride=2, padding_mode = padding_mode) 108 | 109 | # self.avgpool = nn.AvgPool1d(8, stride=1) 110 | # self.fc = nn.Linear(64 * block.expansion, num_classes) 111 | 112 | self._init_weights() 113 | 114 | # def finalPool1d_func(self, final_pool, out_channels, name = "resnet1d"): 115 | # if final_pool == "mean": 116 | # # global average pool 117 | # return torch.mean 118 | # elif final_pool == "min": 119 | # # global min pool 120 | # return torch.min 121 | # elif final_pool == "max": 122 | # # global max pool 123 | # return torch.max 124 | # elif final_pool.startswith("atten"): 125 | # # final_pool: atten_whole_no_1 126 | # atten_flag, att_type, bn, nat = final_pool.split("_") 127 | # assert atten_flag == "atten" 128 | # return AttentionSet(mode_dims = out_channels, 129 | # att_reg = 0., 130 | # att_tem = 1., 131 | # att_type = att_type, 132 | # bn = bn, 133 | # nat= int(nat), 134 | # name = name) 135 | 136 | def finalPool1d(self, x, final_pool = "mean"): 137 | ''' 138 | Args: 139 | x: shape (batch_size, out_channels, (seq_len+k-2)/2^k ) 140 | Return: 141 | 142 | ''' 143 | if final_pool == "mean": 144 | # global average pool 145 | # x: shape (batch_size, out_channels) 146 | x = self.final_pool_func(x, dim = -1, keepdim = False) 147 | elif final_pool == "min": 148 | # global min pool 149 | # x: shape (batch_size, out_channels) 150 | x, indice = self.final_pool_func(x, dim = -1, keepdim = False) 151 | elif final_pool == "max": 152 | # global max pool 153 | # x: shape (batch_size, out_channels) 154 | x, indice = self.final_pool_func(x, dim = -1, keepdim = False) 155 | elif final_pool.startswith("atten"): 156 | # attenion based aggregation 157 | # x: shape (batch_size, out_channels) 158 | x = self.final_pool_func(x) 159 | return x 160 | 161 | def forward(self, x): 162 | ''' 163 | Args: 164 | x: shape (batch_size, in_channels, seq_len) 165 | Return: 166 | x: shape (batch_size, out_channels) 167 | ''' 168 | # x: shape (batch_size, out_channels, seq_len) 169 | x = self.conv1(x) 170 | x = self.norm1(x) 171 | x = self.relu(x) 172 | # print("conv1:", x.shape) 173 | 174 | # x: shape (batch_size, out_channels, seq_len/2) 175 | x = self.maxpool(x) 176 | 177 | # x: shape (batch_size, out_channels, (seq_len+k-2)/2^k ) 178 | x = self.resnet_layers(x) 179 | 180 | # if len(self.num_layer_list) >= 1: 181 | # # After 1st block: shape (batch_size, out_channels, seq_len/2) 182 | # # x: shape (batch_size, out_channels, seq_len/2) 183 | # x = self.layer1(x) 184 | # # print("layer1:", x.shape) 185 | 186 | # if len(self.num_layer_list) >= 2: 187 | # # After 1st block: shape (batch_size, out_channels, (seq_len+2)/4 ) 188 | # # x: shape (batch_size, out_channels, (seq_len+2)/4 ) 189 | # x = self.layer2(x) 190 | # # print("layer2:", x.shape) 191 | 192 | # if len(self.num_layer_list) == 3: 193 | # # After 1st block: shape (batch_size, out_channels, (seq_len+6)/8 ) 194 | # # x: shape (batch_size, out_channels, (seq_len+6)/8 ) 195 | # x = self.layer3(x) 196 | # # print("layer3:", x.shape) 197 | 198 | 199 | 200 | # global pool 201 | # x: shape (batch_size, out_channels) 202 | x = self.finalPool1d(x, self.final_pool) 203 | # x = torch.mean(x, dim = -1, keepdim = False) 204 | # print("avgpool:", x.shape) 205 | 206 | 207 | x = self.dropout(x) 208 | 209 | # x: shape (batch_size, 64*expansion, (seq_len-25)/4 ) 210 | # x = self.avgpool(x) 211 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 ) 212 | # x = x.view(x.size(0), -1) 213 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 ) 214 | # x = self.fc(x) 215 | # print("output:", x.shape) 216 | 217 | return x 218 | 219 | 220 | def _init_weights(self): 221 | for m in self.modules(): 222 | if isinstance(m, nn.Conv1d): 223 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 224 | n = m.kernel_size[0] * m.out_channels 225 | m.weight.data.normal_(0, sqrt(2./n)) 226 | elif isinstance(m, nn.BatchNorm1d): 227 | m.weight.data.fill_(1) 228 | m.bias.data.zero_() 229 | # elif isinstance(m, ModeNorm): 230 | # m.alpha.data.fill_(1) 231 | # m.beta.data.zero_() 232 | 233 | 234 | def _make_layer(self, block, planes, blocks, norm, stride=1, padding_mode = 'circular'): 235 | downsample = None 236 | if (stride != 1) or (self.inplanes != planes * block.expansion): 237 | downsample = nn.Sequential( 238 | nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 239 | norm(planes * block.expansion) 240 | ) 241 | 242 | layers = [] 243 | layers.append(block(self.inplanes, planes, norm, stride, downsample, padding_mode)) 244 | # self.inplanes = planes * block.expansion 245 | for _ in range(1, blocks): 246 | layers.append(block(self.inplanes, planes, norm, padding_mode = padding_mode)) 247 | 248 | return nn.Sequential(*layers) 249 | 250 | 251 | class ResNet1D3(nn.Module): 252 | def __init__(self, block, num_layer_list, in_channels, out_channels, padding_mode = 'circular'): 253 | ''' 254 | Args: 255 | block: BasicBlock() or BottleneckBlock() 256 | num_layer_list: [num_blocks0, num_blocks1, num_blocks2] 257 | inplanes: input number of channel 258 | ''' 259 | super(ResNet1D3, self).__init__() 260 | 261 | Norm = functools.partial(BatchNorm1d) 262 | 263 | self.num_layer_list = num_layer_list 264 | 265 | self.in_channels = in_channels 266 | # For simplicity, make outplanes dividable by block.expansion 267 | assert out_channels % block.expansion == 0 268 | self.out_channels = out_channels 269 | planes = int(out_channels / block.expansion) 270 | 271 | self.inplanes = out_channels 272 | 273 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode = padding_mode, bias=False) 274 | 275 | 276 | self.norm1 = Norm(out_channels) 277 | self.relu = nn.ReLU(inplace=True) 278 | 279 | self.maxpool = nn.MaxPool1d(kernel_size = 2, stride=2, padding = 0) 280 | 281 | if len(self.num_layer_list) >= 1: 282 | # you have at least one number in self.num_layer_list 283 | self.layer1 = self._make_layer(block, planes, self.num_layer_list[0], Norm, padding_mode = padding_mode) 284 | 285 | if len(self.num_layer_list) >= 2: 286 | # you have at least two numbers in self.num_layer_list 287 | self.layer2 = self._make_layer(block, planes, self.num_layer_list[1], Norm, stride=2, padding_mode = padding_mode) 288 | 289 | if len(self.num_layer_list) == 3: 290 | # you have three numbers in self.num_layer_list 291 | self.layer3 = self._make_layer(block, planes, self.num_layer_list[2], Norm, stride=2, padding_mode = padding_mode) 292 | 293 | # self.avgpool = nn.AvgPool1d(8, stride=1) 294 | # self.fc = nn.Linear(64 * block.expansion, num_classes) 295 | 296 | self._init_weights() 297 | 298 | 299 | def forward(self, x): 300 | ''' 301 | Args: 302 | x: shape (batch_size, in_channels, seq_len) 303 | Return: 304 | x: shape (batch_size, out_channels) 305 | ''' 306 | # x: shape (batch_size, out_channels, seq_len) 307 | x = self.conv1(x) 308 | x = self.norm1(x) 309 | x = self.relu(x) 310 | # print("conv1:", x.shape) 311 | 312 | # x: shape (batch_size, out_channels, seq_len/2) 313 | x = self.maxpool(x) 314 | 315 | if len(self.num_layer_list) >= 1: 316 | # After 1st block: shape (batch_size, out_channels, seq_len/2) 317 | # x: shape (batch_size, out_channels, seq_len/2) 318 | x = self.layer1(x) 319 | # print("layer1:", x.shape) 320 | 321 | if len(self.num_layer_list) >= 2: 322 | # After 1st block: shape (batch_size, out_channels, (seq_len+2)/4 ) 323 | # x: shape (batch_size, out_channels, (seq_len+2)/4 ) 324 | x = self.layer2(x) 325 | # print("layer2:", x.shape) 326 | 327 | if len(self.num_layer_list) == 3: 328 | # After 1st block: shape (batch_size, out_channels, (seq_len+6)/8 ) 329 | # x: shape (batch_size, out_channels, (seq_len+6)/8 ) 330 | x = self.layer3(x) 331 | # print("layer3:", x.shape) 332 | 333 | # global average pool 334 | # x: shape (batch_size, out_channels) 335 | x = torch.mean(x, dim = -1, keepdim = False) 336 | # print("avgpool:", x.shape) 337 | 338 | # x: shape (batch_size, 64*expansion, (seq_len-25)/4 ) 339 | # x = self.avgpool(x) 340 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 ) 341 | # x = x.view(x.size(0), -1) 342 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 ) 343 | # x = self.fc(x) 344 | # print("output:", x.shape) 345 | 346 | return x 347 | 348 | 349 | def _init_weights(self): 350 | for m in self.modules(): 351 | if isinstance(m, nn.Conv1d): 352 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 353 | n = m.kernel_size[0] * m.out_channels 354 | m.weight.data.normal_(0, sqrt(2./n)) 355 | elif isinstance(m, nn.BatchNorm1d): 356 | m.weight.data.fill_(1) 357 | m.bias.data.zero_() 358 | # elif isinstance(m, ModeNorm): 359 | # m.alpha.data.fill_(1) 360 | # m.beta.data.zero_() 361 | 362 | 363 | def _make_layer(self, block, planes, blocks, norm, stride=1, padding_mode = 'circular'): 364 | downsample = None 365 | if (stride != 1) or (self.inplanes != planes * block.expansion): 366 | downsample = nn.Sequential( 367 | nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 368 | norm(planes * block.expansion) 369 | ) 370 | 371 | layers = [] 372 | layers.append(block(self.inplanes, planes, norm, stride, downsample, padding_mode)) 373 | # self.inplanes = planes * block.expansion 374 | for _ in range(1, blocks): 375 | layers.append(block(self.inplanes, planes, norm, padding_mode = padding_mode)) 376 | 377 | return nn.Sequential(*layers) 378 | 379 | 380 | class ResNet1DLucas(nn.Module): 381 | def __init__(self, block, layers, inplanes, num_classes, padding_mode = 'circular'): 382 | ''' 383 | Args: 384 | layers: [num_blocks0, num_blocks1, num_blocks2] 385 | ''' 386 | super(ResNet1DLucas, self).__init__() 387 | # self.mn = config.mn 388 | 389 | # if config.mn == "full": 390 | # Norm = functools.partial(ModeNorm, momentum=config.momentum, n_components=config.num_components) 391 | # elif config.mn == "init": 392 | # InitNorm = functools.partial(ModeNorm, momentum=config.momentum, n_components=config.num_components) 393 | # Norm = functools.partial(BatchNorm1d, momentum=config.momentum) 394 | Norm = functools.partial(BatchNorm1d) 395 | 396 | self.inplanes = 16 397 | self.conv1 = nn.Conv1d(inplanes, 16, kernel_size=3, stride=1, padding=1, padding_mode = padding_mode, bias=False) 398 | 399 | # self.norm1 = InitNorm(16) if config.mn == "init" else Norm(16) 400 | self.norm1 = Norm(16) 401 | self.relu = nn.ReLU(inplace=True) 402 | 403 | self.layer1 = self._make_layer(block, 16, layers[0], Norm, padding_mode = padding_mode) 404 | self.layer2 = self._make_layer(block, 32, layers[1], Norm, stride=2, padding_mode = padding_mode) 405 | self.layer3 = self._make_layer(block, 64, layers[2], Norm, stride=2, padding_mode = padding_mode) 406 | # self.avgpool = nn.AvgPool1d(8, stride=1) 407 | self.fc = nn.Linear(64 * block.expansion, num_classes) 408 | 409 | self._init_weights() 410 | 411 | 412 | def forward(self, x): 413 | ''' 414 | Args: 415 | x: shape (batch_size, 3, seq_len) 416 | ''' 417 | # x: shape (batch_size, 16, seq_len) 418 | x = self.conv1(x) 419 | x = self.norm1(x) 420 | x = self.relu(x) 421 | 422 | # After 1st block: shape (batch_size, 16*expansion, seq_len) 423 | # x: shape (batch_size, 16*expansion, seq_len) 424 | x = self.layer1(x) 425 | 426 | # After 1st block: shape (batch_size, 32*expansion, (seq_len+1)/2 ) 427 | # x: shape (batch_size, 32*expansion, (seq_len+1)/2 ) 428 | x = self.layer2(x) 429 | 430 | # After 1st block: shape (batch_size, 64*expansion, (seq_len+3)/4 ) 431 | # x: shape (batch_size, 64*expansion, (seq_len+3)/4 ) 432 | x = self.layer3(x) 433 | 434 | # global average pool 435 | # x: shape (batch_size, 64*expansion) 436 | x = torch.mean(x, dim = -1, keepdim = False) 437 | 438 | # x: shape (batch_size, 64*expansion, (seq_len-25)/4 ) 439 | # x = self.avgpool(x) 440 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 ) 441 | # x = x.view(x.size(0), -1) 442 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 ) 443 | x = self.fc(x) 444 | 445 | return x 446 | 447 | 448 | def _init_weights(self): 449 | for m in self.modules(): 450 | if isinstance(m, nn.Conv1d): 451 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 452 | n = m.kernel_size[0] * m.out_channels 453 | m.weight.data.normal_(0, sqrt(2./n)) 454 | elif isinstance(m, nn.BatchNorm1d): 455 | m.weight.data.fill_(1) 456 | m.bias.data.zero_() 457 | # elif isinstance(m, ModeNorm): 458 | # m.alpha.data.fill_(1) 459 | # m.beta.data.zero_() 460 | 461 | 462 | def _make_layer(self, block, planes, blocks, norm, stride=1, padding_mode = 'circular'): 463 | downsample = None 464 | if (stride != 1) or (self.inplanes != planes * block.expansion): 465 | downsample = nn.Sequential( 466 | nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 467 | norm(planes * block.expansion) 468 | ) 469 | 470 | layers = [] 471 | layers.append(block(self.inplanes, planes, norm, stride, downsample, padding_mode)) 472 | self.inplanes = planes * block.expansion 473 | for _ in range(1, blocks): 474 | layers.append(block(self.inplanes, planes, norm, padding_mode = padding_mode)) 475 | 476 | return nn.Sequential(*layers) 477 | 478 | 479 | class BasicBlock(nn.Module): 480 | expansion=1 481 | 482 | def __init__(self, inplanes, planes, norm, stride=1, downsample=None, padding_mode = 'circular'): 483 | super(BasicBlock, self).__init__() 484 | self.conv1 = self._conv3(inplanes, planes, stride, padding_mode = padding_mode) 485 | self.norm1 = norm(planes) 486 | self.relu = nn.ReLU(inplace=True) 487 | self.conv2 = self._conv3(planes, planes, padding_mode = padding_mode) 488 | self.norm2 = norm(planes) 489 | self.downsample = downsample 490 | self.stride = stride 491 | self.padding_mode = padding_mode 492 | 493 | 494 | def forward(self, x): 495 | ''' 496 | Args: 497 | x: shape (batch_size, in_planes, seq_len) 498 | return: 499 | out: shape (batch_size, planes, (seq_len-1)/stride + 1 ) 500 | ''' 501 | residual = x 502 | 503 | # out: shape (batch_size, planes, (seq_len-1)/stride + 1 ) 504 | out = self.conv1(x) 505 | out = self.norm1(out) 506 | out = self.relu(out) 507 | 508 | # out: shape (batch_size, planes, (seq_len-1)/stride + 1 ) 509 | out = self.conv2(out) 510 | out = self.norm2(out) 511 | 512 | if self.downsample is not None: 513 | # residual: shape (batch_size, planes, (seq_len-1)/stride + 1 ) 514 | residual = self.downsample(x) 515 | 516 | out += residual 517 | out = self.relu(out) 518 | 519 | return out 520 | 521 | 522 | def _conv3(self, in_planes, out_planes, stride=1, padding_mode = 'circular'): 523 | '''3x3 convolution with padding''' 524 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, padding_mode = padding_mode, bias=False) 525 | 526 | 527 | class BottleneckBlock(nn.Module): 528 | expansion = 4 529 | 530 | def __init__(self, in_planes, planes, norm, stride=1, downsample=None, padding_mode = 'circular', ): 531 | super(BottleneckBlock, self).__init__() 532 | self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=1, bias=False) 533 | self.norm1 = norm(planes) 534 | self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=stride, padding=1, padding_mode = padding_mode, bias=False) 535 | self.norm2 = norm(planes) 536 | self.conv3 = nn.Conv1d(planes, self.expansion*planes, kernel_size=1, bias=False) 537 | self.norm3 = norm(self.expansion*planes) 538 | 539 | self.shortcut = nn.Sequential() 540 | if stride != 1 or in_planes != self.expansion*planes: 541 | self.shortcut = nn.Sequential( 542 | nn.Conv1d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 543 | norm(self.expansion*planes) 544 | ) 545 | 546 | 547 | def forward(self, x): 548 | ''' 549 | Args: 550 | x: shape (batch_size, in_planes, seq_len) 551 | return: 552 | out: shape (batch_size, expansion*planes, (seq_len-1)/stride + 1 ) 553 | ''' 554 | # out: shape (batch_size, planes, seq_len) 555 | out = F.relu(self.norm1(self.conv1(x))) 556 | # out: shape (batch_size, planes, (seq_len-1)/stride + 1 ) 557 | out = F.relu(self.norm2(self.conv2(out))) 558 | # out: shape (batch_size, expansion*planes, (seq_len-1)/stride + 1 ) 559 | out = self.norm3(self.conv3(out)) 560 | # self.shortcut(x): shape (batch_size, expansion*planes, (seq_len-1)/stride + 1 ) 561 | # out: shape (batch_size, expansion*planes, (seq_len-1)/stride + 1 ) 562 | out += self.shortcut(x) 563 | out = F.relu(out) 564 | return out 565 | --------------------------------------------------------------------------------