├── polygoncode
├── polygonembed
│ ├── __init__.py
│ ├── dbtopo
│ │ ├── train_generative.py
│ │ └── train_spa_rel.py
│ ├── animal
│ │ ├── train_img_cla.py
│ │ └── train_pgon.py
│ ├── sftp-config.json
│ ├── mnist
│ │ └── train_pgon.py
│ ├── lenet.py
│ ├── resnet2d.py
│ ├── ops.py
│ ├── spec_pool.py
│ ├── atten.py
│ ├── data_util.py
│ ├── PolygonDecoder.py
│ ├── trainer_img.py
│ ├── dla_resnext.py
│ ├── ddsl.py
│ ├── module.py
│ ├── ddsl_utils.py
│ ├── dla.py
│ └── resnet.py
└── 1_pgon_dbtopo.sh
├── image
└── model.png
├── requirements.txt
├── data_processing
└── dbtopo
│ └── data
│ └── spa_rel.csv
├── .gitignore
├── README.md
└── LICENSE
/polygoncode/polygonembed/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/image/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengchenmai/polygon_encoder/HEAD/image/model.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | backoff==2.2.1
2 | Fiona==1.8.18
3 | geopandas==0.8.1
4 | matplotlib==3.3.3
5 | numpy==1.19.5
6 | pandas==1.2.0
7 | Requests==2.31.0
8 | scikit_learn==0.24.1
9 | Shapely==1.7.1
10 | torch==1.7.1+cu110
11 | torchvision==0.8.2+cu110
12 | tqdm==4.55.1
13 |
--------------------------------------------------------------------------------
/data_processing/dbtopo/data/spa_rel.csv:
--------------------------------------------------------------------------------
1 | http://dbpedia.org/ontology/isPartOf
2 | http://dbpedia.org/ontology/location
3 | http://dbpedia.org/ontology/locatedInArea
4 | http://dbpedia.org/property/location
5 | http://dbpedia.org/property/east
6 | http://dbpedia.org/property/south
7 | http://dbpedia.org/property/west
8 | http://dbpedia.org/property/north
9 | http://dbpedia.org/property/southwest
10 | http://dbpedia.org/property/northeast
11 | http://dbpedia.org/property/southeast
12 | http://dbpedia.org/ontology/nearestCity
13 | http://dbpedia.org/property/northwest
14 | http://dbpedia.org/ontology/mouthPlace
--------------------------------------------------------------------------------
/polygoncode/polygonembed/dbtopo/train_generative.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | import torch.nn.functional as F
7 |
8 | from polygonembed.module import *
9 | from polygonembed.SpatialRelationEncoder import *
10 | from polygonembed.resnet import *
11 | from polygonembed.PolygonEncoder import *
12 | from polygonembed.utils import *
13 | from polygonembed.data_util import *
14 | from polygonembed.dataset import *
15 | from polygonembed.trainer import *
16 | from polygonembed.trainer_helper import *
17 |
18 |
19 |
20 | parser = make_args_parser()
21 | args = parser.parse_args()
22 |
23 |
24 | pgon_gdf = load_dataframe(args.data_dir, args.pgon_filename)
25 |
26 |
27 | trainer = Trainer(args, pgon_gdf, console = True)
28 |
29 |
30 | trainer.run_polygon_generative_train()
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled source #
2 | ###################
3 | *.com
4 | *.class
5 | *.dll
6 | *.exe
7 | *.o
8 | *.so
9 |
10 | # Packages #
11 | ############
12 | # it's better to unpack these files and commit the raw source
13 | # git has its own built in compression methods
14 | *.7z
15 | *.dmg
16 | *.gz
17 | *.iso
18 | *.jar
19 | *.rar
20 | *.tar
21 | *.zip
22 |
23 | # Logs and databases #
24 | ######################
25 | *.log
26 | *.sql
27 | *.sqlite
28 |
29 | # OS generated files #
30 | ######################
31 | .DS_Store
32 | .DS_Store?
33 | ._*
34 | .Spotlight-V100
35 | .Trashes
36 | ehthumbs.db
37 | Thumbs.db
38 |
39 | # code #
40 | *.pyc
41 |
42 | # data #
43 | ########
44 | # *.pkl
45 | *.pth
46 | *.pth.tar
47 | *.log
48 | *.crdownload
49 |
50 | # ignore directory
51 | ./polygoncode/model_dir/*
52 | ./polygoncode/polygonembed/.ipynb_checkpoints
53 |
54 |
55 |
--------------------------------------------------------------------------------
/polygoncode/1_pgon_dbtopo.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | python3 -m polygonembed.dbtopo.train_spa_rel.py \
4 | --data_dir ./data_proprocessing/dbtopo/output/ \
5 | --model_dir ./model_dir/dbtopo/ \
6 | --log_dir ./model_dir/dbtopo/ \
7 | --pgon_filename pgon_300_gdf_prj.pkl \
8 | --triple_filename pgon_triples_geom_300_norm_df.pkl \
9 | --geom_type_list norm \
10 | --data_split_num 0 \
11 | --task rel \
12 | --model_type cat \
13 | --pgon_enc nuft_ddsl \
14 | --pgon_embed_dim 512 \
15 | --nuft_freqXY 32 32 \
16 | --j 2 \
17 | --padding_mode circular \
18 | --do_polygon_random_start T \
19 | --do_data_augment F \
20 | --do_online_data_augment F \
21 | --data_augment_type none \
22 | --num_augment 0 \
23 | --dropout 0.1 \
24 | --spa_enc kdelta \
25 | --spa_embed_dim 26 \
26 | --freq 16 \
27 | --max_radius 2 \
28 | --min_radius 1e-6 \
29 | --spa_f_act relu \
30 | --freq_init geometric \
31 | --spa_enc_use_postmat F \
32 | --k_delta 12 \
33 | --num_hidden_layer 1 \
34 | --hidden_dim 512 \
35 | --use_layn T \
36 | --skip_connection T \
37 | --pgon_dec explicit_conv \
38 | --pgon_dec_grid_init circle \
39 | --pgon_dec_grid_enc_type spa_enc \
40 | --grt_loss_func LOOPL2 \
41 | --do_weight_norm F \
42 | --weight_decay 0.000 \
43 | --pgon_norm_reg_weight 0.02 \
44 | --task_loss_weight 0.95 \
45 | --grt_epoches 0 \
46 | --cla_epoches 5 \
47 | --log_every 100 \
48 | --val_every 100 \
49 | --batch_size 128 \
50 | --lr 0.01 \
51 | --opt adam \
52 | --act relu \
53 | --balanced_train_loader F \
54 | --device cuda:0 \
55 | --tb F
--------------------------------------------------------------------------------
/polygoncode/polygonembed/animal/train_img_cla.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import BatchNorm1d
5 |
6 |
7 | import functools
8 | from functools import partial
9 | from dataclasses import dataclass
10 | from collections import OrderedDict
11 |
12 |
13 | import fiona
14 | import geopandas as gpd
15 | import pandas as pd
16 |
17 | import math
18 | import os
19 | import json
20 | import pickle
21 | import numpy as np
22 | from tqdm import tqdm
23 | import matplotlib
24 | import matplotlib.pyplot as plt
25 | import random
26 | import re
27 | import requests
28 | import copy
29 |
30 | import shapely
31 | from shapely.ops import transform
32 |
33 | from shapely.geometry.point import Point
34 | from shapely.geometry.polygon import Polygon
35 | from shapely.geometry.multipolygon import MultiPolygon
36 | from shapely.geometry.linestring import LineString
37 | from shapely.geometry.polygon import LinearRing
38 | from shapely.geometry import box
39 |
40 | from fiona.crs import from_epsg
41 |
42 |
43 | from datetime import datetime
44 |
45 | import torch
46 | import torch.nn as nn
47 | from torch.nn import init
48 | import torch.nn.functional as F
49 |
50 |
51 | from polygonembed.dataset import *
52 | from polygonembed.resnet2d import *
53 | from polygonembed.model_utils import *
54 | from polygonembed.data_util import *
55 | from polygonembed.trainer_helper import *
56 |
57 | from polygonembed.trainer_img import *
58 |
59 |
60 | parser = make_args_parser()
61 | args = parser.parse_args()
62 |
63 | img_gdf = load_dataframe(args.data_dir, args.img_filename)
64 |
65 | trainer = Trainer(args, img_gdf, console = True)
66 |
67 | trainer.run_train()
--------------------------------------------------------------------------------
/polygoncode/polygonembed/sftp-config.json:
--------------------------------------------------------------------------------
1 | {
2 | // The tab key will cycle through the settings when first created
3 | // Visit http://wbond.net/sublime_packages/sftp/settings for help
4 |
5 | // sftp, ftp or ftps
6 | "type": "sftp",
7 |
8 | "save_before_upload": true,
9 | "upload_on_save": true,
10 | "sync_down_on_open": true,
11 | "sync_skip_deletes": false,
12 | "sync_same_age": true,
13 | "confirm_downloads": false,
14 | "confirm_sync": true,
15 | "confirm_overwrite_newer": false,
16 |
17 | "host": "stko-wrksrv.geog.ucsb.edu",
18 | "user": "gengchen",
19 | // "password": "",
20 | //"port": "22",
21 |
22 | "remote_path": "/home/gengchen/polygon2vec/polygonnet/polygoncode/polygonembed",
23 | "ignore_regexes": [
24 | "\\.sublime-(project|workspace)", "sftp-config(-alt\\d?)?\\.json",
25 | "sftp-settings\\.json", "/venv/", "\\.svn/", "\\.hg/", "\\.git/",
26 | "\\.bzr", "_darcs", "CVS", "\\.DS_Store", "Thumbs\\.db", "desktop\\.ini",
27 | "\\.pdf", "\\.txt", "/node_modules(_old)?", "/server/data(_old)?"
28 | ],
29 | //"file_permissions": "664",
30 | //"dir_permissions": "775",
31 |
32 | //"extra_list_connections": 0,
33 |
34 | "connect_timeout": 30,
35 | //"keepalive": 120,
36 | //"ftp_passive_mode": true,
37 | //"ftp_obey_passive_host": false,
38 | //"ssh_key_file": "~/.ssh/id_rsa",
39 | //"sftp_flags": ["-F", "/path/to/ssh_config"],
40 |
41 | //"preserve_modification_times": false,
42 | //"remote_time_offset_in_hours": 0,
43 | //"remote_encoding": "utf-8",
44 | //"remote_locale": "C",
45 | //"allow_config_upload": false,
46 | }
47 |
--------------------------------------------------------------------------------
/polygoncode/polygonembed/animal/train_pgon.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import BatchNorm1d
5 |
6 |
7 | import functools
8 | from functools import partial
9 | from dataclasses import dataclass
10 | from collections import OrderedDict
11 |
12 |
13 | import fiona
14 | import geopandas as gpd
15 | import pandas as pd
16 |
17 | import math
18 | import os
19 | import json
20 | import pickle
21 | import numpy as np
22 | from tqdm import tqdm
23 | import matplotlib
24 | import matplotlib.pyplot as plt
25 | import random
26 | import re
27 | import requests
28 | import copy
29 |
30 | import shapely
31 | from shapely.ops import transform
32 |
33 | from shapely.geometry.point import Point
34 | from shapely.geometry.polygon import Polygon
35 | from shapely.geometry.multipolygon import MultiPolygon
36 | from shapely.geometry.linestring import LineString
37 | from shapely.geometry.polygon import LinearRing
38 | from shapely.geometry import box
39 |
40 | from fiona.crs import from_epsg
41 |
42 |
43 | from datetime import datetime
44 |
45 | import torch
46 | import torch.nn as nn
47 | from torch.nn import init
48 | import torch.nn.functional as F
49 |
50 | from polygonembed.module import *
51 | from polygonembed.SpatialRelationEncoder import *
52 | from polygonembed.resnet import *
53 | from polygonembed.PolygonEncoder import *
54 | from polygonembed.utils import *
55 | from polygonembed.data_util import *
56 | from polygonembed.dataset import *
57 | from polygonembed.trainer import *
58 | from polygonembed.trainer_helper import *
59 |
60 |
61 |
62 | parser = make_args_parser()
63 | args = parser.parse_args()
64 |
65 | pgon_gdf = load_dataframe(args.data_dir, args.pgon_filename)
66 |
67 | trainer = Trainer(args, pgon_gdf, console = True)
68 |
69 | trainer.run_train()
--------------------------------------------------------------------------------
/polygoncode/polygonembed/dbtopo/train_spa_rel.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import fiona
5 | import geopandas as gpd
6 | import pandas as pd
7 |
8 | import math
9 | import os
10 | import json
11 | import pickle
12 | import numpy as np
13 | from tqdm import tqdm
14 | import matplotlib
15 | import matplotlib.pyplot as plt
16 | import random
17 | import re
18 | import requests
19 | import copy
20 |
21 | import shapely
22 | from shapely.ops import transform
23 |
24 | from shapely.geometry.point import Point
25 | from shapely.geometry.polygon import Polygon
26 | from shapely.geometry.multipolygon import MultiPolygon
27 | from shapely.geometry.linestring import LineString
28 | from shapely.geometry.polygon import LinearRing
29 | from shapely.geometry import box
30 |
31 | from fiona.crs import from_epsg
32 |
33 |
34 | from datetime import datetime
35 |
36 | import torch
37 | import torch.nn as nn
38 | from torch.nn import init
39 | import torch.nn.functional as F
40 |
41 | from polygonembed.module import *
42 | from polygonembed.SpatialRelationEncoder import *
43 | from polygonembed.resnet import *
44 | from polygonembed.PolygonEncoder import *
45 | from polygonembed.utils import *
46 | from polygonembed.data_util import *
47 | from polygonembed.dataset import *
48 | from polygonembed.trainer_helper import *
49 |
50 |
51 | from polygonembed.PolygonDecoder import *
52 | from polygonembed.enc_dec import *
53 | from polygonembed.model_utils import *
54 | from polygonembed.trainer_helper import *
55 |
56 | from polygonembed.trainer_dbtopo import *
57 |
58 |
59 | from polygonembed.ddsl_utils import *
60 | from polygonembed.ddsl import *
61 |
62 |
63 | parser = make_args_parser()
64 | args = parser.parse_args()
65 |
66 | triple_gdf = load_dataframe(args.data_dir, args.triple_filename)
67 |
68 | # pgon_gdf = load_dataframe(args.data_dir, args.pgon_filename)
69 |
70 | trainer = Trainer(args, triple_gdf, pgon_gdf = None, console = True)
71 |
72 | # trainer.run_train()
73 | trainer.run_eval(save_eval = True)
--------------------------------------------------------------------------------
/polygoncode/polygonembed/mnist/train_pgon.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import BatchNorm1d
5 |
6 |
7 | import functools
8 | from functools import partial
9 | from dataclasses import dataclass
10 | from collections import OrderedDict
11 |
12 |
13 | import fiona
14 | import geopandas as gpd
15 | import pandas as pd
16 |
17 | import math
18 | import os
19 | import json
20 | import pickle
21 | import numpy as np
22 | from tqdm import tqdm
23 | import matplotlib
24 | import matplotlib.pyplot as plt
25 | import random
26 | import re
27 | import requests
28 | import copy
29 |
30 | import shapely
31 | from shapely.ops import transform
32 |
33 | from shapely.geometry.point import Point
34 | from shapely.geometry.polygon import Polygon
35 | from shapely.geometry.multipolygon import MultiPolygon
36 | from shapely.geometry.linestring import LineString
37 | from shapely.geometry.polygon import LinearRing
38 | from shapely.geometry import box
39 |
40 | from fiona.crs import from_epsg
41 |
42 |
43 | from datetime import datetime
44 |
45 | import torch
46 | import torch.nn as nn
47 | from torch.nn import init
48 | import torch.nn.functional as F
49 |
50 | from polygonembed.module import *
51 | from polygonembed.SpatialRelationEncoder import *
52 | from polygonembed.resnet import *
53 | from polygonembed.PolygonEncoder import *
54 | from polygonembed.utils import *
55 | from polygonembed.data_util import *
56 | from polygonembed.dataset import *
57 | from polygonembed.trainer import *
58 | from polygonembed.trainer_helper import *
59 |
60 | from polygonembed.ddsl_utils import *
61 | from polygonembed.ddsl import *
62 |
63 |
64 |
65 | parser = make_args_parser()
66 | args = parser.parse_args()
67 |
68 | pgon_gdf = load_dataframe(args.data_dir, args.pgon_filename)
69 |
70 | trainer = Trainer(args, pgon_gdf, console = True)
71 |
72 | if args.load_model:
73 | print("load_model...")
74 | trainer.load_model()
75 |
76 | if not args.eval_only:
77 | trainer.run_train(save_eval = args.save_eval)
78 | else:
79 | trainer.run_eval(save_eval = args.save_eval)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Towards General-Purpose Representation Learning of Polygonal Geometries
2 | Code for recreting the results in [our GeoInformatica 2023 paper](https://link.springer.com/article/10.1007/s10707-022-00481-2)
3 |
4 |
5 | ## Related Link
6 | 1. [Springr Paper](https://link.springer.com/article/10.1007/s10707-022-00481-2)
7 | 2. [Arxiv Paper](https://arxiv.org/abs/2209.15458)
8 |
9 | ## Award
10 | 1. This paper won [AAG 2023 J. Warren Nystrom Award (1 award recipient every year)](https://www.aag.org/award-grant/nystrom/)
11 | 2. This paper won [AAG 2022 William L. Garrison Award for Best Dissertation in Computational Geography (1 award recipient every other year)](https://www.aag.org/aag-announces-2022-award-recipients/)
12 |
13 | ## Our Model Overview
14 |
15 |
16 |
17 |
18 | ## Dependencies
19 | - Python 3.7+
20 | - Torch 1.7.1+
21 | - Other required packages are summarized in `requirements.txt`.
22 |
23 | ## Data
24 | Download the required dbtopo datasets from [here](https://www.dropbox.com/scl/fo/ubokquibjibxqb71lduto/h?rlkey=gnex7g3gx51g06gmd1v1um9u1&dl=0) and put them in `./data_proprocessing/dbtopo/output/` folder. The folder has two datasets:
25 | 1) DBSR-46K: the `pgon_triples_geom_300_norm_df.pkl`file, a GeoDataFrame contain the DBSR-46K spatial relation prediction dataset created from DBpedia and OpenStreetMap. Each row indicates a triple from DBpedia and its subject and object are presented as a simple polygon with 300 vertices.
26 | 2) DBSR-cplx46K: the `pgon_triples_geom_300_norm_df_complex.pkl` file, a GeoDataFrame contain the spatial relation prediction dataset. The only difference is each row's subject and object are presented as a complex polygon with 300 vertices.
27 |
28 |
29 |
30 | ## Train and Evaluation
31 | The main code are located in `polygoncode` folder
32 |
33 | 1) `1_pgon_dbtopo.sh` do suprevised training on both DBSR-46K and DBSR-cplx46K datasets.
34 |
35 |
36 |
37 | ### Reference
38 | If you find our work useful in your research please consider citing [our GeoInformatica 2023 paper](https://link.springer.com/article/10.1007/s10707-022-00481-2).
39 | ```
40 | @article{mai2023towards,
41 | title={Towards general-purpose representation learning of polygonal geometries},
42 | author={Mai, Gengchen and Jiang, Chiyu and Sun, Weiwei and Zhu, Rui and Xuan, Yao and Cai, Ling and Janowicz, Krzysztof and Ermon, Stefano and Lao, Ni},
43 | journal={GeoInformatica},
44 | volume={27},
45 | number={2},
46 | pages={289--340},
47 | year={2023},
48 | publisher={Springer}
49 | }
50 | ```
51 |
52 |
53 | Please go to [Dr. Gengchen Mai's Homepage](https://gengchenmai.github.io/) for more information about Spatially Explicit Machine Learning and Artificial Intelligence.
--------------------------------------------------------------------------------
/polygoncode/polygonembed/lenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | import numpy as np
7 |
8 | class Norm(nn.Module):
9 | def __init__(self, mean=0, std=1):
10 | super(Norm, self).__init__()
11 | self.mean = mean
12 | self.std = std
13 |
14 | def forward(self, x):
15 | return (x - self.mean) / self.std
16 |
17 | class LeNet5(nn.Module):
18 | def __init__(self, in_channels, num_classes, signal_sizes=(28,28), hidden_dim = 250, mean=0, std=1):
19 | super(LeNet5, self).__init__()
20 | self.in_channels = in_channels
21 | self.num_classes = num_classes
22 | self.hidden_dim = hidden_dim
23 |
24 | self.conv1 = nn.Conv2d(in_channels, 10, kernel_size=5)
25 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
26 | self.norm = Norm(mean, std)
27 | self.conv2_drop = nn.Dropout2d()
28 |
29 | self.conv_embed_dim = self.compute_conv_embed_dim(signal_sizes, curent_channels = 20)
30 | self.fc1 = nn.Linear(self.conv_embed_dim, self.hidden_dim)
31 | self.fc2 = nn.Linear(self.hidden_dim, self.num_classes)
32 | self.signal_sizes = signal_sizes
33 |
34 | def compute_conv_embed_dim(self, signal_sizes, curent_channels):
35 | fx, fy = signal_sizes
36 | return math.floor((fx-12)/4) * math.floor((fy-12)/4) * curent_channels
37 |
38 | def forward(self, x):
39 | '''
40 | Args:
41 | x: shape (batch_size, in_channels, fx, fy), image tensor
42 | signal_sizes = (fx, fy)
43 | Return:
44 | x: shape (batch_size, num_classes)
45 | image class distribution
46 |
47 | '''
48 | batch_size, n_c, fx, fy = x.shape
49 | assert n_c == self.in_channels
50 | assert fx == self.signal_sizes[0]
51 | assert fy == self.signal_sizes[1]
52 |
53 | # x = x.view(-1, 1, self.signal_sizes[0], self.signal_sizes[1])
54 |
55 | # x: shape (batch_size, in_channels, fx, fy)
56 | x = self.norm(x)
57 | # self.conv1(x): shape [batch_size, 10, fx-4, fy-4]
58 | # F.max_pool2d(self.conv1(x), 2): shape [batch_size, 10, (fx-4)/2, (fy-4)/2 ]
59 | # x: shape [batch_size, 10, (fx-4)/2, (fy-4)/2 ]
60 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
61 |
62 | # self.conv2(x): [batch_size, 20, (fx-12)/2, (fy-12)/2 ]
63 | # self.conv2_drop(self.conv2(x)): [batch_size, 20, (fx-12)/2, (fy-12)/2 ]
64 | # F.max_pool2d(self.conv2_drop(self.conv2(x)), 2): [batch_size, 20, (fx-12)/4, (fy-12)/4 ]
65 | # x: [batch_size, 20, (fx-12)/4, (fy-12)/4 ]
66 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
67 |
68 |
69 | _, n_cc, fx_, fy_ = x.shape
70 | assert n_cc == 20
71 | assert fx_ == math.floor((fx-12)/4)
72 | assert fy_ == math.floor((fy-12)/4)
73 |
74 | # x: shape (batch_size, conv_embed_dim = 20 * floor((fx-12)/4) * floor((fy-12)/4) )
75 | x = x.reshape(batch_size, -1)
76 | # x: shape (batch_size, 250)
77 | x = F.relu(self.fc1(x))
78 | x = F.dropout(x, training=self.training)
79 |
80 | # x: shape (batch_size, num_classes)
81 | x = self.fc2(x)
82 | return x
--------------------------------------------------------------------------------
/polygoncode/polygonembed/resnet2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import torchvision
6 |
7 | class ResNet2D(torchvision.models.resnet.ResNet):
8 | '''
9 | See the torchvision version of ResNet
10 | https://github.com/pytorch/vision/blob/21153802a3086558e9385788956b0f2808b50e51/torchvision/models/resnet.py
11 | '''
12 | def __init__(self, block, layers, in_channels = 1, num_classes=1000, zero_init_residual=False):
13 | super(ResNet2D, self).__init__(block, layers, num_classes, zero_init_residual)
14 | self.conv1 = torch.nn.Conv2d(in_channels, 64,
15 | kernel_size=(7, 7),
16 | stride=(2, 2),
17 | padding=(3, 3), bias=False)
18 |
19 |
20 | def get_resnet_model(resnet_type, num_classes = 20, in_channels = 1):
21 | if resnet_type == "resnet18":
22 | model = ResNet2D(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], in_channels = in_channels, num_classes = num_classes)
23 | elif resnet_type == "resnet34":
24 | model = ResNet2D(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], in_channels = in_channels, num_classes = num_classes)
25 | elif resnet_type == "resnet50":
26 | model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], in_channels = in_channels, num_classes = num_classes)
27 | elif resnet_type == "resnet101":
28 | model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], in_channels = in_channels, num_classes = num_classes)
29 | elif resnet_type == "resnet152":
30 | model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 8, 36, 3], in_channels = in_channels, num_classes = num_classes)
31 |
32 | return model
33 |
34 | # def resnet18(pretrained=False, **kwargs):
35 | # """Constructs a ResNet-18 model.
36 | # Args:
37 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
38 | # """
39 | # model = ResNet2D(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], **kwargs)
40 | # # if pretrained:
41 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
42 | # return model
43 |
44 |
45 | # def resnet34(pretrained=False, **kwargs):
46 | # """Constructs a ResNet-34 model.
47 | # Args:
48 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
49 | # """
50 | # model = ResNet2D(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], **kwargs)
51 | # # if pretrained:
52 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
53 | # return model
54 |
55 |
56 | # def resnet50(pretrained=False, **kwargs):
57 | # """Constructs a ResNet-50 model.
58 | # Args:
59 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
60 | # """
61 | # model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs)
62 | # # if pretrained:
63 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
64 | # return model
65 |
66 |
67 | # def resnet101(pretrained=False, **kwargs):
68 | # """Constructs a ResNet-101 model.
69 | # Args:
70 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
71 | # """
72 | # model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], **kwargs)
73 | # # if pretrained:
74 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
75 | # return model
76 |
77 |
78 | # def resnet152(pretrained=False, **kwargs):
79 | # """Constructs a ResNet-152 model.
80 | # Args:
81 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
82 | # """
83 | # model = ResNet2D(torchvision.models.resnet.Bottleneck, [3, 8, 36, 3], **kwargs)
84 | # # if pretrained:
85 | # # model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
86 | # return model
--------------------------------------------------------------------------------
/polygoncode/polygonembed/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class _ModeNormalization(nn.Module):
7 | def __init__(self, dim, n_components, eps):
8 | super(_ModeNormalization, self).__init__()
9 | self.eps = eps
10 | self.dim = dim
11 | self.n_components = n_components
12 |
13 | self.alpha = nn.Parameter(torch.ones(1, dim, 1, 1))
14 | self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
15 | self.phi = lambda x: x.mean(3).mean(2)
16 |
17 |
18 | class ModeNorm(_ModeNormalization):
19 | """
20 | An implementation of mode normalization. Input samples x are allocated into individual modes (their number is controlled by n_components) by a gating network; samples belonging together are jointly normalized and then passed back to the network.
21 | args:
22 | dim: int
23 | momentum: float
24 | n_components: int
25 | eps: float
26 | """
27 | def __init__(self, dim, momentum, n_components, eps=1.e-5):
28 | super(ModeNorm, self).__init__(dim, n_components, eps)
29 |
30 | self.momentum = momentum
31 |
32 | self.x_ra = torch.zeros(n_components, 1, dim, 1, 1).cuda()
33 | self.x2_ra = torch.zeros(n_components, 1, dim, 1, 1).cuda()
34 |
35 | self.W = torch.nn.Linear(dim, n_components)
36 | self.W.weight.data = torch.ones(n_components, dim) / n_components + .01 * torch.randn(n_components, dim)
37 | self.softmax = torch.nn.Softmax(dim=1)
38 |
39 | self.weighted_mean = lambda w, x, n: (w * x).mean(3, keepdim=True).mean(2, keepdim=True).sum(0, keepdim=True) / n
40 |
41 |
42 | def forward(self, x):
43 | g = self._g(x)
44 | n_k = torch.sum(g, dim=1).squeeze()
45 |
46 | if self.training:
47 | self._update_running_means(g.detach(), x.detach())
48 |
49 | x_split = torch.zeros(x.size()).cuda().to(x.device)
50 |
51 | for k in range(self.n_components):
52 | if self.training:
53 | mu_k = self.weighted_mean(g[k], x, n_k[k])
54 | var_k = self.weighted_mean(g[k], (x - mu_k)**2, n_k[k])
55 | else:
56 | mu_k, var_k = self._mu_var(k)
57 | mu_k = mu_k.to(x.device)
58 | var_k = var_k.to(x.device)
59 |
60 | x_split += g[k] * ((x - mu_k) / torch.sqrt(var_k + self.eps))
61 |
62 | x = self.alpha * x_split + self.beta
63 |
64 | return x
65 |
66 |
67 | def _g(self, x):
68 | """
69 | Image inputs are first flattened along their height and width dimensions by phi(x), then mode memberships are determined via a linear transformation, followed by a softmax activation. The gates are returned with size (k, n, c, 1, 1).
70 | args:
71 | x: torch.Tensor
72 | returns:
73 | g: torch.Tensor
74 | """
75 | g = self.softmax(self.W(self.phi(x))).transpose(0, 1)[:, :, None, None, None]
76 | return g
77 |
78 |
79 | def _mu_var(self, k):
80 | """
81 | At test time, this function is used to compute the k'th mean and variance from weighted running averages of x and x^2.
82 | args:
83 | k: int
84 | returns:
85 | mu, var: torch.Tensor, torch.Tensor
86 | """
87 | mu = self.x_ra[k]
88 | var = self.x2_ra[k] - (self.x_ra[k] ** 2)
89 | return mu, var
90 |
91 |
92 | def _update_running_means(self, g, x):
93 | """
94 | Updates weighted running averages. These are kept and used to compute estimators at test time.
95 | args:
96 | g: torch.Tensor
97 | x: torch.Tensor
98 | """
99 | n_k = torch.sum(g, dim=1).squeeze()
100 |
101 | for k in range(self.n_components):
102 | x_new = self.weighted_mean(g[k], x, n_k[k])
103 | x2_new = self.weighted_mean(g[k], x**2, n_k[k])
104 |
105 | # ensure that tensors are on the right devices
106 | self.x_ra = self.x_ra.to(x_new.device)
107 | self.x2_ra = self.x2_ra.to(x2_new.device)
108 | self.x_ra[k] = self.momentum * x_new + (1-self.momentum) * self.x_ra[k]
109 | self.x2_ra[k] = self.momentum * x2_new + (1-self.momentum) * self.x2_ra[k]
110 |
--------------------------------------------------------------------------------
/polygoncode/polygonembed/spec_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | from torch.autograd.function import once_differentiable
5 | import math
6 | import numpy as np
7 | from math import pi
8 |
9 | from polygonembed.ddsl_utils import *
10 |
11 |
12 | class SpecturalPooling(nn.Module):
13 | """
14 | Do spectural pooling similar to https://arxiv.org/pdf/1506.03767.pdf
15 | """
16 | def __init__(self, min_freqXY_ratio = 0.5, max_pool_freqXY = [10, 10], freqXY = [16, 16], device = "cpu"):
17 | """
18 | Args:
19 | min_freqXY_ratio: c_m, the ratio to control the lowest fx, fy to keep,
20 | c_m = alpha + m/M * (beta - alpha)
21 | maxPoolFreqXY: H_m, the maximum frequency we want to keep,
22 | freqXY: res in DDSL_spec(), [fx, fy]
23 | fx, fy: number of frequency in X or Y dimention
24 | device:
25 | """
26 | super(SpecturalPooling, self).__init__()
27 | self.min_freqXY_ratio = min_freqXY_ratio
28 | self.max_pool_freqXY = max_pool_freqXY
29 | self.freqXY = freqXY
30 | self.device = device
31 |
32 | # the maximum pool fx, fy should smaller/equal to freqXY
33 | assert len(freqXY) == len(max_pool_freqXY)
34 | for i in range(len(freqXY)):
35 | assert freqXY[i] >= max_pool_freqXY[i]
36 |
37 |
38 |
39 |
40 | def get_pool_freqXY(self, freqXY_ratio, max_pool_freqXY):
41 | min_pool_freqXY = []
42 | for f in max_pool_freqXY:
43 | min_pool_freqXY.append(math.floor(f * freqXY_ratio))
44 | return min_pool_freqXY
45 |
46 | def make_select_mask(self, fx, fy, maxfx, maxfy, y_dim):
47 | '''
48 | mask the non select freq elements into 0
49 | Args:
50 | fx, fy: the select dimention in x, y axis
51 | maxfx, maxfy: the original freq dimention
52 | y_dim: the y dimention
53 | Return:
54 | mask: torch.FloatTensor(), shape (maxfx, y_dim, 1)
55 |
56 | '''
57 | assert maxfy > y_dim
58 | fxtop, fxlow, fytop = self.get_freqXY_select_idx(fx, fy, maxfx, maxfy)
59 |
60 | mask = np.zeros((maxfx, y_dim))
61 | mask[0:fxtop, 0:fytop] = 1
62 | mask[-fxlow:, 0:fytop] = 1
63 |
64 | mask = torch.FloatTensor(mask).unsqueeze(-1)
65 | return mask
66 |
67 | def get_freqXY_select_idx(self, fx, fy, maxfx, maxfy):
68 | '''
69 | mask the non select freq elements into 0
70 | Args:
71 | fx, fy: the select dimention in x, y axis
72 | maxfx, maxfy: the original freq dimention
73 | Return:
74 | fxtop, fxlow: 0..fxtop and -fylow ... -1 selected
75 | fytop: 0...fytop selected
76 |
77 | '''
78 | fxtop = math.ceil(fx/2)
79 | fxlow = fx - fxtop
80 |
81 | fytop = math.ceil(fy/2)
82 | return fxtop, fxlow, fytop
83 |
84 | def crop_freqmap(self, x, cur_pool_freqXY, freqXY):
85 | '''
86 | crop the frequency map to (fx, fy)
87 | Args:
88 | x: torch.FloatTensor(), the input features (e.g., polygons) in the specture domain
89 | shape (batch_size, n_channel = 1, maxfx, maxfy//2+1, 2)
90 | cur_pool_freqXY: the pool freq dimention
91 | freqXY: the original dimention
92 | Return:
93 | spec_pool_res: torch.FloatTensor(),
94 | shape (batch_size, n_channel = 1, fx, ceil(fy/2), 2)
95 | '''
96 | maxfx, maxfy = freqXY
97 | fx, fy = cur_pool_freqXY
98 | fxtop, fxlow, fytop = self.get_freqXY_select_idx(fx, fy, maxfx, maxfy)
99 |
100 | upblock = x[:, :, 0:fxtop, 0:fytop, :]
101 | lowblock = x[:, :, -fxlow:, 0:fytop, :]
102 |
103 | spec_pool_res = torch.cat([upblock, lowblock], dim = 2)
104 | return spec_pool_res
105 |
106 | def forward(self, x):
107 | '''
108 | Args:
109 | x: torch.FloatTensor(), the input features (e.g., polygons) in the specture domain
110 | shape (batch_size, n_channel = 1, fx, fy//2+1, 2)
111 | '''
112 | freqXY_ratio = np.random.uniform(self.min_freqXY_ratio, 1)
113 | # compute the current no-maked X, Y dimention
114 | cur_pool_freqXY = self.get_min_pool_freqXY(freqXY_ratio, max_pool_freqXY)
115 |
116 | batch_size, n_channel, fx, fy2, n_dim = x.shape
117 | assert fx == self.freqXY[0]
118 | assert fy2 == self.freqXY[1]//2 + 1
119 |
120 | # crop the input x into max_pool_freqXY
121 | # spec_pool_res: shape (batch_size, n_channel = 1, max_pool_freqXY[0], ceil(max_pool_freqXY[1]/2), 2)
122 | spec_pool_res = self.crop_freqmap(x, max_pool_freqXY, freqXY)
123 |
124 | # mask the freq element outside of cur_pool_freqXY
125 | # mask: shape (max_pool_freqXY[0], ceil(max_pool_freqXY[1]/2), 1)
126 | mask = make_select_mask(fx = cur_pool_freqXY[0],
127 | fy = cur_pool_freqXY[1],
128 | maxfx = max_pool_freqXY[0],
129 | maxfy = max_pool_freqXY[1],
130 | y_dim = spec_pool_res.shape[-2])
131 |
132 | # spec_pool_mask_res: shape (batch_size, n_channel = 1, max_pool_freqXY[0], ceil(max_pool_freqXY[1]/2), 2)
133 | spec_pool_mask_res = spec_pool_res * mask.to(x.device)
134 | return spec_pool_mask_res
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/polygoncode/polygonembed/atten.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import pickle
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from math import sqrt
10 |
11 |
12 | class AttentionSet(nn.Module):
13 | def __init__(self, mode_dims, att_reg=0., att_tem=1., att_type="whole", bn='no', nat=1, name="Real"):
14 | '''
15 | Paper https://openreview.net/forum?id=BJgr4kSFDS
16 | Modified from https://github.com/hyren/query2box/blob/master/codes/model.py
17 | '''
18 | super(AttentionSet, self).__init__()
19 |
20 | self.att_reg = att_reg
21 | self.att_type = att_type
22 | self.att_tem = att_tem
23 | self.Attention_module = Attention(mode_dims, att_type=att_type, bn=bn, nat=nat, name = name)
24 |
25 | def forward(self, embeds):
26 | '''
27 | Args:
28 | embeds: shape (B, mode_dims, L), B: batch_size; L: number of embeddings we need to aggregate
29 | Return:
30 | combined: shape (B, mode_dims)
31 | '''
32 | # temp: shape (B, 1, L) or (B, mode_dims, L)
33 | temp = (self.Attention_module(embeds) + self.att_reg)/(self.att_tem+1e-4)
34 | if self.att_type == 'whole':
35 | # whole: we combine embeddings with a scalar attention coefficient
36 | # attens: shape (B, 1, L)
37 | attens = F.softmax(temp, dim=-1)
38 | # combined: shape (B, mode_dims, L)
39 | combined = embeds * attens
40 | # combined: shape (B, mode_dims)
41 | combined = torch.sum(combined, dim = -1)
42 |
43 | elif self.att_type == 'ele':
44 | # ele: we combine embeds1 and embeds2 with a vector attention coefficient
45 | # attens: shape (B, mode_dims, L)
46 | attens = F.softmax(temp, dim=-1)
47 | # combined: shape (B, mode_dims, L)
48 | combined = embeds * attens
49 | # combined: shape (B, mode_dims)
50 | combined = torch.sum(combined, dim = -1)
51 |
52 | return combined
53 |
54 |
55 |
56 | class Attention(nn.Module):
57 | def __init__(self, mode_dims, att_type="whole", bn = 'no', nat = 1, name="Real"):
58 | '''
59 | Paper https://openreview.net/forum?id=BJgr4kSFDS
60 | Modified from https://github.com/hyren/query2box/blob/master/codes/model.py
61 |
62 | Args:
63 | mode_dims: the input embedding dimention we need to do attention
64 | att_type: the type of attention
65 | whole: we combine embeddings with a scalar attention coefficient
66 | ele: we combine embedding with a vector attention coefficient
67 | bn: the type of batch noralization type
68 | no: no batch norm
69 | before: batch norm before ReLU
70 | after: batch norm after ReLU
71 | nat: scalar = [1,2,3], the number of attention matrix we want to go through before atten_mats2
72 | '''
73 | super(Attention, self).__init__()
74 |
75 | self.bn = bn
76 | self.nat = nat
77 |
78 | self.atten_mats1 = nn.Parameter(torch.FloatTensor(mode_dims, mode_dims))
79 | nn.init.xavier_uniform_(self.atten_mats1)
80 | self.register_parameter("atten_mats1_%s"%name, self.atten_mats1)
81 | if self.nat >= 2:
82 | self.atten_mats1_1 = nn.Parameter(torch.FloatTensor(mode_dims, mode_dims))
83 | nn.init.xavier_uniform_(self.atten_mats1_1)
84 | self.register_parameter("atten_mats1_1_%s"%name, self.atten_mats1_1)
85 | if self.nat >= 3:
86 | self.atten_mats1_2 = nn.Parameter(torch.FloatTensor(mode_dims, mode_dims))
87 | nn.init.xavier_uniform_(self.atten_mats1_2)
88 | self.register_parameter("atten_mats1_2_%s"%name, self.atten_mats1_2)
89 | if bn != 'no':
90 | self.bn1 = nn.BatchNorm1d(mode_dims)
91 | self.bn1_1 = nn.BatchNorm1d(mode_dims)
92 | self.bn1_2 = nn.BatchNorm1d(mode_dims)
93 | if att_type == 'whole':
94 | self.atten_mats2 = nn.Parameter(torch.FloatTensor(1, mode_dims))
95 | elif att_type == 'ele':
96 | self.atten_mats2 = nn.Parameter(torch.FloatTensor(mode_dims, mode_dims))
97 | nn.init.xavier_uniform_(self.atten_mats2)
98 | self.register_parameter("atten_mats2_%s"%name, self.atten_mats2)
99 |
100 | def forward(self, center_embed):
101 | '''
102 | Args:
103 | center_embed: shape (B, mode_dims, L), B: batch_size; L: number of embeddings we need to aggregate
104 | Return:
105 | temp3:
106 | if att_type == 'whole':
107 | temp3: shape (B, 1, L)
108 | elif att_type == 'ele':
109 | temp3: shape (B, mode_dims, L)
110 | '''
111 | temp1 = center_embed
112 | if self.nat >= 1:
113 | # temp2: shape (B, mode_dims, L)
114 | temp2 = torch.einsum('kc,bcl->bkl', self.atten_mats1, temp1)
115 | if self.bn == 'no':
116 | temp2 = F.relu(temp2)
117 | elif self.bn == 'before':
118 | temp2 = F.relu(self.bn1(temp2))
119 | elif self.bn == 'after':
120 | temp2 = self.bn1(F.relu(temp2))
121 | # temp2: shape (B, mode_dims, L)
122 | if self.nat >= 2:
123 | temp2 = torch.einsum('kc,bcl->bkl', self.atten_mats1_1, temp2)
124 | if self.bn == 'no':
125 | temp2 = F.relu(temp2)
126 | elif self.bn == 'before':
127 | temp2 = F.relu(self.bn1_1(temp2))
128 | elif self.bn == 'after':
129 | temp2 = self.bn1_1(F.relu(temp2))
130 | # temp2: shape (B, mode_dims, L)
131 | if self.nat >= 3:
132 | temp2 = torch.einsum('kc,bcl->bkl', self.atten_mats1_2, temp2)
133 | if self.bn == 'no':
134 | temp2 = F.relu(temp2)
135 | elif self.bn == 'before':
136 | temp2 = F.relu(self.bn1_2(temp2))
137 | elif self.bn == 'after':
138 | temp2 = self.bn1_2(F.relu(temp2))
139 | # temp2: shape (B, mode_dims, L)
140 |
141 | temp3 = torch.einsum('kc,bcl->bkl', self.atten_mats2, temp2)
142 | '''
143 | if att_type == 'whole':
144 | temp3: shape (B, 1, L)
145 | elif att_type == 'ele':
146 | temp3: shape (B, mode_dims, L)
147 | '''
148 | return temp3
--------------------------------------------------------------------------------
/polygoncode/polygonembed/data_util.py:
--------------------------------------------------------------------------------
1 | import fiona
2 | import geopandas as gpd
3 | import pandas as pd
4 |
5 | import math
6 | import os
7 | import json
8 | import pickle
9 | import numpy as np
10 | from tqdm import tqdm
11 | import matplotlib
12 | import matplotlib.pyplot as plt
13 | import random
14 | import re
15 | import requests
16 | import copy
17 |
18 | import shapely
19 | from shapely.ops import transform
20 |
21 | from shapely.geometry.point import Point
22 | from shapely.geometry.polygon import Polygon
23 | from shapely.geometry.multipolygon import MultiPolygon
24 | from shapely.geometry.linestring import LineString
25 | from shapely.geometry.polygon import LinearRing
26 | from shapely.geometry import box
27 |
28 | from fiona.crs import from_epsg
29 |
30 |
31 | from datetime import datetime
32 |
33 |
34 | from polygonembed.dataset import *
35 |
36 |
37 |
38 | def json_load(filepath):
39 | with open(filepath, "r") as json_file:
40 | data = json.load(json_file)
41 | return data
42 |
43 | def json_dump(data, filepath, pretty_format = True):
44 | with open(filepath, 'w') as fw:
45 | if pretty_format:
46 | json.dump(data, fw, indent=2, sort_keys=True)
47 | else:
48 | json.dump(data, fw)
49 |
50 | def pickle_dump(obj, pickle_filepath):
51 | with open(pickle_filepath, "wb") as f:
52 | pickle.dump(obj, f, protocol=2)
53 |
54 | def pickle_load(pickle_filepath):
55 | with open(pickle_filepath, "rb") as f:
56 | obj = pickle.load(f)
57 | return obj
58 |
59 |
60 |
61 | # ########################## GeoData Utility
62 |
63 | def make_projected_gdf(gdf, geometry_col = "geometry", epsg_code = 4326):
64 | gdf[geometry_col] = gdf[geometry_col].to_crs(epsg=epsg_code)
65 | gdf.crs = from_epsg(epsg_code)
66 | return gdf
67 |
68 | def explode(indata):
69 | if type(indata) == gpd.GeoDataFrame:
70 | indf = indata
71 | elif type(indata) == str:
72 | indf = gpd.GeoDataFrame.from_file(indata)
73 | else:
74 | raise Exception("Input no recognized")
75 | outdf = gpd.GeoDataFrame(columns=indf.columns, crs = indf.crs)
76 | for idx, row in indf.iterrows():
77 | if type(row.geometry) == Polygon:
78 | outdf = outdf.append(row,ignore_index=True)
79 | if type(row.geometry) == MultiPolygon:
80 | multdf = gpd.GeoDataFrame(columns=indf.columns)
81 | recs = len(row.geometry)
82 | multdf = multdf.append([row]*recs,ignore_index=True)
83 | for geom in range(recs):
84 | multdf.loc[geom,'geometry'] = row.geometry[geom]
85 | outdf = outdf.append(multdf,ignore_index=True)
86 | return outdf
87 |
88 | def plot_multipolygon(multipygon, figsize = (32, 32), edgecolor='r', ax = None):
89 | if ax is None:
90 | fig, ax = plt.subplots(figsize = figsize)
91 | p = gpd.GeoDataFrame(geometry = [multipygon])
92 | p.geometry.boundary.plot(edgecolor=edgecolor, ax = ax)
93 | if ax is None:
94 | plt.show()
95 |
96 | def get_bbox_of_df(indf):
97 | bbox_df = indf["geometry"].bounds
98 |
99 | minx = np.min(bbox_df["minx"])
100 | miny = np.min(bbox_df["miny"])
101 | maxx = np.max(bbox_df["maxx"])
102 | maxy = np.max(bbox_df["maxy"])
103 | extent = (minx, miny, maxx, maxy)
104 | return bbox_df, extent
105 |
106 |
107 | def get_polygon_exterior_max_num_vert(indf, col = "geometry"):
108 | coord_len_list = list(indf[col].apply(lambda x: len(x.exterior.coords) ))
109 | return max(coord_len_list)
110 |
111 |
112 | def upsample_polygon_exterior_gdf_by_num_vert(indf, num_vert):
113 | return indf["geometry"].apply(
114 | lambda x: Polygon(line_interpolate_by_num_vert(x.exterior, num_vert = num_vert).coords ) )
115 |
116 | def line_interpolate_by_num_vert(geom, num_vert):
117 | if geom.geom_type in ['LineString', 'LinearRing']:
118 | num_vert_origin = len(geom.coords.xy[0])
119 |
120 | if num_vert_origin == num_vert:
121 | return geom
122 | elif num_vert_origin > num_vert:
123 | raise Exception("The original line has larger number of vetice then your input")
124 | num_vert_add = num_vert - num_vert_origin
125 |
126 | # print(num_vert, num_vert_origin, num_vert_add)
127 | pt_add_list = []
128 | dist_add_list = []
129 | for i in range(1, num_vert_add+1):
130 | pt_add = geom.interpolate(float(i) / (num_vert_add+1), normalized=True)
131 | dist_add = geom.project(pt_add)
132 |
133 | pt_add_list.append(pt_add)
134 | dist_add_list.append(dist_add)
135 |
136 | for idx in range(1, num_vert_origin - 1):
137 | pt = Point(geom.coords[idx])
138 | dist = geom.project(pt)
139 | insert_idx = np.searchsorted(dist_add_list, dist)
140 |
141 | dist_add_list = dist_add_list[:insert_idx] + [dist] + dist_add_list[insert_idx:]
142 | pt_add_list = pt_add_list[:insert_idx] + [pt] + pt_add_list[insert_idx:]
143 |
144 |
145 | pt_add_list = [Point(geom.coords[0])] + pt_add_list + [Point(geom.coords[0])]
146 | if geom.geom_type == 'LineString':
147 | return LineString(pt_add_list)
148 | elif geom.geom_type == 'LinearRing':
149 | line = LineString(pt_add_list)
150 | return LinearRing(line.coords)
151 | else:
152 | raise ValueError('unhandled geometry %s', (geom.geom_type,))
153 |
154 | def normalize_geometry_by_extent(geom, extent = None):
155 | '''
156 | Normalize the polygon coords to x: (-1, 1), y: (-1, 1)
157 | Args:
158 | geom: a geometry
159 | extent: (minx, miny, maxx, maxy)
160 | '''
161 | if extent is not None:
162 | minx, miny, maxx, maxy = extent
163 | else:
164 | minx, miny, maxx, maxy = geom.bounds
165 |
166 | assert minx < maxx
167 | assert miny < maxy
168 | # compute extent center
169 | x_c = (maxx + minx)/2
170 | y_c = (maxy + miny)/2
171 | # 1. affinity to the extent's center
172 | geom_aff = shapely.affinity.affine_transform(geom, matrix = [1, 0, 0, 1, -x_c, -y_c])
173 | # plot_multipolygon(geom_aff, figsize = (5, 5))
174 |
175 |
176 | deltax = maxx - minx
177 | deltay = maxy - miny
178 | if deltax >= deltay:
179 | max_len = deltax
180 | else:
181 | max_len = deltay
182 | # 2. scale to x: (-1, 1), y: (-1, 1)
183 | geom_scale = shapely.affinity.scale(geom_aff, xfact=2.0/max_len, yfact=2.0/max_len, zfact=0, origin='center')
184 | # plot_multipolygon(geom_scale, figsize = (5, 5))
185 | return geom_scale
186 |
187 |
188 |
189 | def load_dataframe(data_dir, filename):
190 | if filename.endswith(".pkl"):
191 | ingdf = pickle_load(os.path.join(data_dir, filename))
192 | else:
193 | raise Exception('Unknow file type')
194 | return ingdf
195 |
196 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/polygoncode/polygonembed/PolygonDecoder.py:
--------------------------------------------------------------------------------
1 | import fiona
2 | import geopandas as gpd
3 | import pandas as pd
4 |
5 | import math
6 | import os
7 | import json
8 | import pickle
9 | import numpy as np
10 | from tqdm import tqdm
11 | import matplotlib
12 | import matplotlib.pyplot as plt
13 | import random
14 | import re
15 | import requests
16 | import copy
17 |
18 | import shapely
19 | from shapely.ops import transform
20 |
21 | from shapely.geometry.point import Point
22 | from shapely.geometry.polygon import Polygon
23 | from shapely.geometry.multipolygon import MultiPolygon
24 | from shapely.geometry.linestring import LineString
25 | from shapely.geometry.polygon import LinearRing
26 | from shapely.geometry import box
27 |
28 | from fiona.crs import from_epsg
29 |
30 |
31 | from datetime import datetime
32 |
33 | import torch
34 | import torch.nn as nn
35 | from torch.nn import init
36 | import torch.nn.functional as F
37 |
38 |
39 |
40 | from polygonembed.module import *
41 | from polygonembed.utils import *
42 |
43 |
44 | class ExplicitMLPPolygonDecoder(nn.Module):
45 | """
46 |
47 | """
48 | def __init__(self, spa_enc, pgon_embed_dim, num_vert, pgon_dec_grid_init = "uniform", pgon_dec_grid_enc_type = "none",
49 | coord_dim = 2, extent = (-1, 1, -1, 1), device = "cpu"):
50 | """
51 | Args:
52 | spa_enc: a spatial encoder
53 | pgon_embed_dim: the output polygon embedding dimention
54 | num_vert: number of uniuqe vertices each generated polygon with have
55 | note that this does not include the extra point (same as the 1st one) to close the ring
56 | pgon_dec_grid_init: We generate a list of grid points for polygon decoder, the type of grid points are:
57 | uniform: points uniformly sampled from (-1, 1, -1, 1)
58 | circle: points sampled equal-distance on a circle whose radius is randomly sampled
59 | kdgrid: k-d regular grid, (num_vert - k^2) is uniformly sampled
60 | pgon_dec_grid_enc_type: the type to encode the grid point
61 | none: no encoding, use the original grid point
62 | spa_enc: use space encoder to encode grid point before
63 | coord_dim: 2
64 | device:
65 | """
66 | super(ExplicitMLPPolygonDecoder, self).__init__()
67 |
68 | self.spa_enc = spa_enc
69 | self.spa_embed_dim = self.spa_enc.spa_embed_dim
70 | self.pgon_embed_dim = pgon_embed_dim
71 | self.num_vert = num_vert
72 | self.pgon_dec_grid_init = pgon_dec_grid_init
73 | self.coord_dim = coord_dim
74 | self.extent = extent
75 |
76 | self.device = device
77 | self.grid_dim = 2
78 |
79 | self.pgon_dec_grid_enc_type = pgon_dec_grid_enc_type
80 |
81 | # Partly borrowed from atlasnetV2
82 | # define point generator
83 | if self.pgon_dec_grid_enc_type == "none":
84 | self.nlatent = self.pgon_embed_dim + self.grid_dim
85 | elif self.pgon_dec_grid_enc_type == "spa_enc":
86 | self.nlatent = self.pgon_embed_dim + self.spa_embed_dim
87 | else:
88 | raise Exception("Unknown pgon_dec_grid_enc_type")
89 |
90 | # by default the bias = True, so this is like a individual MLP
91 | self.conv1 = torch.nn.Conv1d(in_channels = self.nlatent, out_channels = self.nlatent, kernel_size = 1)
92 | self.conv2 = torch.nn.Conv1d(in_channels = self.nlatent, out_channels = self.nlatent//2, kernel_size = 1)
93 | self.conv3 = torch.nn.Conv1d(in_channels = self.nlatent//2, out_channels = self.nlatent//4, kernel_size = 1)
94 | self.conv4 = torch.nn.Conv1d(in_channels = self.nlatent//4, out_channels = coord_dim, kernel_size = 1)
95 |
96 | self.th = nn.Tanh()
97 | self.bn1 = torch.nn.BatchNorm1d(self.nlatent)
98 | self.bn2 = torch.nn.BatchNorm1d(self.nlatent//2)
99 | self.bn3 = torch.nn.BatchNorm1d(self.nlatent//4)
100 |
101 |
102 |
103 | def forward(self, pgon_embeds):
104 | '''
105 | Args:
106 | pgon_embeds: tensor, shape (batch_size, pgon_embed_dim)
107 | Return:
108 |
109 | polygons: shape (batch_size, num_vert, coord_dim = 2)
110 | rand_grid: shape (batch_size, 2, num_vert)
111 |
112 | '''
113 | device = pgon_embeds.device
114 | batch_size, pgon_embed_dim = pgon_embeds.shape
115 |
116 | # Concat grids and pgon_embeds: Bx(C+2)xN
117 | # pgon_embeds_dup: shape (batch_size, pgon_embed_dim, num_vert)
118 | pgon_embeds_dup = pgon_embeds.unsqueeze(2).expand(
119 | batch_size, pgon_embed_dim, self.num_vert).contiguous() # BxCxN
120 |
121 | # generate rand grids
122 | # rand_grid: shape (batch_size, 2, num_vert)
123 | rand_grid = generate_rand_grid(self.pgon_dec_grid_init,
124 | batch_size = batch_size,
125 | num_vert = self.num_vert,
126 | extent = self.extent,
127 | grid_dim = self.grid_dim)
128 | if self.pgon_dec_grid_enc_type == "none":
129 | rand_grid = torch.FloatTensor(rand_grid).to(device)
130 | elif self.pgon_dec_grid_enc_type == "spa_enc":
131 | # rand_grid: np.array(), shape (batch_size, num_vert, 2)
132 | rand_grid = rand_grid.transpose(0,2,1)
133 | # rand_grid: torch.tensor(), shape (batch_size, num_vert, spa_embed_dim)
134 | rand_grid = self.spa_enc(rand_grid)
135 | # rand_grid: torch.tensor(), shape (batch_size, spa_embed_dim, num_vert)
136 | rand_grid = rand_grid.permute(0,2,1)
137 |
138 | else:
139 | raise Exception("Unknown pgon_dec_grid_enc_type")
140 |
141 |
142 |
143 | # x: shape (batch_size, 2 + pgon_embed_dim, num_vert) or (batch_size, spa_embed_dim + pgon_embed_dim, num_vert)
144 | x = torch.cat([rand_grid, pgon_embeds_dup], dim=1)
145 |
146 | # Generate points
147 | # x: shape (batch_size, 2 + pgon_embed_dim, num_vert)
148 | x = F.relu(self.bn1(self.conv1(x)))
149 |
150 | # x: shape (batch_size, (2 + pgon_embed_dim)/2, num_vert)
151 | x = F.relu(self.bn2(self.conv2(x)))
152 |
153 | # x: shape (batch_size, (2 + pgon_embed_dim)/4, num_vert)
154 | x = F.relu(self.bn3(self.conv3(x)))
155 |
156 | # x: shape (batch_size, coord_dim, num_vert)
157 | x = self.th(self.conv4(x))
158 |
159 | # polygons: shape (batch_size, num_vert, coord_dim = 2)
160 | polygons = x.transpose(2, 1)
161 |
162 | # rand_grid_: shape (batch_size, num_vert, grid_dim = 2)
163 | rand_grid_ = rand_grid.transpose(2, 1)
164 | return polygons, rand_grid_
165 |
166 |
167 |
168 |
169 |
170 | class ExplicitConvPolygonDecoder(nn.Module):
171 | """
172 |
173 | """
174 | def __init__(self, spa_enc, pgon_embed_dim, num_vert, pgon_dec_grid_init = "uniform", pgon_dec_grid_enc_type = "none",
175 | coord_dim = 2, padding_mode = 'circular', extent = (-1, 1, -1, 1), device = "cpu"):
176 | """
177 | Args:
178 | spa_enc: a spatial encoder
179 | pgon_embed_dim: the output polygon embedding dimention
180 | num_vert: number of uniuqe vertices each generated polygon with have
181 | note that this does not include the extra point (same as the 1st one) to close the ring
182 | pgon_dec_grid_init: We generate a list of grid points for polygon decoder, the type of grid points are:
183 | uniform: points uniformly sampled from (-1, 1, -1, 1)
184 | circle: points sampled equal-distance on a circle whose radius is randomly sampled
185 | kdgrid: k-d regular grid, (num_vert - k^2) is uniformly sampled
186 | pgon_dec_grid_enc_type: the type to encode the grid point
187 | none: no encoding, use the original grid point
188 | spa_enc: use space encoder to encode grid point before
189 | coord_dim: 2
190 | padding_mode: 'circular'
191 | device:
192 | """
193 | super(ExplicitConvPolygonDecoder, self).__init__()
194 |
195 | self.spa_enc = spa_enc
196 | self.spa_embed_dim = self.spa_enc.spa_embed_dim
197 | self.pgon_embed_dim = pgon_embed_dim
198 | self.num_vert = num_vert
199 | self.pgon_dec_grid_init = pgon_dec_grid_init
200 | self.coord_dim = coord_dim
201 | self.extent = extent
202 |
203 | self.device = device
204 | self.grid_dim = 2
205 |
206 | self.pgon_dec_grid_enc_type = pgon_dec_grid_enc_type
207 |
208 | # Partly borrowed from atlasnetV2
209 | # define point generator
210 | if self.pgon_dec_grid_enc_type == "none":
211 | self.nlatent = self.pgon_embed_dim + self.grid_dim
212 | elif self.pgon_dec_grid_enc_type == "spa_enc":
213 | self.nlatent = self.pgon_embed_dim + self.spa_embed_dim
214 | else:
215 | raise Exception("Unknown pgon_dec_grid_enc_type")
216 |
217 | # by default the bias = True, so this is like a individual MLP
218 | self.conv1 = torch.nn.Conv1d(in_channels = self.nlatent, out_channels = self.nlatent, kernel_size = 3,
219 | stride=1, padding=1, padding_mode = padding_mode)
220 | self.conv2 = torch.nn.Conv1d(in_channels = self.nlatent, out_channels = self.nlatent//2, kernel_size = 3,
221 | stride=1, padding=1, padding_mode = padding_mode)
222 | self.conv3 = torch.nn.Conv1d(in_channels = self.nlatent//2, out_channels = self.nlatent//4, kernel_size = 3,
223 | stride=1, padding=1, padding_mode = padding_mode)
224 | self.conv4 = torch.nn.Conv1d(in_channels = self.nlatent//4, out_channels = coord_dim, kernel_size = 1)
225 |
226 | self.th = nn.Tanh()
227 | self.bn1 = torch.nn.BatchNorm1d(self.nlatent)
228 | self.bn2 = torch.nn.BatchNorm1d(self.nlatent//2)
229 | self.bn3 = torch.nn.BatchNorm1d(self.nlatent//4)
230 |
231 |
232 |
233 | def forward(self, pgon_embeds):
234 | '''
235 | Args:
236 | pgon_embeds: tensor, shape (batch_size, pgon_embed_dim)
237 | Return:
238 |
239 | polygons: shape (batch_size, num_vert, coord_dim = 2)
240 | rand_grid: shape (batch_size, 2, num_vert)
241 |
242 | '''
243 | device = pgon_embeds.device
244 | batch_size, pgon_embed_dim = pgon_embeds.shape
245 |
246 | # Concat grids and pgon_embeds: Bx(C+2)xN
247 | # pgon_embeds_dup: shape (batch_size, pgon_embed_dim, num_vert)
248 | pgon_embeds_dup = pgon_embeds.unsqueeze(2).expand(
249 | batch_size, pgon_embed_dim, self.num_vert).contiguous() # BxCxN
250 |
251 | # generate rand grids
252 | # rand_grid: shape (batch_size, 2, num_vert)
253 | rand_grid = generate_rand_grid(self.pgon_dec_grid_init,
254 | batch_size = batch_size,
255 | num_vert = self.num_vert,
256 | extent = self.extent,
257 | grid_dim = self.grid_dim)
258 | if self.pgon_dec_grid_enc_type == "none":
259 | rand_grid = torch.FloatTensor(rand_grid).to(device)
260 | elif self.pgon_dec_grid_enc_type == "spa_enc":
261 | # rand_grid: np.array(), shape (batch_size, num_vert, 2)
262 | rand_grid = rand_grid.transpose(0,2,1)
263 | # rand_grid: torch.tensor(), shape (batch_size, num_vert, spa_embed_dim)
264 | rand_grid = self.spa_enc(rand_grid)
265 | # rand_grid: torch.tensor(), shape (batch_size, spa_embed_dim, num_vert)
266 | rand_grid = rand_grid.permute(0,2,1)
267 |
268 | else:
269 | raise Exception("Unknown pgon_dec_grid_enc_type")
270 |
271 |
272 |
273 | # x: shape (batch_size, 2 + pgon_embed_dim, num_vert) or (batch_size, spa_embed_dim + pgon_embed_dim, num_vert)
274 | x = torch.cat([rand_grid, pgon_embeds_dup], dim=1)
275 |
276 | # Generate points
277 | # x: shape (batch_size, 2 + pgon_embed_dim, num_vert)
278 | x = F.relu(self.bn1(self.conv1(x)))
279 |
280 | # x: shape (batch_size, (2 + pgon_embed_dim)/2, num_vert)
281 | x = F.relu(self.bn2(self.conv2(x)))
282 |
283 | # x: shape (batch_size, (2 + pgon_embed_dim)/4, num_vert)
284 | x = F.relu(self.bn3(self.conv3(x)))
285 |
286 | # x: shape (batch_size, coord_dim, num_vert)
287 | x = self.th(self.conv4(x))
288 |
289 | # polygons: shape (batch_size, num_vert, coord_dim = 2)
290 | polygons = x.transpose(2, 1)
291 |
292 | # rand_grid_: shape (batch_size, num_vert, grid_dim = 2)
293 | rand_grid_ = rand_grid.transpose(2, 1)
294 | return polygons, rand_grid_
--------------------------------------------------------------------------------
/polygoncode/polygonembed/trainer_img.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from torch import optim
3 | from torch.utils.tensorboard import SummaryWriter
4 |
5 | from polygonembed.dataset import *
6 | from polygonembed.resnet2d import *
7 | from polygonembed.model_utils import *
8 | from polygonembed.data_util import *
9 | from polygonembed.trainer_helper import *
10 |
11 |
12 |
13 |
14 | def make_args_parser():
15 | parser = ArgumentParser()
16 | # dir
17 | parser.add_argument("--data_dir", type=str, default="./animal/")
18 | parser.add_argument("--model_dir", type=str, default="./model_dir/animal_img/")
19 | parser.add_argument("--log_dir", type=str, default="./model_dir/animal_img/")
20 |
21 | #data
22 | parser.add_argument("--img_filename", type=str, default="animal_img_mats.pkl")
23 |
24 | parser.add_argument("--data_split_num", type=int, default=0,
25 | help='''we might do multiple train/valid/test split,
26 | this indicate which split we will use to train
27 | Note that we use 1, 0, -1 to indicate train/test/valid
28 | 1: train
29 | 0: test
30 | -1: valid ''')
31 | parser.add_argument("--num_worker", type=int, default=0,
32 | help='the number of worker for dataloader')
33 |
34 |
35 |
36 | # model type
37 | parser.add_argument("--model_type", type=str, default="",
38 | help='''the type of image classification model we use,
39 | resnet18 / resnet34 / resnet50 / resnet101 / resnet152
40 | ''')
41 |
42 |
43 | # model
44 | # parser.add_argument("--embed_dim", type=int, default=64,
45 | # help='Point feature embedding dim')
46 | # parser.add_argument("--dropout", type=float, default=0.5,
47 | # help='The dropout rate used in all fully connected layer')
48 | # parser.add_argument("--act", type=str, default='sigmoid',
49 | # help='the activation function for the encoder decoder')
50 |
51 |
52 |
53 | # # # encoder decoder
54 | # # parser.add_argument("--join_dec_type", type=str, default='max',
55 | # # help='the type of join_dec, min/max/mean/cat')
56 |
57 | # # polygon encoder
58 | # parser.add_argument("--pgon_enc", type=str, default="resnet",
59 | # help='''the type of polygon encoder:
60 | # resnet: ResNet based encoder
61 | # veercnn: the CNN model proposed in https://arxiv.org/pdf/1806.03857.pdf''')
62 | # parser.add_argument("--pgon_embed_dim", type=int, default=64,
63 | # help='the embedding dimention of polygon')
64 | # parser.add_argument("--padding_mode", type=str, default="circular",
65 | # help='the type of padding method for Conv1D: circular / zeros / reflect / replicate')
66 | # parser.add_argument("--resnet_add_middle_pool", type=str, default='F',
67 | # help='whether to add MaxPool1D between the middle layers of ResNet')
68 | # parser.add_argument("--resnet_fl_pool_type", type=str, default="mean",
69 | # help='the type of final pooling method: mean / min /max')
70 | # parser.add_argument("--resnet_block_type", type=str, default="basic",
71 | # help='the type of ResNet block we will use: basic / bottleneck')
72 | # parser.add_argument("--resnet_layers_per_block", nargs='+', type=int, default=[],
73 | # help='the number of layers per resnet block, must be 3 layers')
74 |
75 |
76 |
77 | parser.add_argument("--do_data_augment", type=str, default="F",
78 | help = "whether do polygon data argumentation, flip, rotate, scale polygons in each batch")
79 |
80 |
81 |
82 |
83 | # train
84 | parser.add_argument("--opt", type=str, default="adam")
85 | parser.add_argument("--lr", type=float, default=0.01,
86 | help='learning rate')
87 |
88 | parser.add_argument("--weight_decay", type=float, default=0.001,
89 | help='weight decay of adam optimizer')
90 | # parser.add_argument("--task_loss_weight", type=float, default=0.8,
91 | # help='the weight of classification loss when we do join training')
92 | # parser.add_argument("--pgon_norm_reg_weight", type=float, default=0.1,
93 | # help='the weight of polygon embedding norm regularizer')
94 |
95 | # parser.add_argument("--grt_epoches", type=int, default=50000000,
96 | # help='the maximum epoches for generative model converge')
97 | parser.add_argument("--cla_epoches", type=int, default=50000000,
98 | help='the maximum epoches for polygon classifier model converge')
99 | # parser.add_argument("--max_burn_in", type=int, default=5000,
100 | # help='the maximum iterator for relative/global model converge')
101 | parser.add_argument("--batch_size", type=int, default=512)
102 | # parser.add_argument("--tol", type=float, default=0.000001)
103 |
104 | parser.add_argument("--balanced_train_loader", type=str, default="T",
105 | help = "whether we do BalancedSampler for polygon classification")
106 |
107 |
108 |
109 | # eval
110 | parser.add_argument("--log_every", type=int, default=50)
111 | parser.add_argument("--val_every", type=int, default=5000)
112 |
113 |
114 | # load old model
115 | parser.add_argument("--load_model", action='store_true')
116 |
117 | # cuda
118 | parser.add_argument("--device", type=str, default="cpu")
119 |
120 | return parser
121 |
122 |
123 | def bool_arg_handler(arg):
124 | return True if arg == "T" else False
125 |
126 | def update_args(args):
127 | select_args = ["balanced_train_loader", "do_data_augment"]
128 | for arg in select_args:
129 | args.__dict__[arg] = bool_arg_handler(getattr(args, arg))
130 |
131 | return args
132 |
133 |
134 | def make_args_combine(args):
135 | args_data = "{data:s}-{data_split_num:d}".format(
136 | data=args.data_dir.strip().split("/")[-2],
137 | data_split_num = args.data_split_num
138 | )
139 |
140 | args_train = "-{batch_size:d}-{lr:.6f}-{opt:s}-{weight_decay:.2f}-{balanced_train_loader:s}".format(
141 | # act = args.act,
142 | # dropout=args.dropout,
143 | batch_size=args.batch_size,
144 | lr=args.lr,
145 | opt = args.opt,
146 | weight_decay = args.weight_decay,
147 | # task_loss_weight = args.task_loss_weight,
148 | # pgon_norm_reg_weight = args.pgon_norm_reg_weight,
149 | balanced_train_loader = args.balanced_train_loader,
150 | # do_polygon_random_start = args.do_polygon_random_start
151 | )
152 |
153 | args_combine = "/{args_data:s}-{model_type:s}-{args_train:s}".format(
154 | args_data = args_data,
155 | model_type=args.model_type,
156 | args_train = args_train
157 |
158 | )
159 | return args_combine
160 |
161 |
162 |
163 | class Trainer():
164 | """
165 | Trainer
166 | """
167 | def __init__(self, args, img_gdf, console = True):
168 |
169 |
170 | self.args_combine = make_args_combine(args) #+ ".L2"
171 |
172 | self.log_file = args.log_dir + self.args_combine + ".log"
173 | self.model_file = args.model_dir + self.args_combine + ".pth"
174 | # tensorboard log directory
175 | # self.tb_log_dir = args.model_dir + self.args_combine
176 | self.tb_log_dir = args.model_dir + "/tb"
177 |
178 | if not os.path.exists(args.model_dir):
179 | os.makedirs(args.model_dir)
180 | if not os.path.exists(args.log_dir):
181 | os.makedirs(args.log_dir)
182 |
183 | self.logger = setup_logging(self.log_file, console = console, filemode='a')
184 |
185 |
186 |
187 | self.img_gdf = img_gdf
188 |
189 | args = update_args(args)
190 | self.args = args
191 |
192 | self.img_cla_dataset, self.img_cla_dataloader, self.split2num, self.data_split_col, self.train_sampler, self.num_classes = self.load_image_cla_dataset_dataloader(
193 | img_gdf,
194 | num_worker = args.num_worker,
195 | batch_size = args.batch_size,
196 | balanced_train_loader = args.balanced_train_loader,
197 | id_col = "ID",
198 | img_col = "IMAGE",
199 | class_col = "TYPEID",
200 | data_split_num = args.data_split_num,
201 | do_data_augment = args.do_data_augment,
202 | device = args.device)
203 |
204 | self.model = get_resnet_model(resnet_type = args.model_type, num_classes = self.num_classes).to(args.device)
205 |
206 |
207 | if args.opt == "sgd":
208 | self.optimizer = optim.SGD(filter(lambda p : p.requires_grad, self.model.parameters()), lr=args.lr, momentum=0)
209 | elif args.opt == "adam":
210 | self.optimizer = optim.Adam(filter(lambda p : p.requires_grad, self.model.parameters()), lr=args.lr, weight_decay = args.weight_decay)
211 |
212 | print("create model from {}".format(self.args_combine + ".pth"))
213 | self.logger.info("Save file at {}".format(self.args_combine + ".pth"))
214 |
215 | self.tb_writer = SummaryWriter(self.tb_log_dir)
216 | self.global_batch_idx = 0
217 |
218 | def load_image_cla_dataset_dataloader(self, img_gdf,
219 | num_worker, batch_size, balanced_train_loader,
220 | id_col = "ID", img_col = "IMAGE", class_col = "TYPEID", data_split_num = 0,
221 | do_data_augment = False, device = "cpu"):
222 | '''
223 | load polygon classification dataset including training, validation, testing
224 | '''
225 | img_cla_dataset = dict()
226 | img_cla_dataloader = dict()
227 | data_split_col = "SPLIT_{:d}".format(data_split_num)
228 | split2num = {"TRAIN": 1, "TEST": 0, "VALID": -1}
229 | split_nums = np.unique(np.array(img_gdf[data_split_col]))
230 | assert 1 in split_nums and 0 in split_nums
231 | dup_test = False
232 | if -1 not in split_nums:
233 | # we will make valid and test the same
234 | dup_test = True
235 |
236 | un_class = np.unique(np.array(img_gdf[class_col]))
237 | num_classes = len(un_class)
238 | max_num_exs_per_class = math.ceil(batch_size/num_classes)
239 |
240 |
241 | for split in ["TRAIN", "TEST", "VALID"]:
242 | # make dataset
243 |
244 | if split == "VALID" and dup_test:
245 | img_cla_dataset[split] = img_cla_dataset["TEST"]
246 | else:
247 | img_split_gdf = img_gdf[ img_gdf[data_split_col] == split2num[split] ]
248 |
249 | if split == "TRAIN":
250 | img_cla_dataset[split] = ImageDataset(img_gdf = img_split_gdf,
251 | id_col = id_col,
252 | img_col = img_col,
253 | class_col = class_col,
254 | do_data_augment = do_data_augment,
255 | device = device)
256 | else:
257 | img_cla_dataset[split] = ImageDataset(img_gdf = img_split_gdf,
258 | id_col = id_col,
259 | img_col = img_col,
260 | class_col = class_col,
261 | do_data_augment = False,
262 | device = device)
263 |
264 | # make dataloader
265 | if split == "TRAIN":
266 | if balanced_train_loader:
267 | train_sampler = BalancedSampler(classes = img_cla_dataset["TRAIN"].class_list.cpu().numpy(),
268 | num_per_class = max_num_exs_per_class,
269 | use_replace=False,
270 | multi_label=False)
271 | img_cla_dataloader[split] = torch.utils.data.DataLoader(img_cla_dataset[split],
272 | num_workers = num_worker,
273 | batch_size = batch_size,
274 | sampler=train_sampler,
275 | shuffle = False)
276 | else:
277 | train_sampler = None
278 | img_cla_dataloader[split] = torch.utils.data.DataLoader(img_cla_dataset[split],
279 | num_workers = num_worker,
280 | batch_size = batch_size,
281 | shuffle = True)
282 | elif split == "VALID" and dup_test:
283 | img_cla_dataloader[split] = img_cla_dataloader['TEST']
284 | else:
285 | img_cla_dataloader[split] = torch.utils.data.DataLoader(img_cla_dataset[split],
286 | num_workers = num_worker,
287 | batch_size = batch_size,
288 | shuffle = False)
289 |
290 |
291 | return img_cla_dataset, img_cla_dataloader, split2num, data_split_col, train_sampler, num_classes
292 |
293 | def run_train(self):
294 | # assert "norm" in self.geom_type_list
295 | self.global_batch_idx = train_image_model(
296 | self.args,
297 | model = self.model,
298 | img_cla_dataloader = self.img_cla_dataloader,
299 | optimizer = self.optimizer,
300 | tb_writer = self.tb_writer,
301 | logger = self.logger,
302 | model_file = self.model_file,
303 | cla_epoches = self.args.cla_epoches,
304 | log_every = self.args.log_every,
305 | val_every = self.args.val_every,
306 | global_batch_idx = self.global_batch_idx)
307 |
308 | def load_model(self):
309 | self.model, self.optimizer, self.args = load_model(self.model, self.optimizer, self.model_file)
--------------------------------------------------------------------------------
/polygoncode/polygonembed/dla_resnext.py:
--------------------------------------------------------------------------------
1 | """
2 | ResNeXt for ImageNet-1K, implemented in PyTorch.
3 | Original paper: 'Aggregated Residual Transformations for Deep Neural Networks,' http://arxiv.org/abs/1611.05431.
4 | """
5 |
6 | # __all__ = ['ResNeXt', 'resnext14_16x4d', 'resnext14_32x2d', 'resnext14_32x4d', 'resnext26_16x4d', 'resnext26_32x2d',
7 | # 'resnext26_32x4d', 'resnext38_32x4d', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d',
8 | # 'ResNeXtBottleneck', 'ResNeXtUnit']
9 | __all__ = ['ResNeXtBottleneck']
10 |
11 | import os
12 | import math
13 | import torch.nn as nn
14 | import torch.nn.init as init
15 | from polygonembed.dla_common import conv1x1_block, conv3x3_block
16 | from polygonembed.dla_resnet import ResInitBlock
17 |
18 |
19 | class ResNeXtBottleneck(nn.Module):
20 | """
21 | ResNeXt bottleneck block for residual path in ResNeXt unit.
22 |
23 | Parameters:
24 | ----------
25 | in_channels : int
26 | Number of input channels.
27 | out_channels : int
28 | Number of output channels.
29 | stride : int or tuple/list of 2 int
30 | Strides of the convolution.
31 | cardinality: int
32 | Number of groups.
33 | bottleneck_width: int
34 | Width of bottleneck block.
35 | bottleneck_factor : int, default 4
36 | Bottleneck factor.
37 | """
38 | def __init__(self,
39 | in_channels,
40 | out_channels,
41 | stride,
42 | cardinality,
43 | bottleneck_width,
44 | bottleneck_factor=4):
45 | super(ResNeXtBottleneck, self).__init__()
46 | mid_channels = out_channels // bottleneck_factor
47 | D = int(math.floor(mid_channels * (bottleneck_width / 64.0)))
48 | group_width = cardinality * D
49 |
50 | self.conv1 = conv1x1_block(
51 | in_channels=in_channels,
52 | out_channels=group_width)
53 | self.conv2 = conv3x3_block(
54 | in_channels=group_width,
55 | out_channels=group_width,
56 | stride=stride,
57 | groups=cardinality)
58 | self.conv3 = conv1x1_block(
59 | in_channels=group_width,
60 | out_channels=out_channels,
61 | activation=None)
62 |
63 | def forward(self, x):
64 | x = self.conv1(x)
65 | x = self.conv2(x)
66 | x = self.conv3(x)
67 | return x
68 |
69 |
70 | # class ResNeXtUnit(nn.Module):
71 | # """
72 | # ResNeXt unit with residual connection.
73 |
74 | # Parameters:
75 | # ----------
76 | # in_channels : int
77 | # Number of input channels.
78 | # out_channels : int
79 | # Number of output channels.
80 | # stride : int or tuple/list of 2 int
81 | # Strides of the convolution.
82 | # cardinality: int
83 | # Number of groups.
84 | # bottleneck_width: int
85 | # Width of bottleneck block.
86 | # """
87 | # def __init__(self,
88 | # in_channels,
89 | # out_channels,
90 | # stride,
91 | # cardinality,
92 | # bottleneck_width):
93 | # super(ResNeXtUnit, self).__init__()
94 | # self.resize_identity = (in_channels != out_channels) or (stride != 1)
95 |
96 | # self.body = ResNeXtBottleneck(
97 | # in_channels=in_channels,
98 | # out_channels=out_channels,
99 | # stride=stride,
100 | # cardinality=cardinality,
101 | # bottleneck_width=bottleneck_width)
102 | # if self.resize_identity:
103 | # self.identity_conv = conv1x1_block(
104 | # in_channels=in_channels,
105 | # out_channels=out_channels,
106 | # stride=stride,
107 | # activation=None)
108 | # self.activ = nn.ReLU(inplace=True)
109 |
110 | # def forward(self, x):
111 | # if self.resize_identity:
112 | # identity = self.identity_conv(x)
113 | # else:
114 | # identity = x
115 | # x = self.body(x)
116 | # x = x + identity
117 | # x = self.activ(x)
118 | # return x
119 |
120 |
121 | # class ResNeXt(nn.Module):
122 | # """
123 | # ResNeXt model from 'Aggregated Residual Transformations for Deep Neural Networks,' http://arxiv.org/abs/1611.05431.
124 |
125 | # Parameters:
126 | # ----------
127 | # channels : list of list of int
128 | # Number of output channels for each unit.
129 | # init_block_channels : int
130 | # Number of output channels for the initial unit.
131 | # cardinality: int
132 | # Number of groups.
133 | # bottleneck_width: int
134 | # Width of bottleneck block.
135 | # in_channels : int, default 3
136 | # Number of input channels.
137 | # in_size : tuple of two ints, default (224, 224)
138 | # Spatial size of the expected input image.
139 | # num_classes : int, default 1000
140 | # Number of classification classes.
141 | # """
142 | # def __init__(self,
143 | # channels,
144 | # init_block_channels,
145 | # cardinality,
146 | # bottleneck_width,
147 | # in_channels=3,
148 | # in_size=(224, 224),
149 | # num_classes=1000):
150 | # super(ResNeXt, self).__init__()
151 | # self.in_size = in_size
152 | # self.num_classes = num_classes
153 |
154 | # self.features = nn.Sequential()
155 | # self.features.add_module("init_block", ResInitBlock(
156 | # in_channels=in_channels,
157 | # out_channels=init_block_channels))
158 | # in_channels = init_block_channels
159 | # for i, channels_per_stage in enumerate(channels):
160 | # stage = nn.Sequential()
161 | # for j, out_channels in enumerate(channels_per_stage):
162 | # stride = 2 if (j == 0) and (i != 0) else 1
163 | # stage.add_module("unit{}".format(j + 1), ResNeXtUnit(
164 | # in_channels=in_channels,
165 | # out_channels=out_channels,
166 | # stride=stride,
167 | # cardinality=cardinality,
168 | # bottleneck_width=bottleneck_width))
169 | # in_channels = out_channels
170 | # self.features.add_module("stage{}".format(i + 1), stage)
171 | # self.features.add_module("final_pool", nn.AvgPool2d(
172 | # kernel_size=7,
173 | # stride=1))
174 |
175 | # self.output = nn.Linear(
176 | # in_features=in_channels,
177 | # out_features=num_classes)
178 |
179 | # self._init_params()
180 |
181 | # def _init_params(self):
182 | # for name, module in self.named_modules():
183 | # if isinstance(module, nn.Conv2d):
184 | # init.kaiming_uniform_(module.weight)
185 | # if module.bias is not None:
186 | # init.constant_(module.bias, 0)
187 |
188 | # def forward(self, x):
189 | # x = self.features(x)
190 | # x = x.view(x.size(0), -1)
191 | # x = self.output(x)
192 | # return x
193 |
194 |
195 | # def get_resnext(blocks,
196 | # cardinality,
197 | # bottleneck_width,
198 | # model_name=None,
199 | # pretrained=False,
200 | # root=os.path.join("~", ".torch", "models"),
201 | # **kwargs):
202 | # """
203 | # Create ResNeXt model with specific parameters.
204 |
205 | # Parameters:
206 | # ----------
207 | # blocks : int
208 | # Number of blocks.
209 | # cardinality: int
210 | # Number of groups.
211 | # bottleneck_width: int
212 | # Width of bottleneck block.
213 | # model_name : str or None, default None
214 | # Model name for loading pretrained model.
215 | # pretrained : bool, default False
216 | # Whether to load the pretrained weights for model.
217 | # root : str, default '~/.torch/models'
218 | # Location for keeping the model parameters.
219 | # """
220 |
221 | # if blocks == 14:
222 | # layers = [1, 1, 1, 1]
223 | # elif blocks == 26:
224 | # layers = [2, 2, 2, 2]
225 | # elif blocks == 38:
226 | # layers = [3, 3, 3, 3]
227 | # elif blocks == 50:
228 | # layers = [3, 4, 6, 3]
229 | # elif blocks == 101:
230 | # layers = [3, 4, 23, 3]
231 | # else:
232 | # raise ValueError("Unsupported ResNeXt with number of blocks: {}".format(blocks))
233 |
234 | # assert (sum(layers) * 3 + 2 == blocks)
235 |
236 | # init_block_channels = 64
237 | # channels_per_layers = [256, 512, 1024, 2048]
238 |
239 | # channels = [[ci] * li for (ci, li) in zip(channels_per_layers, layers)]
240 |
241 | # net = ResNeXt(
242 | # channels=channels,
243 | # init_block_channels=init_block_channels,
244 | # cardinality=cardinality,
245 | # bottleneck_width=bottleneck_width,
246 | # **kwargs)
247 |
248 | # if pretrained:
249 | # if (model_name is None) or (not model_name):
250 | # raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
251 | # from .model_store import download_model
252 | # download_model(
253 | # net=net,
254 | # model_name=model_name,
255 | # local_model_store_dir_path=root)
256 |
257 | # return net
258 |
259 |
260 | # def resnext14_16x4d(**kwargs):
261 | # """
262 | # ResNeXt-14 (16x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
263 | # http://arxiv.org/abs/1611.05431.
264 |
265 | # Parameters:
266 | # ----------
267 | # pretrained : bool, default False
268 | # Whether to load the pretrained weights for model.
269 | # root : str, default '~/.torch/models'
270 | # Location for keeping the model parameters.
271 | # """
272 | # return get_resnext(blocks=14, cardinality=16, bottleneck_width=4, model_name="resnext14_16x4d", **kwargs)
273 |
274 |
275 | # def resnext14_32x2d(**kwargs):
276 | # """
277 | # ResNeXt-14 (32x2d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
278 | # http://arxiv.org/abs/1611.05431.
279 |
280 | # Parameters:
281 | # ----------
282 | # pretrained : bool, default False
283 | # Whether to load the pretrained weights for model.
284 | # root : str, default '~/.torch/models'
285 | # Location for keeping the model parameters.
286 | # """
287 | # return get_resnext(blocks=14, cardinality=32, bottleneck_width=2, model_name="resnext14_32x2d", **kwargs)
288 |
289 |
290 | # def resnext14_32x4d(**kwargs):
291 | # """
292 | # ResNeXt-14 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
293 | # http://arxiv.org/abs/1611.05431.
294 |
295 | # Parameters:
296 | # ----------
297 | # pretrained : bool, default False
298 | # Whether to load the pretrained weights for model.
299 | # root : str, default '~/.torch/models'
300 | # Location for keeping the model parameters.
301 | # """
302 | # return get_resnext(blocks=14, cardinality=32, bottleneck_width=4, model_name="resnext14_32x4d", **kwargs)
303 |
304 |
305 | # def resnext26_16x4d(**kwargs):
306 | # """
307 | # ResNeXt-26 (16x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
308 | # http://arxiv.org/abs/1611.05431.
309 |
310 | # Parameters:
311 | # ----------
312 | # pretrained : bool, default False
313 | # Whether to load the pretrained weights for model.
314 | # root : str, default '~/.torch/models'
315 | # Location for keeping the model parameters.
316 | # """
317 | # return get_resnext(blocks=26, cardinality=16, bottleneck_width=4, model_name="resnext26_16x4d", **kwargs)
318 |
319 |
320 | # def resnext26_32x2d(**kwargs):
321 | # """
322 | # ResNeXt-26 (32x2d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
323 | # http://arxiv.org/abs/1611.05431.
324 |
325 | # Parameters:
326 | # ----------
327 | # pretrained : bool, default False
328 | # Whether to load the pretrained weights for model.
329 | # root : str, default '~/.torch/models'
330 | # Location for keeping the model parameters.
331 | # """
332 | # return get_resnext(blocks=26, cardinality=32, bottleneck_width=2, model_name="resnext26_32x2d", **kwargs)
333 |
334 |
335 | # def resnext26_32x4d(**kwargs):
336 | # """
337 | # ResNeXt-26 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
338 | # http://arxiv.org/abs/1611.05431.
339 |
340 | # Parameters:
341 | # ----------
342 | # pretrained : bool, default False
343 | # Whether to load the pretrained weights for model.
344 | # root : str, default '~/.torch/models'
345 | # Location for keeping the model parameters.
346 | # """
347 | # return get_resnext(blocks=26, cardinality=32, bottleneck_width=4, model_name="resnext26_32x4d", **kwargs)
348 |
349 |
350 | # def resnext38_32x4d(**kwargs):
351 | # """
352 | # ResNeXt-38 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
353 | # http://arxiv.org/abs/1611.05431.
354 |
355 | # Parameters:
356 | # ----------
357 | # pretrained : bool, default False
358 | # Whether to load the pretrained weights for model.
359 | # root : str, default '~/.torch/models'
360 | # Location for keeping the model parameters.
361 | # """
362 | # return get_resnext(blocks=38, cardinality=32, bottleneck_width=4, model_name="resnext38_32x4d", **kwargs)
363 |
364 |
365 | # def resnext50_32x4d(**kwargs):
366 | # """
367 | # ResNeXt-50 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
368 | # http://arxiv.org/abs/1611.05431.
369 |
370 | # Parameters:
371 | # ----------
372 | # pretrained : bool, default False
373 | # Whether to load the pretrained weights for model.
374 | # root : str, default '~/.torch/models'
375 | # Location for keeping the model parameters.
376 | # """
377 | # return get_resnext(blocks=50, cardinality=32, bottleneck_width=4, model_name="resnext50_32x4d", **kwargs)
378 |
379 |
380 | # def resnext101_32x4d(**kwargs):
381 | # """
382 | # ResNeXt-101 (32x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
383 | # http://arxiv.org/abs/1611.05431.
384 |
385 | # Parameters:
386 | # ----------
387 | # pretrained : bool, default False
388 | # Whether to load the pretrained weights for model.
389 | # root : str, default '~/.torch/models'
390 | # Location for keeping the model parameters.
391 | # """
392 | # return get_resnext(blocks=101, cardinality=32, bottleneck_width=4, model_name="resnext101_32x4d", **kwargs)
393 |
394 |
395 | # def resnext101_64x4d(**kwargs):
396 | # """
397 | # ResNeXt-101 (64x4d) model from 'Aggregated Residual Transformations for Deep Neural Networks,'
398 | # http://arxiv.org/abs/1611.05431.
399 |
400 | # Parameters:
401 | # ----------
402 | # pretrained : bool, default False
403 | # Whether to load the pretrained weights for model.
404 | # root : str, default '~/.torch/models'
405 | # Location for keeping the model parameters.
406 | # """
407 | # return get_resnext(blocks=101, cardinality=64, bottleneck_width=4, model_name="resnext101_64x4d", **kwargs)
408 |
409 |
410 | # def _calc_width(net):
411 | # import numpy as np
412 | # net_params = filter(lambda p: p.requires_grad, net.parameters())
413 | # weight_count = 0
414 | # for param in net_params:
415 | # weight_count += np.prod(param.size())
416 | # return weight_count
417 |
418 |
419 | # def _test():
420 | # import torch
421 |
422 | # pretrained = False
423 |
424 | # models = [
425 | # resnext14_16x4d,
426 | # resnext14_32x2d,
427 | # resnext14_32x4d,
428 | # resnext26_16x4d,
429 | # resnext26_32x2d,
430 | # resnext26_32x4d,
431 | # resnext38_32x4d,
432 | # resnext50_32x4d,
433 | # resnext101_32x4d,
434 | # resnext101_64x4d,
435 | # ]
436 |
437 | # for model in models:
438 |
439 | # net = model(pretrained=pretrained)
440 |
441 | # # net.train()
442 | # net.eval()
443 | # weight_count = _calc_width(net)
444 | # print("m={}, {}".format(model.__name__, weight_count))
445 | # assert (model != resnext14_16x4d or weight_count == 7127336)
446 | # assert (model != resnext14_32x2d or weight_count == 7029416)
447 | # assert (model != resnext14_32x4d or weight_count == 9411880)
448 | # assert (model != resnext26_16x4d or weight_count == 10119976)
449 | # assert (model != resnext26_32x2d or weight_count == 9924136)
450 | # assert (model != resnext26_32x4d or weight_count == 15389480)
451 | # assert (model != resnext38_32x4d or weight_count == 21367080)
452 | # assert (model != resnext50_32x4d or weight_count == 25028904)
453 | # assert (model != resnext101_32x4d or weight_count == 44177704)
454 | # assert (model != resnext101_64x4d or weight_count == 83455272)
455 |
456 | # x = torch.randn(1, 3, 224, 224)
457 | # y = net(x)
458 | # y.sum().backward()
459 | # assert (tuple(y.size()) == (1, 1000))
460 |
461 |
462 | if __name__ == "__main__":
463 | _test()
464 |
--------------------------------------------------------------------------------
/polygoncode/polygonembed/ddsl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | from torch.autograd.function import once_differentiable
5 | from math import pi
6 |
7 | from polygonembed.ddsl_utils import *
8 |
9 |
10 |
11 | class SimplexFT(Function):
12 | """
13 | Fourier transform for signal defined on a j-simplex set in R^n space
14 | :param V: vertex tensor. float tensor of shape (n_vertex, n_dims)
15 | :param E: element tensor. int tensor of shape (n_elem, j or j+1)
16 | if j cols, triangulate/tetrahedronize interior first.
17 | :param D: int ndarray of shape (n_elem, n_channel)
18 | :param res: n_dims int tuple of number of frequency modes
19 | :param t: n_dims tuple of period in each dimension
20 | :param j: dimension of simplex set
21 | :param mode: normalization mode.
22 | 'density' for preserving density, 'mass' for preserving mass
23 | :return: F: ndarray of shape (res[0], res[1], ..., res[-1]/2, n_channel)
24 | last dimension is halfed since the signal is assumed to be real
25 | """
26 | @staticmethod
27 | def forward(ctx, V, E, D, res, t, j,
28 | min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "fft",
29 | elem_batch=100, mode='density'):
30 | '''
31 | Args:
32 | V: vertex tensor. float tensor of shape (batch_size, num_vert, n_dims = 2)
33 | E: element tensor. int tensor of shape (batch_size, n_elem = num_vert, j or j+1)
34 | if j cols, triangulate/tetrahedronize interior first.
35 | (num_vert, 2), indicate the connectivity
36 |
37 |
38 | D: int ndarray of shape (batch_size, n_elem = num_vert, n_channel = 1), all 1
39 | res: n_dims int tuple of number of frequency modes, (fx, fy) for polygon
40 | t: n_dims tuple of period in each dimension, (tx, ty) for polygon
41 | j: dimension of simplex set, 2 for polygon
42 | max_freqXY: the maximum frequency
43 | min_freqXY: the minimum frequency
44 | freq_init: frequency generated method,
45 | "geometric": geometric series
46 | "fft": fast fourier transformation
47 | mode: normalization mode.
48 | 'density' for preserving density, 'mass' for preserving mass
49 | Return:
50 | F: ndarray of shape (res[0], res[1], ..., res[-1]/2, n_channel)
51 | last dimension is halfed since the signal is assumed to be real
52 | Shape (fx, fy//2+1, n_channel = 1, 2) for polygon case
53 | '''
54 | ## check if E is subdim
55 | '''
56 | subdim: See ddsl_illustration.jpg
57 | True: V is a sequence of vertex and E is sequence of lines
58 | False: V is a sequence of vertex + one and E is sequence of triangles
59 | '''
60 | assert V.shape[0] == E.shape[0] == D.shape[0] # batch_size
61 | batch_size = V.shape[0]
62 | subdim = E.shape[-1] == j and V.shape[-1] == j
63 | assert (E.shape[-1] == j+1 or subdim)
64 | # assert V.shape[0] == E.shape[0] == D.shape[0] # batch_size dim should match
65 |
66 | if subdim:
67 | # make sure all D has same density
68 | D_repeat = torch.repeat_interleave(D[:, 0].unsqueeze(1), repeats = D.shape[1], dim=1)
69 | assert((D == D_repeat).sum().item() == D.numel()) # assert same densities for all simplices (homogeneous filling)
70 | # add (0,0) as an auxilary vertex in V
71 | # aux_vert_mat: shape (batch_size, 1, n_dims = 2)
72 | aux_vert_mat = torch.zeros((batch_size, 1, V.shape[-1]), device=V.device, dtype=V.dtype)
73 | # V: (batch_size, num_vert + 1, n_dims = 2)
74 | V = torch.cat((V, aux_vert_mat), dim=1)
75 |
76 | # the index of the added auxilary vertex (0, 0)
77 | aux_vert_idx = V.shape[1] - 1
78 | # add_aux_vert_mat: (batch_size, n_elem = num_vert, 1)
79 | # values are the index of the added auxilary vertex
80 | add_aux_vert_mat = torch.zeros((batch_size, E.shape[1], 1), device=E.device, dtype=E.dtype) + aux_vert_idx
81 | # E: (batch_size, n_elem = num_vert, j+1 = 3), add (0, 0) as the 3rd vertice for each line in E, we have construct all 2-simplex
82 | E = torch.cat((E, add_aux_vert_mat), dim=-1)
83 |
84 | n_elem = E.shape[-2] # n_elem = num_vert, number of 2-simplex, the triangles
85 | n_vert = V.shape[-2] # n_vert = num_vert + 1, number of vertex
86 | n_channel = D.shape[-1] # n_channel = 1
87 |
88 | ## save context info for backwards
89 | ctx.mark_non_differentiable(E, D) # mark non-differentiable
90 | ctx.res = res
91 | ctx.t = t
92 | ctx.j = j
93 | ctx.mode = mode
94 | ctx.n_dims = V.shape[-1] # n_dims = 2
95 | ctx.elem_batch = elem_batch # elem_batch = 100
96 | ctx.subdim = subdim
97 |
98 | ctx.min_freqXY = min_freqXY
99 | ctx.max_freqXY = max_freqXY
100 | ctx.mid_freqXY = mid_freqXY,
101 | ctx.freq_init = freq_init
102 |
103 |
104 |
105 |
106 | # compute content array
107 | # C: normalized simple content, \gama_n^j, shape (batch_size, n_elem = num_vert, 1)
108 | # unsigned: Equation 6 in https://openreview.net/pdf?id=B1G5ViAqFm
109 | # signed: Equation 8 in https://openreview.net/pdf?id=B1G5ViAqFm
110 | C = math.factorial(j) * simplex_content(V, E, signed=subdim) # [n_elem, 1]
111 | ctx.save_for_backward(V, E, D, C)
112 |
113 | ## compute frequencies F
114 | n_dims = ctx.n_dims
115 | assert(n_dims == len(res)) # consistent spacial dimensionality
116 | assert(E.shape[1] == D.shape[1]) # consistent vertex numbers
117 | assert(mode in ['density', 'mass'])
118 |
119 |
120 |
121 | # frequency tensor
122 | '''
123 | omega: the fft frequance matrix
124 | if extract = True
125 | for res = (fx, fy) => shape (fx, fy//2+1, 2)
126 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2+1, n = n_dims)
127 | '''
128 | # omega = fftfreqs(res, dtype=V.dtype).to(V.device) # [dim0, dim1, dim2, d]
129 | # omega: [dim0, dim1, dim2, d]
130 | omega = get_fourier_freqs(res = res,
131 | min_freqXY = min_freqXY,
132 | max_freqXY = max_freqXY,
133 | mid_freqXY = mid_freqXY,
134 | freq_init = freq_init,
135 | dtype=V.dtype).to(V.device)
136 |
137 | # normalize frequencies
138 | for dim in range(n_dims):
139 | omega[..., dim] *= 2 * math.pi / t[dim]
140 |
141 |
142 |
143 | # initialize output F
144 | # F_shape = [fx, fy//2+1]
145 | F_shape = list(omega.shape)[:-1]
146 | # F_shape = [fx, fy//2+1, n_channel = 1, 2]
147 | F_shape += [n_channel, 2]
148 | # F_shape = [batch_size, fx, fy//2+1, n_channel = 1, 2]
149 | F_shape = [batch_size] + F_shape
150 | # F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2)
151 | F = torch.zeros(*F_shape, dtype=V.dtype, device=V.device) # [dimX, dimY, dimZ, n_chan, 2] 2: real/imag
152 |
153 | # compute element-point tensor
154 | # P: point tensor. float, shape (batch_size, n_elem = num_vert, j+1 = 3, n_dims = 2)
155 | # P = V[E]
156 | P = coord_P_lookup_from_V_E(V, E)
157 |
158 |
159 |
160 | # loop over element/simplex batches
161 | for idx in range(math.ceil(n_elem/elem_batch)):
162 | id_start = idx * elem_batch
163 | id_end = min((idx+1) * elem_batch, n_elem)
164 | # Xi: coordinate mat, shape [batch_size, elem_batch, j+1, n_dims = 2]
165 | Xi = P[:, id_start:id_end]
166 | # Di: simple density mat, shape [batch_size, elem_batch, n_channel = 1]
167 | Di = D[:, id_start:id_end]
168 | # Ci: normalized simple content, \gama_n^j, Equ. 6, shape [batch_size, elem_batch, 1]
169 | Ci = C[:, id_start:id_end]
170 | # CDi: \gama_n^j * \rho_n
171 | # CDi: shape [batch_size, elem_batch, n_channel]
172 | CDi = Ci.expand_as(Di) * Di
173 | # sig: shape (batch_size, elem_batch, j+1, fx, fy//2+1)
174 | # k dot x, Equation 3 in https://openreview.net/pdf?id=B1G5ViAqFm
175 | sig = torch.einsum('lbjd,...d->lbj...', (Xi, omega))
176 |
177 |
178 | # sig: shape (batch_size, elem_batch, j+1, fx, fy//2+1, 1)
179 | sig = torch.unsqueeze(sig, dim=-1) # [elem_batch, j+1, dimX, dimY, dimZ, 1]
180 | # esig: e^{-i*\sigma_t}, Euler's formular,
181 | # shape (batch_size, elem_batch, j+1, fx, fy//2+1, 1, 2),
182 | esig = torch.stack((torch.cos(sig), -torch.sin(sig)), dim=-1) # [elem_batch, j+1, dimX, dimY, dimZ, 1, 2]
183 | # sig: shape (batch_size, elem_batch, j+1, fx, fy//2+1, 1, 1)
184 | sig = torch.unsqueeze(sig, dim=-1) # [elem_batch, j+1, dimX, dimY, dimZ, 1, 1]
185 |
186 |
187 |
188 | # denom: prod(sigma_t - sigm_l), the demoniator of Equation 7 in https://openreview.net/pdf?id=B1G5ViAqFm
189 | # denom: shape (batch_size, elem_batch, j+1, fx, fy//2+1, 1, 1)
190 | denom = torch.ones_like(sig) # [elem_batch, j+1, dimX, dimY, dimZ, 1, 1]
191 | for dim in range(1, j+1):
192 | seq = permute_seq(dim, j+1)
193 | denom *= sig - sig[:, :, seq]
194 | # tmp: shape (batch_size, elem_batch, fx, fy//2+1, 1, 2)
195 | # sum e^{-i*\sigma_t}/prod(sigma_t - sigm_l),
196 | # signed context, Equation 4 in https://openreview.net/pdf?id=B1G5ViAqFm
197 | tmp = torch.sum(esig / denom, dim=2) # [elem_batch, dimX, dimY, dimZ, 1, 2]
198 |
199 | # select all cases when denom == 0
200 | mask = ((denom == 0).repeat_interleave(repeats = 2, dim=-1).sum(dim = 2) > 0)
201 | # mask all cases as 0
202 | tmp[mask] = 0
203 |
204 | # CDi: shape (batch_size, elem_batch, n_channel, 1)
205 | CDi.unsqueeze_(-1) # [elem_batch, n_channel, 1]
206 | # CDi: shape (batch_size, elem_batch, 1, 1, n_channel, 1)
207 | for _ in range(n_dims): # unsqueeze to broadcast
208 | CDi.unsqueeze_(dim=2) # [elem_batch, 1, 1, 1, n_channel, 2]
209 |
210 | # shape_ = (batch_size, elem_batch, fx, fy//2+1, n_channel, 2)
211 | shape_ = (list(tmp.shape[:-2])+[n_channel, 2])
212 | # tmp: (batch_size, elem_batch, fx, fy//2+1, n_channel, 2)
213 | # \gama_n^j * \rho_n * sum e^{-i*\sigma_t}/prod(sigma_t - sigm_l)
214 | tmp = tmp * CDi # [elem_batch, dimX, dimY, dimZ, n_channel, 2]
215 | # Fi: shape (batch_size, fx, fy//2+1, n_channel, 2)
216 | Fi = torch.sum(tmp, dim=1, keepdim=False) # [dimX, dimY, dimZ, n_channel, 2]
217 |
218 |
219 | # CDi_: shape (batch_size, 1, 1, n_channel, 1)
220 | # the sum of the polygon content
221 | CDi_ = torch.sum(CDi, dim=1)
222 | # CDi_: shape (batch_size, n_channel, 1)
223 | for _ in range(n_dims): # squeeze dims
224 | CDi_.squeeze_(dim=1) # [n_channel, 2]
225 |
226 |
227 | # Fi[:, tuple([0] * n_dims)] = - 1 / factorial(j) * CDi_ # ?????
228 |
229 | # Fi: shape (batch_size, fx, fy//2+1, n_channel = 1, 2)
230 | if n_dims == 2:
231 | Fi[:, 0, 0] = - 1 / math.factorial(j) * CDi_
232 | elif n_dims == 3:
233 | Fi[:, 0, 0, 0] = - 1 / math.factorial(j) * CDi_
234 | else:
235 | raise Exception("n_dims is not 2 or 3")
236 | F += Fi
237 |
238 | # F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2)
239 | # multiply tensor F by i ** j
240 | # see Equation 4 and 7 in https://openreview.net/pdf?id=B1G5ViAqFm
241 | F = img(F, deg=j) # Fi *= i**j [dimX, dimY, dimZ, n_chan, 2] 2: real/imag
242 |
243 | if mode == 'density':
244 | res_t = torch.tensor(res)
245 | if not torch.equal(res_t, res[0]*torch.ones(len(res), dtype=res_t.dtype)):
246 | print("WARNING: density preserving mode not correctly implemented if not all res are equal")
247 | F *= res[0] ** j
248 | return F
249 |
250 |
251 | class DDSL_spec(nn.Module):
252 | """
253 | Module for DDSL layer. Takes in a simplex mesh and returns the spectral raster.
254 | """
255 | def __init__(self, res, t, j,
256 | min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "fft",
257 | elem_batch=100, mode='density'):
258 | """
259 | Args:
260 | res: n_dims int tuple of number of frequency modes
261 | t: n_dims tuple of period in each dimension
262 | j: dimension of simplex set
263 | max_freqXY: the maximum frequency
264 | min_freqXY: the minimum frequency
265 | freq_init: frequency generated method,
266 | "geometric": geometric series
267 | "fft": fast fourier transformation
268 | elem_batch: element-wise batch size.
269 | mode: 'density' for density conserving, or 'mass' for mess conserving. Defaults 'density'
270 | """
271 | super(DDSL_spec, self).__init__()
272 | self.res = res
273 | self.t = t
274 | self.j = j
275 | self.min_freqXY = min_freqXY
276 | self.max_freqXY = max_freqXY
277 | self.mid_freqXY = mid_freqXY
278 | self.freq_init = freq_init
279 | self.elem_batch = elem_batch
280 | self.mode = mode
281 | def forward(self, V, E, D):
282 | """
283 | V: vertex tensor. float tensor of shape (batch_size, num_vert, n_dims = 2)
284 | E: element tensor. int tensor of shape (batch_size, n_elem = num_vert, j or j+1)
285 | if j cols, triangulate/tetrahedronize interior first.
286 | (num_vert, 2), indicate the connectivity
287 | D: int ndarray of shape (batch_size, n_elem, n_channel)
288 | :return F: ndarray of shape (res[0], res[1], ..., res[-1]/2, n_channel)
289 | last dimension is halfed since the signal is assumed to be real
290 | F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2)
291 | """
292 | V, D = V.double(), D.double()
293 | return SimplexFT.apply(V,E,D,self.res,self.t,self.j,
294 | self.min_freqXY, self.max_freqXY, self.mid_freqXY, self.freq_init,
295 | self.elem_batch,self.mode)
296 |
297 |
298 | class DDSL_phys(nn.Module):
299 | """
300 | Module for DDSL layer. Takes in a simplex mesh and returns a dealiased raster image (in physical domain).
301 | """
302 | def __init__(self, res, t, j,
303 | min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "fft",
304 | smoothing='gaussian', sig=2.0, elem_batch=100, mode='density'):
305 | """
306 | Args:
307 | res: n_dims int tuple of number of frequency modes
308 | t: n_dims tuple of period in each dimension
309 | j: dimension of simplex set
310 | max_freqXY: the maximum frequency
311 | min_freqXY: the minimum frequency
312 | freq_init: frequency generated method,
313 | "geometric": geometric series
314 | "fft": fast fourier transformation
315 | smoothing: str, choice of spectral smoothing function. Defaults 'gaussian'
316 | sig: sigma of gaussian at highest frequency
317 | elem_batch: element-wise batch size.
318 | mode: 'density' for density conserving, or 'mass' for mess conserving. Defaults 'density'
319 | """
320 | super(DDSL_phys, self).__init__()
321 | self.res = res
322 | self.t = t
323 | self.j = j
324 | self.min_freqXY = min_freqXY
325 | self.max_freqXY = max_freqXY
326 | self.mid_freqXY = mid_freqXY
327 | self.freq_init = freq_init
328 | self.elem_batch = elem_batch
329 | self.mode = mode
330 | self.filter = None
331 | self.sig = sig
332 |
333 | if isinstance(smoothing, str):
334 | assert(smoothing in ["gaussian"])
335 | if smoothing == 'gaussian':
336 | # filter: shape (fx, fy//2+1, 1, 1)
337 | self.filter = self._gaussian_filter()
338 |
339 | def forward(self, V, E, D):
340 | """
341 | V: vertex tensor. float tensor of shape (batch_size, num_vert, n_dims = 2)
342 | E: element tensor. int tensor of shape (batch_size, n_elem = num_vert, j or j+1)
343 | if j cols, triangulate/tetrahedronize interior first.
344 | (num_vert, 2), indicate the connectivity
345 | :param D: int ndarray of shape (batch_size, n_elem, n_channel)
346 | Return:
347 | f: dealiased raster image in physical domain of shape (batch_size, res[0], res[1], ..., res[-1], n_channel)
348 | shape (batch_size, fx, fy, n_channel = 1)
349 | """
350 |
351 | V, D = V.double(), D.double()
352 | # F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2) for polygon case
353 | F = SimplexFT.apply(V,E,D,self.res,self.t,self.j,
354 | self.min_freqXY, self.max_freqXY, self.mid_freqXY, self.freq_init,
355 | self.elem_batch,self.mode)
356 | F[torch.isnan(F)] = 0 # pad nans to 0
357 | if self.filter is not None:
358 | # filter: shape (fx, fy//2+1, 1, 1)
359 | self.filter = self.filter.to(F.device)
360 | # filter_: shape (fx, fy//2+1, 1, 2)
361 | filter_ = torch.repeat_interleave(self.filter, repeats = F.shape[-1], dim=-1)
362 | # F: shape (batch_size, fx, fy//2+1, n_channel = 1, 2) for polygon case
363 | F *= filter_ # [dim0, dim1, dim2, n_channel, 2]
364 | dim = len(self.res)
365 | # F: shape (batch_size, n_channel = 1, fx, fy//2+1, 2) for polygon case
366 | F = F.permute(*([0, dim+1] + list(range(1, dim+1)) + [dim+2])) # [n_channel, dim0, dim1, dim2, 2]
367 | # f: shape (batch_size, n_channel = 1, fx, fy)
368 | f = torch.irfft(F, dim, signal_sizes=self.res)
369 | # f: shape (batch_size, fx, fy, n_channel = 1)
370 | f = f.permute(*([0] + list(range(2, 2+dim)) + [1]))
371 |
372 | return f
373 |
374 | def _gaussian_filter(self):
375 | '''
376 | Return:
377 | filter_: shape (fx, fy//2+1, 1, 1)
378 | '''
379 |
380 | # omega = fftfreqs(self.res, dtype=torch.float64) # [dim0, dim1, dim2, d]
381 |
382 | # omega: shape (fx, fy//2+1, 2)
383 | omega = get_fourier_freqs(res = self.res,
384 | min_freqXY = self.min_freqXY,
385 | max_freqXY = self.max_freqXY,
386 | mid_freqXY = self.mid_freqXY,
387 | freq_init = self.freq_init,
388 | dtype=torch.float64)
389 | # dis: shape (fx, fy//2+1)
390 | dis = torch.sqrt(torch.sum(omega ** 2, dim=-1))
391 | # filter_: shape (fx, fy//2+1, 1, 1)
392 | filter_ = torch.exp(-0.5*((self.sig*2*dis/self.res[0])**2)).unsqueeze(-1).unsqueeze(-1)
393 | filter_.requires_grad = False
394 | return filter_
--------------------------------------------------------------------------------
/polygoncode/polygonembed/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 |
6 | import torch.utils.data
7 | import math
8 | import numpy as np
9 |
10 | class LayerNorm(nn.Module):
11 | """
12 | layer normalization
13 | Simple layer norm object optionally used with the convolutional encoder.
14 | """
15 |
16 | def __init__(self, feature_dim, eps=1e-6):
17 | super(LayerNorm, self).__init__()
18 | self.gamma = nn.Parameter(torch.ones((feature_dim,)))
19 | self.register_parameter("gamma", self.gamma)
20 | self.beta = nn.Parameter(torch.zeros((feature_dim,)))
21 | self.register_parameter("beta", self.beta)
22 | self.eps = eps
23 |
24 | def forward(self, x):
25 | # x: [batch_size, embed_dim]
26 | # normalize for each embedding
27 | mean = x.mean(-1, keepdim=True)
28 | std = x.std(-1, keepdim=True)
29 | # output shape is the same as x
30 | # Type not match for self.gamma and self.beta??????????????????????
31 | # output: [batch_size, embed_dim]
32 | return self.gamma * (x - mean) / (std + self.eps) + self.beta
33 |
34 | def get_activation_function(activation, context_str):
35 | if activation == "leakyrelu":
36 | return nn.LeakyReLU(negative_slope=0.2)
37 | elif activation == "relu":
38 | return nn.ReLU()
39 | elif activation == "sigmoid":
40 | return nn.Sigmoid()
41 | elif activation == 'tanh':
42 | return nn.Tanh()
43 | else:
44 | raise Exception("{} activation not recognized.".format(context_str))
45 |
46 |
47 | def coord_normalize(coords, extent = (-1710000, -1690000, 1610000, 1640000)):
48 | """
49 | Given a list of coords (X, Y), normalize them to [-1, 1]
50 | Args:
51 | coords: a python list with shape (batch_size, num_context_pt, coord_dim)
52 | extent: (x_min, x_max, y_min, y_max)
53 | Return:
54 | coords_mat: np tensor shape (batch_size, num_context_pt, coord_dim)
55 | """
56 | if type(coords) == list:
57 | coords_mat = np.asarray(coords).astype(float)
58 | elif type(coords) == np.ndarray:
59 | coords_mat = coords
60 |
61 | # x => [0,1] min_max normalize
62 | x = (coords_mat[:,:,0] - extent[0])*1.0/(extent[1] - extent[0])
63 | # x => [-1,1]
64 | coords_mat[:,:,0] = (x * 2) - 1
65 |
66 | # y => [0,1] min_max normalize
67 | y = (coords_mat[:,:,1] - extent[2])*1.0/(extent[3] - extent[2])
68 | # y => [-1,1]
69 | coords_mat[:,:,1] = (y * 2) - 1
70 |
71 | return coords_mat
72 |
73 | class SingleFeedForwardNN(nn.Module):
74 | """
75 | Creates a single layer fully connected feed forward neural network.
76 | this will use non-linearity, layer normalization, dropout
77 | this is for the hidden layer, not the last layer of the feed forard NN
78 | """
79 |
80 | def __init__(self, input_dim,
81 | output_dim,
82 | dropout_rate=None,
83 | activation="sigmoid",
84 | use_layernormalize=False,
85 | skip_connection = False,
86 | context_str = ''):
87 | '''
88 |
89 | Args:
90 | input_dim (int32): the input embedding dim
91 | output_dim (int32): dimension of the output of the network.
92 | dropout_rate (scalar tensor or float): Dropout keep prob.
93 | activation (string): tanh or relu or leakyrelu or sigmoid
94 | use_layernormalize (bool): do layer normalization or not
95 | skip_connection (bool): do skip connection or not
96 | context_str (string): indicate which spatial relation encoder is using the current FFN
97 |
98 | '''
99 | super(SingleFeedForwardNN, self).__init__()
100 | self.input_dim = input_dim
101 | self.output_dim = output_dim
102 |
103 | if dropout_rate is not None:
104 | self.dropout = nn.Dropout(p=dropout_rate)
105 | else:
106 | self.dropout = None
107 |
108 | self.act = get_activation_function(activation, context_str)
109 |
110 | if use_layernormalize:
111 | # the layer normalization is only used in the hidden layer, not the last layer
112 | self.layernorm = nn.LayerNorm(self.output_dim)
113 | self.layernorm.weight.data.fill_(1)
114 | self.layernorm.bias.data.zero_()
115 | else:
116 | self.layernorm = None
117 |
118 | # the skip connection is only possible, if the input and out dimention is the same
119 | if self.input_dim == self.output_dim:
120 | self.skip_connection = skip_connection
121 | else:
122 | self.skip_connection = False
123 |
124 | self.linear = nn.Linear(self.input_dim, self.output_dim)
125 | nn.init.xavier_uniform_(self.linear.weight)
126 | nn.init.zeros_(self.linear.bias)
127 |
128 |
129 |
130 |
131 |
132 | def forward(self, input_tensor):
133 | '''
134 | Args:
135 | input_tensor: shape [batch_size, ..., input_dim]
136 | Returns:
137 | tensor of shape [batch_size,..., output_dim]
138 | note there is no non-linearity applied to the output.
139 |
140 | Raises:
141 | Exception: If given activation or normalizer not supported.
142 | '''
143 | assert input_tensor.size()[-1] == self.input_dim
144 | # Linear layer
145 | output = self.linear(input_tensor)
146 | # non-linearity
147 | output = self.act(output)
148 | # dropout
149 | if self.dropout is not None:
150 | output = self.dropout(output)
151 |
152 | # skip connection
153 | if self.skip_connection:
154 | output = output + input_tensor
155 |
156 | # layer normalization
157 | if self.layernorm is not None:
158 | output = self.layernorm(output)
159 |
160 | return output
161 |
162 | # num_rbf_anchor_pts = 100, rbf_kernal_size = 10e2, frequency_num = 16,
163 | # max_radius = 10000, dropout = 0.5, f_act = "sigmoid", freq_init = "geometric",
164 | # num_hidden_layer = 3, hidden_dim = 128, use_layn = "F", skip_connection = "F", use_post_mat = "T"):
165 | # if use_layn == "T":
166 | # use_layn = True
167 | # else:
168 | # use_layn = False
169 | # if skip_connection == "T":
170 | # skip_connection = True
171 | # else:
172 | # skip_connection = False
173 | # if use_post_mat == "T":
174 | # use_post_mat = True
175 | # else:
176 | # use_post_mat = False
177 |
178 | class MultiLayerFeedForwardNN(nn.Module):
179 | """
180 | Creates a fully connected feed forward neural network.
181 | N fully connected feed forward NN, each hidden layer will use non-linearity, layer normalization, dropout
182 | The last layer do not have any of these
183 | """
184 |
185 | def __init__(self, input_dim,
186 | output_dim,
187 | num_hidden_layers=0,
188 | dropout_rate=None,
189 | hidden_dim=-1,
190 | activation="sigmoid",
191 | use_layernormalize=False,
192 | skip_connection = False,
193 | context_str = None):
194 | '''
195 |
196 | Args:
197 | input_dim (int32): the input embedding dim
198 | num_hidden_layers (int32): number of hidden layers in the network, set to 0 for a linear network.
199 | output_dim (int32): dimension of the output of the network.
200 | dropout (scalar tensor or float): Dropout keep prob.
201 | hidden_dim (int32): size of the hidden layers
202 | activation (string): tanh or relu
203 | use_layernormalize (bool): do layer normalization or not
204 | context_str (string): indicate which spatial relation encoder is using the current FFN
205 |
206 | '''
207 | super(MultiLayerFeedForwardNN, self).__init__()
208 | self.input_dim = input_dim
209 | self.output_dim = output_dim
210 | self.num_hidden_layers = num_hidden_layers
211 | self.dropout_rate = dropout_rate
212 | self.hidden_dim = hidden_dim
213 | self.activation = activation
214 | self.use_layernormalize = use_layernormalize
215 | self.skip_connection = skip_connection
216 | self.context_str = context_str
217 |
218 | self.layers = nn.ModuleList()
219 | if self.num_hidden_layers <= 0:
220 | self.layers.append( SingleFeedForwardNN(input_dim = self.input_dim,
221 | output_dim = self.output_dim,
222 | dropout_rate = self.dropout_rate,
223 | activation = self.activation,
224 | use_layernormalize = False,
225 | skip_connection = False,
226 | context_str = self.context_str))
227 | else:
228 | self.layers.append( SingleFeedForwardNN(input_dim = self.input_dim,
229 | output_dim = self.hidden_dim,
230 | dropout_rate = self.dropout_rate,
231 | activation = self.activation,
232 | use_layernormalize = self.use_layernormalize,
233 | skip_connection = self.skip_connection,
234 | context_str = self.context_str))
235 |
236 | for i in range(self.num_hidden_layers-1):
237 | self.layers.append( SingleFeedForwardNN(input_dim = self.hidden_dim,
238 | output_dim = self.hidden_dim,
239 | dropout_rate = self.dropout_rate,
240 | activation = self.activation,
241 | use_layernormalize = self.use_layernormalize,
242 | skip_connection = self.skip_connection,
243 | context_str = self.context_str))
244 |
245 | self.layers.append( SingleFeedForwardNN(input_dim = self.hidden_dim,
246 | output_dim = self.output_dim,
247 | dropout_rate = self.dropout_rate,
248 | activation = self.activation,
249 | use_layernormalize = False,
250 | skip_connection = False,
251 | context_str = self.context_str))
252 |
253 |
254 |
255 | def forward(self, input_tensor):
256 | '''
257 | Args:
258 | input_tensor: shape [batch_size, ..., input_dim]
259 | Returns:
260 | tensor of shape [batch_size, ..., output_dim]
261 | note there is no non-linearity applied to the output.
262 |
263 | Raises:
264 | Exception: If given activation or normalizer not supported.
265 | '''
266 | assert input_tensor.size()[-1] == self.input_dim
267 | output = input_tensor
268 | for i in range(len(self.layers)):
269 | output = self.layers[i](output)
270 |
271 | return output
272 |
273 |
274 |
275 | class MultiLayerFeedForwardNNFlexible(nn.Module):
276 | """
277 | Creates a fully connected feed forward neural network.
278 | N fully connected feed forward NN, each hidden layer will use non-linearity, layer normalization, dropout
279 | The last layer do not have any of these
280 | """
281 |
282 | def __init__(self, input_dim,
283 | output_dim,
284 | hidden_layers=[],
285 | dropout_rate=None,
286 | activation="sigmoid",
287 | use_layernormalize=False,
288 | skip_connection = False,
289 | context_str = None):
290 | '''
291 |
292 | Args:
293 | input_dim (int32): the input embedding dim
294 | hidden_layers (list): a list of hidden layer dimention
295 | output_dim (int32): dimension of the output of the network.
296 | dropout (scalar tensor or float): Dropout keep prob.
297 | activation (string): tanh or relu
298 | use_layernormalize (bool): do layer normalization or not
299 | context_str (string): indicate which spatial relation encoder is using the current FFN
300 |
301 | '''
302 | super(MultiLayerFeedForwardNNFlexible, self).__init__()
303 | self.input_dim = input_dim
304 | self.output_dim = output_dim
305 | self.hidden_layers = hidden_layers
306 | self.dropout_rate = dropout_rate
307 | self.activation = activation
308 | self.use_layernormalize = use_layernormalize
309 | self.skip_connection = skip_connection
310 | self.context_str = context_str
311 |
312 | self.num_hidden_layers = len(self.hidden_layers)
313 |
314 | for dim in self.hidden_layers:
315 | assert type(dim) == int
316 |
317 | self.layers = nn.ModuleList()
318 | if self.num_hidden_layers == 0:
319 | self.layers.append( SingleFeedForwardNN(input_dim = self.input_dim,
320 | output_dim = self.output_dim,
321 | dropout_rate = self.dropout_rate,
322 | activation = self.activation,
323 | use_layernormalize = False,
324 | skip_connection = False,
325 | context_str = self.context_str))
326 | else:
327 | self.layers.append( SingleFeedForwardNN(input_dim = self.input_dim,
328 | output_dim = self.hidden_layers[0],
329 | dropout_rate = self.dropout_rate,
330 | activation = self.activation,
331 | use_layernormalize = self.use_layernormalize,
332 | skip_connection = self.skip_connection,
333 | context_str = self.context_str))
334 |
335 | for i in range(self.num_hidden_layers-1):
336 | self.layers.append( SingleFeedForwardNN(input_dim = self.hidden_layers[i],
337 | output_dim = self.hidden_layers[i+1],
338 | dropout_rate = self.dropout_rate,
339 | activation = self.activation,
340 | use_layernormalize = self.use_layernormalize,
341 | skip_connection = self.skip_connection,
342 | context_str = self.context_str))
343 |
344 | self.layers.append( SingleFeedForwardNN(input_dim = self.hidden_layers[-1],
345 | output_dim = self.output_dim,
346 | dropout_rate = self.dropout_rate,
347 | activation = self.activation,
348 | use_layernormalize = False,
349 | skip_connection = False,
350 | context_str = self.context_str))
351 |
352 |
353 |
354 | def forward(self, input_tensor):
355 | '''
356 | Args:
357 | input_tensor: shape [batch_size, ..., input_dim]
358 | Returns:
359 | tensor of shape [batch_size, ..., output_dim]
360 | note there is no non-linearity applied to the output.
361 |
362 | Raises:
363 | Exception: If given activation or normalizer not supported.
364 | '''
365 | assert input_tensor.size()[-1] == self.input_dim
366 | output = input_tensor
367 | for i in range(len(self.layers)):
368 | output = self.layers[i](output)
369 |
370 | return output
371 |
372 |
373 | # from Presence-Only Geographical Priors for Fine-Grained Image Classification
374 | # www.vision.caltech.edu/~macaodha/projects/geopriors
375 |
376 | class ResLayer(nn.Module):
377 | def __init__(self, linear_size):
378 | super(ResLayer, self).__init__()
379 | self.l_size = linear_size
380 | self.nonlin1 = nn.ReLU(inplace=True)
381 | self.nonlin2 = nn.ReLU(inplace=True)
382 | self.dropout1 = nn.Dropout()
383 | self.w1 = nn.Linear(self.l_size, self.l_size)
384 | self.w2 = nn.Linear(self.l_size, self.l_size)
385 |
386 | def forward(self, x):
387 | y = self.w1(x)
388 | y = self.nonlin1(y)
389 | y = self.dropout1(y)
390 | y = self.w2(y)
391 | y = self.nonlin2(y)
392 | out = x + y
393 |
394 | return out
395 |
396 |
397 | class FCNet(nn.Module):
398 | # def __init__(self, num_inputs, num_classes, num_filts, num_users=1):
399 | def __init__(self, num_inputs, num_filts, num_hidden_layers):
400 | '''
401 | Args:
402 | num_inputs: input embedding diemntion
403 | num_filts: hidden embedding dimention
404 | num_hidden_layers: number of hidden layer
405 | '''
406 | super(FCNet, self).__init__()
407 | # self.inc_bias = False
408 | # self.class_emb = nn.Linear(num_filts, num_classes, bias=self.inc_bias)
409 | # self.user_emb = nn.Linear(num_filts, num_users, bias=self.inc_bias)
410 |
411 | # self.feats = nn.Sequential(nn.Linear(num_inputs, num_filts),
412 | # nn.ReLU(inplace=True),
413 | # ResLayer(num_filts),
414 | # ResLayer(num_filts),
415 | # ResLayer(num_filts),
416 | # ResLayer(num_filts))
417 | self.num_hidden_layers = num_hidden_layers
418 | self.feats = nn.Sequential()
419 | self.feats.add_module("ln_1", nn.Linear(num_inputs, num_filts))
420 | self.feats.add_module("relu_1", nn.ReLU(inplace=True))
421 | for i in range(num_hidden_layers):
422 | self.feats.add_module("resnet_{}".format(i+1), ResLayer(num_filts))
423 |
424 | # def forward(self, x, class_of_interest=None, return_feats=False):
425 | def forward(self, x):
426 | loc_emb = self.feats(x)
427 | # if return_feats:
428 | # return loc_emb
429 | # if class_of_interest is None:
430 | # class_pred = self.class_emb(loc_emb)
431 | # else:
432 | # class_pred = self.eval_single_class(loc_emb, class_of_interest)
433 |
434 | # return torch.sigmoid(class_pred)
435 | return loc_emb
436 |
437 | # def eval_single_class(self, x, class_of_interest):
438 | # if self.inc_bias:
439 | # return torch.matmul(x, self.class_emb.weight[class_of_interest, :]) + self.class_emb.bias[class_of_interest]
440 | # else:
441 | # return torch.matmul(x, self.class_emb.weight[class_of_interest, :])
--------------------------------------------------------------------------------
/polygoncode/polygonembed/ddsl_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 | from math import factorial
5 | import os
6 |
7 |
8 | def get_fourier_freqs(res, min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "geometric", dtype=torch.float32, exact=True, eps = 1e-4):
9 | """
10 | Helper function to return frequency tensors
11 | This is a generalization of fftfreqs()
12 | Args:
13 | res: n_dims int tuple of number of frequency modes, (fx, fy) for polygon
14 | max_freqXY: the maximum frequency
15 | min_freqXY: the minimum frequency
16 | freq_init: frequency generated method,
17 | "geometric": geometric series
18 | "fft": fast fourier transformation
19 | :return: omega:
20 | if extract = True
21 | for res = (fx, fy) => shape (fx, fy//2+1, 2)
22 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2+1, n)
23 | if extract = False
24 | for res = (fx, fy) => shape (fx, fy//2, 2)
25 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2, n)
26 | """
27 | n_dims = len(res)
28 | freqs = []
29 | for dim in range(n_dims - 1):
30 | r_ = res[dim]
31 | freq = make_fourier_freq_vector(res_dim = r_,
32 | min_freqXY = min_freqXY,
33 | max_freqXY = max_freqXY,
34 | mid_freqXY = mid_freqXY,
35 | freq_init = freq_init,
36 | get_full_vector = True)
37 | freqs.append(torch.tensor(freq, dtype=dtype))
38 | r_ = res[-1]
39 | freq = make_fourier_freq_vector(res_dim = r_,
40 | min_freqXY = min_freqXY,
41 | max_freqXY = max_freqXY,
42 | mid_freqXY = mid_freqXY,
43 | freq_init = freq_init,
44 | get_full_vector = False)
45 | if exact:
46 | freqs.append(torch.tensor(freq, dtype=dtype))
47 | else:
48 | freqs.append(torch.tensor(freq[:-1], dtype=dtype))
49 | omega = torch.meshgrid(freqs)
50 | omega = list(omega)
51 | omega = torch.stack(omega, dim=-1)
52 |
53 | omega[0, 0, :] = torch.FloatTensor(np.random.rand(2)*eps).to(omega.dtype)
54 | return omega
55 |
56 |
57 | def compute_geoemtric_series(min_, max_, num):
58 | log_timescale_increment = (math.log(float(max_) / float(min_)) /
59 | (num*1.0 - 1))
60 |
61 | timescales = min_ * np.exp(
62 | np.arange(num).astype(float) * log_timescale_increment)
63 | return timescales
64 |
65 | def make_freq_series(min_freqXY, max_freqXY, frequency_num, mid_freqXY = None, freq_init = "geometric"):
66 | if freq_init == "geometric":
67 | return compute_geoemtric_series(min_freqXY, max_freqXY, frequency_num)
68 | elif freq_init == "arith_geometric":
69 | assert mid_freqXY is not None
70 | assert min_freqXY < mid_freqXY < max_freqXY
71 | left_freq_num = int(frequency_num/2)
72 | right_freq_num = int(frequency_num - left_freq_num)
73 |
74 | # left: arithmatric
75 | left_freqs = np.linspace(start = min_freqXY, stop=mid_freqXY, num=left_freq_num, endpoint=False)
76 |
77 | # right: geometric
78 | right_freqs = compute_geoemtric_series(min_ = mid_freqXY, max_ = max_freqXY, num = right_freq_num)
79 |
80 | freqs = np.concatenate([left_freqs, right_freqs], axis = -1)
81 | return freqs
82 | elif freq_init == "geometric_arith":
83 | assert mid_freqXY is not None
84 | assert min_freqXY < mid_freqXY < max_freqXY
85 | left_freq_num = int(frequency_num/2)
86 | right_freq_num = int(frequency_num - left_freq_num)
87 |
88 | # left: geometric
89 | left_freqs = compute_geoemtric_series(min_ = min_freqXY, max_ = mid_freqXY, num = left_freq_num)
90 |
91 | # right: arithmatric
92 | right_freqs = np.linspace(start = mid_freqXY, stop=max_freqXY, num=right_freq_num+1, endpoint=True)
93 |
94 |
95 |
96 | freqs = np.concatenate([left_freqs, right_freqs[1:]], axis = -1)
97 | return freqs
98 | else:
99 | raise Exception(f"freq_init = {freq_init} is not implemented" )
100 |
101 | def make_fourier_freq_vector(res_dim, min_freqXY, max_freqXY, mid_freqXY = None, freq_init = "geometric", get_full_vector = True):
102 | '''
103 | make the frequency vector for X or Y dimention
104 | Args:
105 | res_dim: the total frequency we want
106 | max_freqXY: the maximum frequency
107 | min_freqXY: the minimum frequency
108 | get_full_vector: get the full frequency vector, or half of them (Y dimention)
109 | '''
110 | if freq_init == "fft":
111 | if get_full_vector:
112 | freq = np.fft.fftfreq(res_dim, d=1/res_dim)
113 | else:
114 | freq = np.fft.rfftfreq(res_dim, d=1/res_dim)
115 | else:
116 | half_freqs = make_freq_series(min_freqXY, max_freqXY, frequency_num = res_dim//2,
117 | mid_freqXY = mid_freqXY, freq_init = freq_init)
118 |
119 | if get_full_vector:
120 | neg_half_freqs = -np.flip(half_freqs, axis = -1)
121 | if res_dim % 2 == 0:
122 | freq = np.concatenate([np.array([0.0]), half_freqs[:-1], neg_half_freqs], axis = -1)
123 | else:
124 | freq = np.concatenate([np.array([0.0]), half_freqs, neg_half_freqs], axis = -1)
125 | else:
126 | if res_dim % 2 == 0:
127 | freq = np.concatenate([np.array([0.0]), half_freqs], axis = -1)
128 | else:
129 | freq = np.concatenate([np.array([0.0]), half_freqs], axis = -1)
130 |
131 | if get_full_vector:
132 | assert freq.shape[0] == res_dim
133 | else:
134 | assert freq.shape[0] == math.floor(res_dim*1.0/2)+1
135 | return freq
136 |
137 | def fftfreqs(res, dtype=torch.float32, exact=True, eps = 1e-4):
138 | """
139 | Helper function to return frequency tensors
140 | :param res: n_dims int tuple of number of frequency modes, (fx, fy) for polygon
141 | :return: omega:
142 | if extract = True
143 | for res = (fx, fy) => shape (fx, fy//2+1, 2)
144 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2+1, n)
145 | if extract = False
146 | for res = (fx, fy) => shape (fx, fy//2, 2)
147 | for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2, n)
148 | """
149 |
150 | n_dims = len(res)
151 | freqs = []
152 | for dim in range(n_dims - 1):
153 | r_ = res[dim]
154 | freq = np.fft.fftfreq(r_, d=1/r_)
155 | freqs.append(torch.tensor(freq, dtype=dtype))
156 | r_ = res[-1]
157 | if exact:
158 | freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
159 | else:
160 | freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
161 | omega = torch.meshgrid(freqs)
162 | omega = list(omega)
163 | omega = torch.stack(omega, dim=-1)
164 |
165 | omega[0, 0, :] = torch.FloatTensor(np.random.rand(2)*eps).to(omega.dtype)
166 |
167 | return omega
168 |
169 | # def fftfreqs(res, dtype=torch.float32, exact=True):
170 | # """
171 | # Helper function to return frequency tensors
172 | # :param res: n_dims int tuple of number of frequency modes, (fx, fy) for polygon
173 | # :return: omega:
174 | # if extract = True
175 | # for res = (fx, fy) => shape (fx, fy//2+1, 2)
176 | # for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2+1, n)
177 | # if extract = False
178 | # for res = (fx, fy) => shape (fx, fy//2, 2)
179 | # for res = (f1, f2, ..., fn) => shape (f1, f2, ..., f_{n-1}, fn//2, n)
180 | # """
181 |
182 | # n_dims = len(res)
183 | # freqs = []
184 | # for dim in range(n_dims - 1):
185 | # r_ = res[dim]
186 | # freq = np.fft.fftfreq(r_, d=1/r_)
187 | # freqs.append(torch.tensor(freq, dtype=dtype))
188 | # r_ = res[-1]
189 | # if exact:
190 | # freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
191 | # else:
192 | # freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
193 | # omega = torch.meshgrid(freqs)
194 | # omega = list(omega)
195 | # omega = torch.stack(omega, dim=-1)
196 |
197 | # return omega
198 |
199 |
200 | def permute_seq(i, len):
201 | """
202 | Permute the ordering of integer sequences
203 | """
204 | assert(i 0:
322 | vol2[:, neg_mask] = 0
323 | print("[!]Warning: zeroing {0} small negative number".format(torch.sum(neg_mask).item()))
324 | vol =torch.sqrt(vol2)
325 | else:
326 | # Equation 8 in https://openreview.net/pdf?id=B1G5ViAqFm
327 | # compute based on: https://en.m.wikipedia.org/wiki/Simplex#Volume
328 | assert(j == nd)
329 | # matrix determinant
330 | # x_{j+1} is auxilarity node with (0, 0) coordinate
331 | # [x1,x2,...,x_{j}] - [x_{j+1}]
332 | # P[:, :, :-1]: [x1,x2,...,x_{j}], shape (batch_size, n_elem = num_vert, j = 2, n_dims = 2),
333 | # P[:, :, -1:]: [x_{j+1}], shape (batch_size, n_elem = num_vert, 1, n_dims = 2)
334 | # mat: shape (batch_size, n_elem = num_vert, j = 2, n_dims = 2)
335 | mat = P[:, :, :-1] - P[:, :, -1:]
336 | # vol: shape (batch_size, n_elem = num_vert)
337 | vol = torch.linalg.det(mat) / math.factorial(j)
338 |
339 | return vol.unsqueeze(-1)
340 |
341 |
342 | def batch_det(A):
343 | """
344 | (No use) Batch compute determinant of square matrix A of shape (*, N, N)
345 |
346 | We can use torch.linalg.det() directly
347 | Return:
348 | Tensor of shape (*)
349 |
350 | 第一种初等变换—bai—交换两行或du列—zhi—要偶数次不改变dao行列的值
351 | 第二种初等zhuan变换——shu某行(列)乘以非0实数——这个可以乘以系数,但总的乘积必须为1方可不改变行列式值
352 | 第三种初等变换——某行(列)乘以实数加到另一行(列)上——此条对行列式值无影响
353 | """
354 | # 对矩阵进行LU 分解,显然,分解后的矩阵对角线上元素的乘积即为原始矩阵行列式的值
355 | # 由于在进行 LU 分解时可能会进行行交换的情况,而每次行交换都会带来行列式的符号变化,所以我们要记录行交换次数的奇偶性,
356 | # 奇数则符号改变,偶数则符号不变,请密切注意下面程序中出现的用来标记变换次数奇偶性的变量 parity
357 | # LU: the packed LU factorization matrix, (*, N, N)
358 | # pivots: (*, N)
359 | LU, pivots = torch.lu(A)
360 | # torch.einsum('...ii->...i', LU): the diagnoal vector
361 | # det_LU: the product of all diagnal values
362 | det_LU = torch.einsum('...ii->...i', LU).prod(-1)
363 | pivots -= 1
364 | d = pivots.shape[-1]
365 | perm = pivots - torch.arange(d, dtype=pivots.dtype, device=pivots.device).expand(pivots.shape)
366 | det_P = (-1) ** ((perm != 0).sum(-1))
367 | det = det_LU * det_P.type(det_LU.dtype)
368 |
369 | return det
370 |
371 |
372 |
373 | def make_E(V):
374 | '''
375 | here, we assume V comes from a simple polygon
376 | Given polygon vertice tensor -> V with shape (batch_size, num_vert, 2)
377 | Generate its edge matrix E
378 | Note, num_vert reflect all unique vertices, remove the repeated last/first node beforehand
379 | Here, num_vert: number of vertice of input polygon = number of edges = number pf 2-simplex (auxiliary node)
380 |
381 | Args:
382 | V with shape (batch_size, num_vert, 2)
383 | Return:
384 | E: torch.LongTensor(), shape (batch_size, num_vert, 2)
385 | '''
386 | batch_size, num_vert, n_dims = V.shape
387 | a = torch.arange(0, num_vert)
388 | E = torch.stack((a, a+1), dim = 0).permute(1,0)
389 | E[-1, -1] = 0
390 | E = torch.repeat_interleave(E.unsqueeze(0), repeats = batch_size, dim=0).to(V.device)
391 | return E
392 |
393 | def make_D(V):
394 | '''
395 | Given polygon vertice tensor -> V with shape (batch_size, num_vert, 2)
396 | Generate its density matrix D
397 | Note, num_vert reflect all unique vertices, remove the repeated last/first node beforehand
398 | Here, num_vert: number of vertice of input polygon = number of edges = number pf 2-simplex (auxiliary node)
399 |
400 | Args:
401 | V: shape (batch_size, num_vert, 2)
402 | Return:
403 | D: torch.LongTensor(), shape (batch_size, num_vert, 1)
404 | '''
405 | batch_size, num_vert, n_dims = V.shape
406 | D = torch.ones(batch_size, num_vert, 1).to(V.device)
407 | return D
408 |
409 | def affinity_V(V, extent):
410 | '''
411 | affinity vertice tensor to move it to [0, periodX, 0, periodY]
412 |
413 | Args:
414 | V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2)
415 | vertex tensor
416 | extent: the maximum spatial extent of all polygons, (minx, maxx, miny, maxy)
417 |
418 | eps: the maximum noise we add to each polygon vertice
419 | Retunr:
420 | V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2)
421 | '''
422 | device = V.device
423 | minx, maxx, miny, maxy = extent
424 |
425 | # assert maxx - minx == periodXY[0]
426 | # assert maxy - miny == periodXY[1]
427 |
428 | # affinity all polygons to make them has positive coordinates
429 | # move to (0,2,0,2)
430 | V = V + torch.FloatTensor([-minx, -miny]).to(device)
431 |
432 | return V
433 |
434 |
435 | def add_noise_V(V, eps):
436 | '''
437 | add small noise to each vertice to make NUFT more robust
438 | Args:
439 | V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2)
440 | vertex tensor
441 | eps: the maximum noise we add to each polygon vertice
442 | Retunr:
443 | V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2)
444 | '''
445 | # add small noise
446 | V = V + torch.rand(V.shape, device = V.device)*eps
447 | return V
448 |
449 |
450 |
451 |
452 |
453 |
454 | def make_periodXY(extent):
455 | '''
456 | Make periodXY based on the spatial extent
457 | Args:
458 | extent: (minx, maxx, miny, maxy)
459 | Return:
460 | periodXY: t in DDSL_spec(), [periodX, periodY]
461 | periodX, periodY: the spatial extend from [0, periodX]
462 | '''
463 | minx, maxx, miny, maxy = extent
464 |
465 | periodX = maxx - minx
466 | periodY = maxy - miny
467 |
468 | periodXY = [periodX, periodY]
469 | return periodXY
470 |
471 |
472 | def polygon_nuft_input(polygons, extent, V = None, E = None):
473 | '''
474 | polygons: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2)
475 | last points not equal to the 1st one
476 | extent: the maximum spatial extent of all polygons, (minx, maxx, miny, maxy)
477 | '''
478 | assert (polygons is None and V is not None and E is not None) or (polygons is not None and V is None and E is None)
479 | if polygons is not None:
480 | # V: torch.FloatTensor(), shape (batch_size, num_vert, n_dims = 2)
481 | # vertex tensor
482 | V = polygons
483 | # affinity vertice tensor to move it to [0, periodX, 0, periodY]
484 | V = affinity_V(V, extent)
485 | if E is None:
486 | # add noise
487 | # V = add_noise_V(V, self.eps)
488 | # E: torch.LongTensor(), shape (batch_size, num_vert, 2)
489 | # element tensor, each element [[0,1], [1,2],...,[num_vert-1, 0]]
490 | E = make_E(V)
491 | # D: torch.LongTensor(), shape (batch_size, num_vert, 1)
492 | # all be one tensor
493 | D = make_D(V)
494 |
495 |
496 | return V, E, D
--------------------------------------------------------------------------------
/polygoncode/polygonembed/dla.py:
--------------------------------------------------------------------------------
1 | """
2 | DLA for ImageNet-1K, implemented in PyTorch.
3 | Original paper: 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
4 | """
5 |
6 | __all__ = ['DLA', 'dla34', 'dla46c', 'dla46xc', 'dla60', 'dla60x', 'dla60xc', 'dla102', 'dla102x', 'dla102x2', 'dla169']
7 |
8 | import os
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.init as init
12 | from polygonembed.dla_common import conv1x1, conv1x1_block, conv3x3_block, conv7x7_block
13 | from polygonembed.dla_resnet import ResBlock, ResBottleneck
14 | from polygonembed.dla_resnext import ResNeXtBottleneck
15 |
16 |
17 | class DLABottleneck(ResBottleneck):
18 | """
19 | DLA bottleneck block for residual path in residual block.
20 |
21 | Parameters:
22 | ----------
23 | in_channels : int
24 | Number of input channels.
25 | out_channels : int
26 | Number of output channels.
27 | stride : int or tuple/list of 2 int
28 | Strides of the convolution.
29 | bottleneck_factor : int, default 2
30 | Bottleneck factor.
31 | """
32 | def __init__(self,
33 | in_channels,
34 | out_channels,
35 | stride,
36 | bottleneck_factor=2):
37 | super(DLABottleneck, self).__init__(
38 | in_channels=in_channels,
39 | out_channels=out_channels,
40 | stride=stride,
41 | bottleneck_factor=bottleneck_factor)
42 |
43 |
44 | class DLABottleneckX(ResNeXtBottleneck):
45 | """
46 | DLA ResNeXt-like bottleneck block for residual path in residual block.
47 |
48 | Parameters:
49 | ----------
50 | in_channels : int
51 | Number of input channels.
52 | out_channels : int
53 | Number of output channels.
54 | stride : int or tuple/list of 2 int
55 | Strides of the convolution.
56 | cardinality: int, default 32
57 | Number of groups.
58 | bottleneck_width: int, default 8
59 | Width of bottleneck block.
60 | """
61 | def __init__(self,
62 | in_channels,
63 | out_channels,
64 | stride,
65 | cardinality=32,
66 | bottleneck_width=8):
67 | super(DLABottleneckX, self).__init__(
68 | in_channels=in_channels,
69 | out_channels=out_channels,
70 | stride=stride,
71 | cardinality=cardinality,
72 | bottleneck_width=bottleneck_width)
73 |
74 |
75 | class DLAResBlock(nn.Module):
76 | """
77 | DLA residual block with residual connection.
78 |
79 | Parameters:
80 | ----------
81 | in_channels : int
82 | Number of input channels.
83 | out_channels : int
84 | Number of output channels.
85 | stride : int or tuple/list of 2 int
86 | Strides of the convolution.
87 | body_class : nn.Module, default ResBlock
88 | Residual block body class.
89 | return_down : bool, default False
90 | Whether return downsample result.
91 | """
92 | def __init__(self,
93 | in_channels,
94 | out_channels,
95 | stride,
96 | body_class=ResBlock,
97 | return_down=False):
98 | super(DLAResBlock, self).__init__()
99 | self.return_down = return_down
100 | self.downsample = (stride > 1)
101 | self.project = (in_channels != out_channels)
102 |
103 | self.body = body_class(
104 | in_channels=in_channels,
105 | out_channels=out_channels,
106 | stride=stride)
107 | self.activ = nn.ReLU(inplace=True)
108 | if self.downsample:
109 | self.downsample_pool = nn.MaxPool2d(
110 | kernel_size=stride,
111 | stride=stride)
112 | if self.project:
113 | self.project_conv = conv1x1_block(
114 | in_channels=in_channels,
115 | out_channels=out_channels,
116 | activation=None)
117 |
118 | def forward(self, x):
119 | down = self.downsample_pool(x) if self.downsample else x
120 | identity = self.project_conv(down) if self.project else down
121 | if identity is None:
122 | identity = x
123 | x = self.body(x)
124 | x += identity
125 | x = self.activ(x)
126 | if self.return_down:
127 | return x, down
128 | else:
129 | return x
130 |
131 |
132 | class DLARoot(nn.Module):
133 | """
134 | DLA root block.
135 |
136 | Parameters:
137 | ----------
138 | in_channels : int
139 | Number of input channels.
140 | out_channels : int
141 | Number of output channels.
142 | residual : bool
143 | Whether use residual connection.
144 | """
145 | def __init__(self,
146 | in_channels,
147 | out_channels,
148 | residual):
149 | super(DLARoot, self).__init__()
150 | self.residual = residual
151 |
152 | self.conv = conv1x1_block(
153 | in_channels=in_channels,
154 | out_channels=out_channels,
155 | activation=None)
156 | self.activ = nn.ReLU(inplace=True)
157 |
158 | def forward(self, x2, x1, extra):
159 | last_branch = x2
160 | x = torch.cat((x2, x1) + tuple(extra), dim=1)
161 | x = self.conv(x)
162 | if self.residual:
163 | x += last_branch
164 | x = self.activ(x)
165 | return x
166 |
167 |
168 | class DLATree(nn.Module):
169 | """
170 | DLA tree unit. It's like iterative stage.
171 |
172 | Parameters:
173 | ----------
174 | levels : int
175 | Number of levels in the stage.
176 | in_channels : int
177 | Number of input channels.
178 | out_channels : int
179 | Number of output channels.
180 | res_body_class : nn.Module
181 | Residual block body class.
182 | stride : int or tuple/list of 2 int
183 | Strides of the convolution in a residual block.
184 | root_residual : bool
185 | Whether use residual connection in the root.
186 | root_dim : int
187 | Number of input channels in the root block.
188 | first_tree : bool, default False
189 | Is this tree stage the first stage in the net.
190 | input_level : bool, default True
191 | Is this tree unit the first unit in the stage.
192 | return_down : bool, default False
193 | Whether return downsample result.
194 | """
195 | def __init__(self,
196 | levels,
197 | in_channels,
198 | out_channels,
199 | res_body_class,
200 | stride,
201 | root_residual,
202 | root_dim=0,
203 | first_tree=False,
204 | input_level=True,
205 | return_down=False):
206 | super(DLATree, self).__init__()
207 | self.return_down = return_down
208 | self.add_down = (input_level and not first_tree)
209 | self.root_level = (levels == 1)
210 |
211 | if root_dim == 0:
212 | root_dim = 2 * out_channels
213 | if self.add_down:
214 | root_dim += in_channels
215 |
216 | if self.root_level:
217 | self.tree1 = DLAResBlock(
218 | in_channels=in_channels,
219 | out_channels=out_channels,
220 | stride=stride,
221 | body_class=res_body_class,
222 | return_down=True)
223 | self.tree2 = DLAResBlock(
224 | in_channels=out_channels,
225 | out_channels=out_channels,
226 | stride=1,
227 | body_class=res_body_class,
228 | return_down=False)
229 | else:
230 | self.tree1 = DLATree(
231 | levels=levels - 1,
232 | in_channels=in_channels,
233 | out_channels=out_channels,
234 | res_body_class=res_body_class,
235 | stride=stride,
236 | root_residual=root_residual,
237 | root_dim=0,
238 | input_level=False,
239 | return_down=True)
240 | self.tree2 = DLATree(
241 | levels=levels - 1,
242 | in_channels=out_channels,
243 | out_channels=out_channels,
244 | res_body_class=res_body_class,
245 | stride=1,
246 | root_residual=root_residual,
247 | root_dim=root_dim + out_channels,
248 | input_level=False,
249 | return_down=False)
250 | if self.root_level:
251 | self.root = DLARoot(
252 | in_channels=root_dim,
253 | out_channels=out_channels,
254 | residual=root_residual)
255 |
256 | def forward(self, x, extra=None):
257 | extra = [] if extra is None else extra
258 | x1, down = self.tree1(x)
259 | if self.add_down:
260 | extra.append(down)
261 | if self.root_level:
262 | x2 = self.tree2(x1)
263 | x = self.root(x2, x1, extra)
264 | else:
265 | extra.append(x1)
266 | x = self.tree2(x1, extra)
267 | if self.return_down:
268 | return x, down
269 | else:
270 | return x
271 |
272 |
273 | class DLAInitBlock(nn.Module):
274 | """
275 | DLA specific initial block.
276 |
277 | Parameters:
278 | ----------
279 | in_channels : int
280 | Number of input channels.
281 | out_channels : int
282 | Number of output channels.
283 | """
284 | def __init__(self,
285 | in_channels,
286 | out_channels):
287 | super(DLAInitBlock, self).__init__()
288 | mid_channels = out_channels // 2
289 |
290 | self.conv1 = conv7x7_block(
291 | in_channels=in_channels,
292 | out_channels=mid_channels)
293 | self.conv2 = conv3x3_block(
294 | in_channels=mid_channels,
295 | out_channels=mid_channels)
296 | self.conv3 = conv3x3_block(
297 | in_channels=mid_channels,
298 | out_channels=out_channels,
299 | stride=2)
300 |
301 | def forward(self, x):
302 | '''
303 | Args:
304 | x: shape (N, in_channels, H, W)
305 | '''
306 | # x: shape (N, mid_channels, H, W)
307 | x = self.conv1(x)
308 | # x: shape (N, mid_channels, H, W)
309 | x = self.conv2(x)
310 | # x: shape (N, out_channels, (H+1)/2, (W+1)/2)
311 | x = self.conv3(x)
312 | return x
313 |
314 |
315 | class DLA(nn.Module):
316 | """
317 | DLA model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
318 |
319 | Parameters:
320 | ----------
321 | levels : int
322 | Number of levels in each stage.
323 | channels : list of int
324 | Number of output channels for each stage.
325 | init_block_channels : int
326 | Number of output channels for the initial unit.
327 | res_body_class : nn.Module
328 | Residual block body class.
329 | residual_root : bool
330 | Whether use residual connection in the root blocks.
331 | in_channels : int, default 3
332 | Number of input channels.
333 | in_size : tuple of two ints, default (224, 224)
334 | Spatial size of the expected input image.
335 | num_classes : int, default 1000
336 | Number of classification classes.
337 | """
338 | def __init__(self,
339 | levels,
340 | channels,
341 | init_block_channels,
342 | res_body_class,
343 | residual_root,
344 | in_channels=3,
345 | in_size=(224, 224),
346 | num_classes=1000):
347 | super(DLA, self).__init__()
348 | self.in_size = in_size
349 | self.num_classes = num_classes
350 |
351 | self.features = nn.Sequential()
352 | self.features.add_module("init_block", DLAInitBlock(
353 | in_channels=in_channels,
354 | out_channels=init_block_channels))
355 | in_channels = init_block_channels
356 |
357 | for i in range(len(levels)):
358 | levels_i = levels[i]
359 | out_channels = channels[i]
360 | first_tree = (i == 0)
361 | self.features.add_module("stage{}".format(i + 1), DLATree(
362 | levels=levels_i,
363 | in_channels=in_channels,
364 | out_channels=out_channels,
365 | res_body_class=res_body_class,
366 | stride=2,
367 | root_residual=residual_root,
368 | first_tree=first_tree))
369 | in_channels = out_channels
370 |
371 | self.features.add_module("final_pool", nn.AvgPool2d(
372 | kernel_size=7,
373 | stride=1))
374 |
375 | self.output = conv1x1(
376 | in_channels=in_channels,
377 | out_channels=num_classes,
378 | bias=True)
379 |
380 | self._init_params()
381 |
382 | def _init_params(self):
383 | for name, module in self.named_modules():
384 | if isinstance(module, nn.Conv2d):
385 | init.kaiming_uniform_(module.weight)
386 | if module.bias is not None:
387 | init.constant_(module.bias, 0)
388 |
389 | def forward(self, x):
390 | x = self.features(x)
391 | print(x.shape)
392 | x = self.output(x)
393 | print(x.shape)
394 | x = x.view(x.size(0), -1)
395 | print(x.shape)
396 | return x
397 |
398 |
399 | def get_dla(levels,
400 | channels,
401 | res_body_class,
402 | residual_root=False,
403 | model_name=None,
404 | pretrained=False,
405 | root=os.path.join("~", ".torch", "models"),
406 | **kwargs):
407 | """
408 | Create DLA model with specific parameters.
409 |
410 | Parameters:
411 | ----------
412 | levels : int
413 | Number of levels in each stage.
414 | channels : list of int
415 | Number of output channels for each stage.
416 | res_body_class : nn.Module
417 | Residual block body class.
418 | residual_root : bool, default False
419 | Whether use residual connection in the root blocks.
420 | model_name : str or None, default None
421 | Model name for loading pretrained model.
422 | pretrained : bool, default False
423 | Whether to load the pretrained weights for model.
424 | root : str, default '~/.torch/models'
425 | Location for keeping the model parameters.
426 | """
427 | init_block_channels = 32
428 |
429 | net = DLA(
430 | levels=levels,
431 | channels=channels,
432 | init_block_channels=init_block_channels,
433 | res_body_class=res_body_class,
434 | residual_root=residual_root,
435 | **kwargs)
436 |
437 | if pretrained:
438 | if (model_name is None) or (not model_name):
439 | raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
440 | from .model_store import download_model
441 | download_model(
442 | net=net,
443 | model_name=model_name,
444 | local_model_store_dir_path=root)
445 |
446 | return net
447 |
448 |
449 | def dla34(**kwargs):
450 | """
451 | DLA-34 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
452 |
453 | Parameters:
454 | ----------
455 | pretrained : bool, default False
456 | Whether to load the pretrained weights for model.
457 | root : str, default '~/.torch/models'
458 | Location for keeping the model parameters.
459 | """
460 | return get_dla(levels=[1, 2, 2, 1], channels=[64, 128, 256, 512], res_body_class=ResBlock, model_name="dla34",
461 | **kwargs)
462 |
463 |
464 | def dla46c(**kwargs):
465 | """
466 | DLA-46-C model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
467 |
468 | Parameters:
469 | ----------
470 | pretrained : bool, default False
471 | Whether to load the pretrained weights for model.
472 | root : str, default '~/.torch/models'
473 | Location for keeping the model parameters.
474 | """
475 | return get_dla(levels=[1, 2, 2, 1], channels=[64, 64, 128, 256], res_body_class=DLABottleneck, model_name="dla46c",
476 | **kwargs)
477 |
478 |
479 | def dla46xc(**kwargs):
480 | """
481 | DLA-X-46-C model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
482 |
483 | Parameters:
484 | ----------
485 | pretrained : bool, default False
486 | Whether to load the pretrained weights for model.
487 | root : str, default '~/.torch/models'
488 | Location for keeping the model parameters.
489 | """
490 | return get_dla(levels=[1, 2, 2, 1], channels=[64, 64, 128, 256], res_body_class=DLABottleneckX,
491 | model_name="dla46xc", **kwargs)
492 |
493 |
494 | def dla60(**kwargs):
495 | """
496 | DLA-60 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
497 |
498 | Parameters:
499 | ----------
500 | pretrained : bool, default False
501 | Whether to load the pretrained weights for model.
502 | root : str, default '~/.torch/models'
503 | Location for keeping the model parameters.
504 | """
505 | return get_dla(levels=[1, 2, 3, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneck,
506 | model_name="dla60", **kwargs)
507 |
508 |
509 | def dla60x(**kwargs):
510 | """
511 | DLA-X-60 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
512 |
513 | Parameters:
514 | ----------
515 | pretrained : bool, default False
516 | Whether to load the pretrained weights for model.
517 | root : str, default '~/.torch/models'
518 | Location for keeping the model parameters.
519 | """
520 | return get_dla(levels=[1, 2, 3, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneckX,
521 | model_name="dla60x", **kwargs)
522 |
523 |
524 | def dla60xc(**kwargs):
525 | """
526 | DLA-X-60-C model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
527 |
528 | Parameters:
529 | ----------
530 | pretrained : bool, default False
531 | Whether to load the pretrained weights for model.
532 | root : str, default '~/.torch/models'
533 | Location for keeping the model parameters.
534 | """
535 | return get_dla(levels=[1, 2, 3, 1], channels=[64, 64, 128, 256], res_body_class=DLABottleneckX,
536 | model_name="dla60xc", **kwargs)
537 |
538 |
539 | def dla102(**kwargs):
540 | """
541 | DLA-102 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
542 |
543 | Parameters:
544 | ----------
545 | pretrained : bool, default False
546 | Whether to load the pretrained weights for model.
547 | root : str, default '~/.torch/models'
548 | Location for keeping the model parameters.
549 | """
550 | return get_dla(levels=[1, 3, 4, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneck,
551 | residual_root=True, model_name="dla102", **kwargs)
552 |
553 |
554 | def dla102x(**kwargs):
555 | """
556 | DLA-X-102 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
557 |
558 | Parameters:
559 | ----------
560 | pretrained : bool, default False
561 | Whether to load the pretrained weights for model.
562 | root : str, default '~/.torch/models'
563 | Location for keeping the model parameters.
564 | """
565 | return get_dla(levels=[1, 3, 4, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneckX,
566 | residual_root=True, model_name="dla102x", **kwargs)
567 |
568 |
569 | def dla102x2(**kwargs):
570 | """
571 | DLA-X2-102 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
572 |
573 | Parameters:
574 | ----------
575 | pretrained : bool, default False
576 | Whether to load the pretrained weights for model.
577 | root : str, default '~/.torch/models'
578 | Location for keeping the model parameters.
579 | """
580 | class DLABottleneckX64(DLABottleneckX):
581 | def __init__(self, in_channels, out_channels, stride):
582 | super(DLABottleneckX64, self).__init__(in_channels, out_channels, stride, cardinality=64)
583 |
584 | return get_dla(levels=[1, 3, 4, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneckX64,
585 | residual_root=True, model_name="dla102x2", **kwargs)
586 |
587 |
588 | def dla169(**kwargs):
589 | """
590 | DLA-169 model from 'Deep Layer Aggregation,' https://arxiv.org/abs/1707.06484.
591 |
592 | Parameters:
593 | ----------
594 | pretrained : bool, default False
595 | Whether to load the pretrained weights for model.
596 | root : str, default '~/.torch/models'
597 | Location for keeping the model parameters.
598 | """
599 | return get_dla(levels=[2, 3, 5, 1], channels=[128, 256, 512, 1024], res_body_class=DLABottleneck,
600 | residual_root=True, model_name="dla169", **kwargs)
601 |
602 |
603 | def _calc_width(net):
604 | import numpy as np
605 | net_params = filter(lambda p: p.requires_grad, net.parameters())
606 | weight_count = 0
607 | for param in net_params:
608 | weight_count += np.prod(param.size())
609 | return weight_count
610 |
611 |
612 | def _test():
613 | import torch
614 |
615 | pretrained = False
616 |
617 | models = [
618 | dla34,
619 | dla46c,
620 | dla46xc,
621 | dla60,
622 | dla60x,
623 | dla60xc,
624 | dla102,
625 | dla102x,
626 | dla102x2,
627 | dla169,
628 | ]
629 |
630 | for model in models:
631 |
632 | net = model(pretrained=pretrained)
633 |
634 | # net.train()
635 | net.eval()
636 | weight_count = _calc_width(net)
637 | print("m={}, {}".format(model.__name__, weight_count))
638 | assert (model != dla34 or weight_count == 15742104)
639 | assert (model != dla46c or weight_count == 1301400)
640 | assert (model != dla46xc or weight_count == 1068440)
641 | assert (model != dla60 or weight_count == 22036632)
642 | assert (model != dla60x or weight_count == 17352344)
643 | assert (model != dla60xc or weight_count == 1319832)
644 | assert (model != dla102 or weight_count == 33268888)
645 | assert (model != dla102x or weight_count == 26309272)
646 | assert (model != dla102x2 or weight_count == 41282200)
647 | assert (model != dla169 or weight_count == 53389720)
648 |
649 | x = torch.randn(1, 3, 224, 224)
650 | y = net(x)
651 | y.sum().backward()
652 | assert (tuple(y.size()) == (1, 1000))
653 |
654 |
655 | if __name__ == "__main__":
656 | _test()
657 |
--------------------------------------------------------------------------------
/polygoncode/polygonembed/resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | COde rewritten based on https://github.com/ldeecke/mn-torch/blob/master/nn/resnet.py
3 | '''
4 |
5 |
6 | import functools
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from math import sqrt
12 | from torch.nn import BatchNorm1d
13 |
14 | from polygonembed.atten import *
15 |
16 | # from polygonembed.ops import ModeNorm
17 |
18 | # resnet20 = lambda config: ResNet(BasicBlock, [3, 3, 3], config)
19 | # resnet56 = lambda config: ResNet(BasicBlock, [9, 9, 9], config)
20 | # resnet110 = lambda config: ResNet(BasicBlock, [18, 18, 18], config)
21 |
22 |
23 | def get_agg_func(agg_func_type, out_channels, name = "resnet1d"):
24 | if agg_func_type == "mean":
25 | # global average pool
26 | return torch.mean
27 | elif agg_func_type == "min":
28 | # global min pool
29 | return torch.min
30 | elif agg_func_type == "max":
31 | # global max pool
32 | return torch.max
33 | elif agg_func_type.startswith("atten"):
34 | # agg_func_type: atten_whole_no_1
35 | atten_flag, att_type, bn, nat = agg_func_type.split("_")
36 | assert atten_flag == "atten"
37 | return AttentionSet(mode_dims = out_channels,
38 | att_reg = 0.,
39 | att_tem = 1.,
40 | att_type = att_type,
41 | bn = bn,
42 | nat= int(nat),
43 | name = name)
44 |
45 |
46 | class ResNet1D(nn.Module):
47 | def __init__(self, block, num_layer_list, in_channels, out_channels, add_middle_pool = False, final_pool = "mean", padding_mode = 'circular', dropout_rate = 0.5):
48 | '''
49 | Args:
50 | block: BasicBlock() or BottleneckBlock()
51 | num_layer_list: [num_blocks0, num_blocks1, num_blocks2]
52 | inplanes: input number of channel
53 | '''
54 | super(ResNet1D, self).__init__()
55 |
56 | Norm = functools.partial(BatchNorm1d)
57 |
58 | self.num_layer_list = num_layer_list
59 |
60 | self.in_channels = in_channels
61 | # For simplicity, make outplanes dividable by block.expansion
62 | assert out_channels % block.expansion == 0
63 | self.out_channels = out_channels
64 | planes = int(out_channels / block.expansion)
65 |
66 | self.inplanes = out_channels
67 |
68 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode = padding_mode, bias=False)
69 |
70 |
71 | self.norm1 = Norm(out_channels)
72 | self.relu = nn.ReLU(inplace=True)
73 |
74 | self.maxpool = nn.MaxPool1d(kernel_size = 2, stride=2, padding = 0)
75 |
76 | resnet_layers = []
77 |
78 | if len(self.num_layer_list) >= 1:
79 | # you have at least one number in self.num_layer_list
80 | layer1 = self._make_layer(block, planes, self.num_layer_list[0], Norm, padding_mode = padding_mode)
81 | resnet_layers.append(layer1)
82 | if add_middle_pool and len(self.num_layer_list) > 1:
83 | maxpool = nn.MaxPool1d(kernel_size = 2, stride=2, padding = 0)
84 | resnet_layers.append(maxpool)
85 |
86 | if len(self.num_layer_list) >= 2:
87 | # you have at least two numbers in self.num_layer_list
88 | for i in range(1, len(self.num_layer_list)):
89 | layerk = self._make_layer(block, planes, self.num_layer_list[i], Norm, stride=2, padding_mode = padding_mode)
90 | resnet_layers.append(layerk)
91 | if add_middle_pool and i < len(self.num_layer_list) - 1:
92 | maxpool = nn.MaxPool1d(kernel_size = 2, stride=2, padding = 0)
93 | resnet_layers.append(maxpool)
94 |
95 | self.resnet_layers = nn.Sequential(*resnet_layers)
96 |
97 | self.final_pool = final_pool
98 |
99 | self.final_pool_func = get_agg_func(agg_func_type = final_pool,
100 | out_channels = out_channels,
101 | name = "resnet1d")
102 |
103 | self.dropout = nn.Dropout(p=dropout_rate)
104 |
105 | # if len(self.num_layer_list) == 3:
106 | # # you have three numbers in self.num_layer_list
107 | # self.layer3 = self._make_layer(block, planes, self.num_layer_list[2], Norm, stride=2, padding_mode = padding_mode)
108 |
109 | # self.avgpool = nn.AvgPool1d(8, stride=1)
110 | # self.fc = nn.Linear(64 * block.expansion, num_classes)
111 |
112 | self._init_weights()
113 |
114 | # def finalPool1d_func(self, final_pool, out_channels, name = "resnet1d"):
115 | # if final_pool == "mean":
116 | # # global average pool
117 | # return torch.mean
118 | # elif final_pool == "min":
119 | # # global min pool
120 | # return torch.min
121 | # elif final_pool == "max":
122 | # # global max pool
123 | # return torch.max
124 | # elif final_pool.startswith("atten"):
125 | # # final_pool: atten_whole_no_1
126 | # atten_flag, att_type, bn, nat = final_pool.split("_")
127 | # assert atten_flag == "atten"
128 | # return AttentionSet(mode_dims = out_channels,
129 | # att_reg = 0.,
130 | # att_tem = 1.,
131 | # att_type = att_type,
132 | # bn = bn,
133 | # nat= int(nat),
134 | # name = name)
135 |
136 | def finalPool1d(self, x, final_pool = "mean"):
137 | '''
138 | Args:
139 | x: shape (batch_size, out_channels, (seq_len+k-2)/2^k )
140 | Return:
141 |
142 | '''
143 | if final_pool == "mean":
144 | # global average pool
145 | # x: shape (batch_size, out_channels)
146 | x = self.final_pool_func(x, dim = -1, keepdim = False)
147 | elif final_pool == "min":
148 | # global min pool
149 | # x: shape (batch_size, out_channels)
150 | x, indice = self.final_pool_func(x, dim = -1, keepdim = False)
151 | elif final_pool == "max":
152 | # global max pool
153 | # x: shape (batch_size, out_channels)
154 | x, indice = self.final_pool_func(x, dim = -1, keepdim = False)
155 | elif final_pool.startswith("atten"):
156 | # attenion based aggregation
157 | # x: shape (batch_size, out_channels)
158 | x = self.final_pool_func(x)
159 | return x
160 |
161 | def forward(self, x):
162 | '''
163 | Args:
164 | x: shape (batch_size, in_channels, seq_len)
165 | Return:
166 | x: shape (batch_size, out_channels)
167 | '''
168 | # x: shape (batch_size, out_channels, seq_len)
169 | x = self.conv1(x)
170 | x = self.norm1(x)
171 | x = self.relu(x)
172 | # print("conv1:", x.shape)
173 |
174 | # x: shape (batch_size, out_channels, seq_len/2)
175 | x = self.maxpool(x)
176 |
177 | # x: shape (batch_size, out_channels, (seq_len+k-2)/2^k )
178 | x = self.resnet_layers(x)
179 |
180 | # if len(self.num_layer_list) >= 1:
181 | # # After 1st block: shape (batch_size, out_channels, seq_len/2)
182 | # # x: shape (batch_size, out_channels, seq_len/2)
183 | # x = self.layer1(x)
184 | # # print("layer1:", x.shape)
185 |
186 | # if len(self.num_layer_list) >= 2:
187 | # # After 1st block: shape (batch_size, out_channels, (seq_len+2)/4 )
188 | # # x: shape (batch_size, out_channels, (seq_len+2)/4 )
189 | # x = self.layer2(x)
190 | # # print("layer2:", x.shape)
191 |
192 | # if len(self.num_layer_list) == 3:
193 | # # After 1st block: shape (batch_size, out_channels, (seq_len+6)/8 )
194 | # # x: shape (batch_size, out_channels, (seq_len+6)/8 )
195 | # x = self.layer3(x)
196 | # # print("layer3:", x.shape)
197 |
198 |
199 |
200 | # global pool
201 | # x: shape (batch_size, out_channels)
202 | x = self.finalPool1d(x, self.final_pool)
203 | # x = torch.mean(x, dim = -1, keepdim = False)
204 | # print("avgpool:", x.shape)
205 |
206 |
207 | x = self.dropout(x)
208 |
209 | # x: shape (batch_size, 64*expansion, (seq_len-25)/4 )
210 | # x = self.avgpool(x)
211 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 )
212 | # x = x.view(x.size(0), -1)
213 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 )
214 | # x = self.fc(x)
215 | # print("output:", x.shape)
216 |
217 | return x
218 |
219 |
220 | def _init_weights(self):
221 | for m in self.modules():
222 | if isinstance(m, nn.Conv1d):
223 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224 | n = m.kernel_size[0] * m.out_channels
225 | m.weight.data.normal_(0, sqrt(2./n))
226 | elif isinstance(m, nn.BatchNorm1d):
227 | m.weight.data.fill_(1)
228 | m.bias.data.zero_()
229 | # elif isinstance(m, ModeNorm):
230 | # m.alpha.data.fill_(1)
231 | # m.beta.data.zero_()
232 |
233 |
234 | def _make_layer(self, block, planes, blocks, norm, stride=1, padding_mode = 'circular'):
235 | downsample = None
236 | if (stride != 1) or (self.inplanes != planes * block.expansion):
237 | downsample = nn.Sequential(
238 | nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
239 | norm(planes * block.expansion)
240 | )
241 |
242 | layers = []
243 | layers.append(block(self.inplanes, planes, norm, stride, downsample, padding_mode))
244 | # self.inplanes = planes * block.expansion
245 | for _ in range(1, blocks):
246 | layers.append(block(self.inplanes, planes, norm, padding_mode = padding_mode))
247 |
248 | return nn.Sequential(*layers)
249 |
250 |
251 | class ResNet1D3(nn.Module):
252 | def __init__(self, block, num_layer_list, in_channels, out_channels, padding_mode = 'circular'):
253 | '''
254 | Args:
255 | block: BasicBlock() or BottleneckBlock()
256 | num_layer_list: [num_blocks0, num_blocks1, num_blocks2]
257 | inplanes: input number of channel
258 | '''
259 | super(ResNet1D3, self).__init__()
260 |
261 | Norm = functools.partial(BatchNorm1d)
262 |
263 | self.num_layer_list = num_layer_list
264 |
265 | self.in_channels = in_channels
266 | # For simplicity, make outplanes dividable by block.expansion
267 | assert out_channels % block.expansion == 0
268 | self.out_channels = out_channels
269 | planes = int(out_channels / block.expansion)
270 |
271 | self.inplanes = out_channels
272 |
273 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode = padding_mode, bias=False)
274 |
275 |
276 | self.norm1 = Norm(out_channels)
277 | self.relu = nn.ReLU(inplace=True)
278 |
279 | self.maxpool = nn.MaxPool1d(kernel_size = 2, stride=2, padding = 0)
280 |
281 | if len(self.num_layer_list) >= 1:
282 | # you have at least one number in self.num_layer_list
283 | self.layer1 = self._make_layer(block, planes, self.num_layer_list[0], Norm, padding_mode = padding_mode)
284 |
285 | if len(self.num_layer_list) >= 2:
286 | # you have at least two numbers in self.num_layer_list
287 | self.layer2 = self._make_layer(block, planes, self.num_layer_list[1], Norm, stride=2, padding_mode = padding_mode)
288 |
289 | if len(self.num_layer_list) == 3:
290 | # you have three numbers in self.num_layer_list
291 | self.layer3 = self._make_layer(block, planes, self.num_layer_list[2], Norm, stride=2, padding_mode = padding_mode)
292 |
293 | # self.avgpool = nn.AvgPool1d(8, stride=1)
294 | # self.fc = nn.Linear(64 * block.expansion, num_classes)
295 |
296 | self._init_weights()
297 |
298 |
299 | def forward(self, x):
300 | '''
301 | Args:
302 | x: shape (batch_size, in_channels, seq_len)
303 | Return:
304 | x: shape (batch_size, out_channels)
305 | '''
306 | # x: shape (batch_size, out_channels, seq_len)
307 | x = self.conv1(x)
308 | x = self.norm1(x)
309 | x = self.relu(x)
310 | # print("conv1:", x.shape)
311 |
312 | # x: shape (batch_size, out_channels, seq_len/2)
313 | x = self.maxpool(x)
314 |
315 | if len(self.num_layer_list) >= 1:
316 | # After 1st block: shape (batch_size, out_channels, seq_len/2)
317 | # x: shape (batch_size, out_channels, seq_len/2)
318 | x = self.layer1(x)
319 | # print("layer1:", x.shape)
320 |
321 | if len(self.num_layer_list) >= 2:
322 | # After 1st block: shape (batch_size, out_channels, (seq_len+2)/4 )
323 | # x: shape (batch_size, out_channels, (seq_len+2)/4 )
324 | x = self.layer2(x)
325 | # print("layer2:", x.shape)
326 |
327 | if len(self.num_layer_list) == 3:
328 | # After 1st block: shape (batch_size, out_channels, (seq_len+6)/8 )
329 | # x: shape (batch_size, out_channels, (seq_len+6)/8 )
330 | x = self.layer3(x)
331 | # print("layer3:", x.shape)
332 |
333 | # global average pool
334 | # x: shape (batch_size, out_channels)
335 | x = torch.mean(x, dim = -1, keepdim = False)
336 | # print("avgpool:", x.shape)
337 |
338 | # x: shape (batch_size, 64*expansion, (seq_len-25)/4 )
339 | # x = self.avgpool(x)
340 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 )
341 | # x = x.view(x.size(0), -1)
342 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 )
343 | # x = self.fc(x)
344 | # print("output:", x.shape)
345 |
346 | return x
347 |
348 |
349 | def _init_weights(self):
350 | for m in self.modules():
351 | if isinstance(m, nn.Conv1d):
352 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
353 | n = m.kernel_size[0] * m.out_channels
354 | m.weight.data.normal_(0, sqrt(2./n))
355 | elif isinstance(m, nn.BatchNorm1d):
356 | m.weight.data.fill_(1)
357 | m.bias.data.zero_()
358 | # elif isinstance(m, ModeNorm):
359 | # m.alpha.data.fill_(1)
360 | # m.beta.data.zero_()
361 |
362 |
363 | def _make_layer(self, block, planes, blocks, norm, stride=1, padding_mode = 'circular'):
364 | downsample = None
365 | if (stride != 1) or (self.inplanes != planes * block.expansion):
366 | downsample = nn.Sequential(
367 | nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
368 | norm(planes * block.expansion)
369 | )
370 |
371 | layers = []
372 | layers.append(block(self.inplanes, planes, norm, stride, downsample, padding_mode))
373 | # self.inplanes = planes * block.expansion
374 | for _ in range(1, blocks):
375 | layers.append(block(self.inplanes, planes, norm, padding_mode = padding_mode))
376 |
377 | return nn.Sequential(*layers)
378 |
379 |
380 | class ResNet1DLucas(nn.Module):
381 | def __init__(self, block, layers, inplanes, num_classes, padding_mode = 'circular'):
382 | '''
383 | Args:
384 | layers: [num_blocks0, num_blocks1, num_blocks2]
385 | '''
386 | super(ResNet1DLucas, self).__init__()
387 | # self.mn = config.mn
388 |
389 | # if config.mn == "full":
390 | # Norm = functools.partial(ModeNorm, momentum=config.momentum, n_components=config.num_components)
391 | # elif config.mn == "init":
392 | # InitNorm = functools.partial(ModeNorm, momentum=config.momentum, n_components=config.num_components)
393 | # Norm = functools.partial(BatchNorm1d, momentum=config.momentum)
394 | Norm = functools.partial(BatchNorm1d)
395 |
396 | self.inplanes = 16
397 | self.conv1 = nn.Conv1d(inplanes, 16, kernel_size=3, stride=1, padding=1, padding_mode = padding_mode, bias=False)
398 |
399 | # self.norm1 = InitNorm(16) if config.mn == "init" else Norm(16)
400 | self.norm1 = Norm(16)
401 | self.relu = nn.ReLU(inplace=True)
402 |
403 | self.layer1 = self._make_layer(block, 16, layers[0], Norm, padding_mode = padding_mode)
404 | self.layer2 = self._make_layer(block, 32, layers[1], Norm, stride=2, padding_mode = padding_mode)
405 | self.layer3 = self._make_layer(block, 64, layers[2], Norm, stride=2, padding_mode = padding_mode)
406 | # self.avgpool = nn.AvgPool1d(8, stride=1)
407 | self.fc = nn.Linear(64 * block.expansion, num_classes)
408 |
409 | self._init_weights()
410 |
411 |
412 | def forward(self, x):
413 | '''
414 | Args:
415 | x: shape (batch_size, 3, seq_len)
416 | '''
417 | # x: shape (batch_size, 16, seq_len)
418 | x = self.conv1(x)
419 | x = self.norm1(x)
420 | x = self.relu(x)
421 |
422 | # After 1st block: shape (batch_size, 16*expansion, seq_len)
423 | # x: shape (batch_size, 16*expansion, seq_len)
424 | x = self.layer1(x)
425 |
426 | # After 1st block: shape (batch_size, 32*expansion, (seq_len+1)/2 )
427 | # x: shape (batch_size, 32*expansion, (seq_len+1)/2 )
428 | x = self.layer2(x)
429 |
430 | # After 1st block: shape (batch_size, 64*expansion, (seq_len+3)/4 )
431 | # x: shape (batch_size, 64*expansion, (seq_len+3)/4 )
432 | x = self.layer3(x)
433 |
434 | # global average pool
435 | # x: shape (batch_size, 64*expansion)
436 | x = torch.mean(x, dim = -1, keepdim = False)
437 |
438 | # x: shape (batch_size, 64*expansion, (seq_len-25)/4 )
439 | # x = self.avgpool(x)
440 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 )
441 | # x = x.view(x.size(0), -1)
442 | # x: shape (batch_size, 64*expansion * (seq_len-25)/4 )
443 | x = self.fc(x)
444 |
445 | return x
446 |
447 |
448 | def _init_weights(self):
449 | for m in self.modules():
450 | if isinstance(m, nn.Conv1d):
451 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
452 | n = m.kernel_size[0] * m.out_channels
453 | m.weight.data.normal_(0, sqrt(2./n))
454 | elif isinstance(m, nn.BatchNorm1d):
455 | m.weight.data.fill_(1)
456 | m.bias.data.zero_()
457 | # elif isinstance(m, ModeNorm):
458 | # m.alpha.data.fill_(1)
459 | # m.beta.data.zero_()
460 |
461 |
462 | def _make_layer(self, block, planes, blocks, norm, stride=1, padding_mode = 'circular'):
463 | downsample = None
464 | if (stride != 1) or (self.inplanes != planes * block.expansion):
465 | downsample = nn.Sequential(
466 | nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
467 | norm(planes * block.expansion)
468 | )
469 |
470 | layers = []
471 | layers.append(block(self.inplanes, planes, norm, stride, downsample, padding_mode))
472 | self.inplanes = planes * block.expansion
473 | for _ in range(1, blocks):
474 | layers.append(block(self.inplanes, planes, norm, padding_mode = padding_mode))
475 |
476 | return nn.Sequential(*layers)
477 |
478 |
479 | class BasicBlock(nn.Module):
480 | expansion=1
481 |
482 | def __init__(self, inplanes, planes, norm, stride=1, downsample=None, padding_mode = 'circular'):
483 | super(BasicBlock, self).__init__()
484 | self.conv1 = self._conv3(inplanes, planes, stride, padding_mode = padding_mode)
485 | self.norm1 = norm(planes)
486 | self.relu = nn.ReLU(inplace=True)
487 | self.conv2 = self._conv3(planes, planes, padding_mode = padding_mode)
488 | self.norm2 = norm(planes)
489 | self.downsample = downsample
490 | self.stride = stride
491 | self.padding_mode = padding_mode
492 |
493 |
494 | def forward(self, x):
495 | '''
496 | Args:
497 | x: shape (batch_size, in_planes, seq_len)
498 | return:
499 | out: shape (batch_size, planes, (seq_len-1)/stride + 1 )
500 | '''
501 | residual = x
502 |
503 | # out: shape (batch_size, planes, (seq_len-1)/stride + 1 )
504 | out = self.conv1(x)
505 | out = self.norm1(out)
506 | out = self.relu(out)
507 |
508 | # out: shape (batch_size, planes, (seq_len-1)/stride + 1 )
509 | out = self.conv2(out)
510 | out = self.norm2(out)
511 |
512 | if self.downsample is not None:
513 | # residual: shape (batch_size, planes, (seq_len-1)/stride + 1 )
514 | residual = self.downsample(x)
515 |
516 | out += residual
517 | out = self.relu(out)
518 |
519 | return out
520 |
521 |
522 | def _conv3(self, in_planes, out_planes, stride=1, padding_mode = 'circular'):
523 | '''3x3 convolution with padding'''
524 | return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, padding_mode = padding_mode, bias=False)
525 |
526 |
527 | class BottleneckBlock(nn.Module):
528 | expansion = 4
529 |
530 | def __init__(self, in_planes, planes, norm, stride=1, downsample=None, padding_mode = 'circular', ):
531 | super(BottleneckBlock, self).__init__()
532 | self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=1, bias=False)
533 | self.norm1 = norm(planes)
534 | self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=stride, padding=1, padding_mode = padding_mode, bias=False)
535 | self.norm2 = norm(planes)
536 | self.conv3 = nn.Conv1d(planes, self.expansion*planes, kernel_size=1, bias=False)
537 | self.norm3 = norm(self.expansion*planes)
538 |
539 | self.shortcut = nn.Sequential()
540 | if stride != 1 or in_planes != self.expansion*planes:
541 | self.shortcut = nn.Sequential(
542 | nn.Conv1d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
543 | norm(self.expansion*planes)
544 | )
545 |
546 |
547 | def forward(self, x):
548 | '''
549 | Args:
550 | x: shape (batch_size, in_planes, seq_len)
551 | return:
552 | out: shape (batch_size, expansion*planes, (seq_len-1)/stride + 1 )
553 | '''
554 | # out: shape (batch_size, planes, seq_len)
555 | out = F.relu(self.norm1(self.conv1(x)))
556 | # out: shape (batch_size, planes, (seq_len-1)/stride + 1 )
557 | out = F.relu(self.norm2(self.conv2(out)))
558 | # out: shape (batch_size, expansion*planes, (seq_len-1)/stride + 1 )
559 | out = self.norm3(self.conv3(out))
560 | # self.shortcut(x): shape (batch_size, expansion*planes, (seq_len-1)/stride + 1 )
561 | # out: shape (batch_size, expansion*planes, (seq_len-1)/stride + 1 )
562 | out += self.shortcut(x)
563 | out = F.relu(out)
564 | return out
565 |
--------------------------------------------------------------------------------