├── .gitignore ├── DatasetConstruction ├── README.md ├── data │ ├── img_path.json │ └── img_path2size.json ├── download.sh ├── move.sh ├── split.sh ├── split_dataset.py ├── split_list.py ├── stat │ ├── objlist.json │ └── sortedID.json └── transform.py ├── README.md ├── data └── ADE │ └── permutations_1000.npy ├── dataset ├── __init__.py ├── base_dataset.py ├── collate.py ├── dataloader.py ├── distributed.py ├── novel_dataset.py ├── proto_dataset.py ├── sampler.py └── transform.py ├── evaluate ├── novel.py ├── predict.py ├── save_features.py └── vote.py ├── logger.py ├── model ├── __init__.py ├── base_model.py ├── builder.py ├── component │ ├── __init__.py │ ├── attr.py │ ├── bbox.py │ ├── bkg.py │ ├── classifier.py │ ├── hierarchy.py │ ├── part.py │ ├── patch_location.py │ ├── resnet.py │ ├── roi_align │ │ ├── roi_align │ │ │ ├── __init__.py │ │ │ ├── crop_and_resize.py │ │ │ ├── roi_align.py │ │ │ └── src │ │ │ │ ├── crop_and_resize.cpp │ │ │ │ ├── crop_and_resize.h │ │ │ │ ├── crop_and_resize_gpu.cpp │ │ │ │ ├── crop_and_resize_gpu.h │ │ │ │ └── cuda │ │ │ │ ├── crop_and_resize_kernel.cu │ │ │ │ └── crop_and_resize_kernel.h │ │ ├── setup.py │ │ ├── test.sh │ │ └── tests │ │ │ ├── crop_and_resize_example.py │ │ │ ├── images │ │ │ ├── choco.png │ │ │ └── snow.png │ │ │ ├── test.py │ │ │ └── test2.py │ ├── rotation.py │ ├── scene.py │ └── seg.py ├── novel_model.py └── parallel │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ └── replicate.py ├── preprocessing ├── addcontext.py ├── generate_list.py ├── supervision.json └── supervison_generation │ ├── attr.py │ ├── bbox.py │ ├── bkg.py │ ├── hierarchy.py │ ├── scene.py │ └── seg.py ├── supervision.json ├── train.py ├── train.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | data/ADE/ADE* 4 | history* 5 | checkpoint/ 6 | ckpt*/ 7 | *.swp 8 | */*.swp 9 | log*/ 10 | model/prtrain/ 11 | run.sh 12 | test.sh 13 | val.sh 14 | sleep.sh 15 | feature.sh 16 | model.log 17 | train.log 18 | diagnosis/*.json 19 | diagnosis/*.pth 20 | data 21 | merge 22 | -------------------------------------------------------------------------------- /DatasetConstruction/README.md: -------------------------------------------------------------------------------- 1 | ### Dataset Construction: Transform ADE20k to ADE-FewShot 2 | 3 | This is an instruction for how to construct the most essential part of ADE-FewShot dataset. 4 | 5 | 6 | 7 | > For now the framework of this part is not very elegant. Since some codes are missing, we simply offer some data files in json format at directory data/ and directory stat/. We will re-construct this part soon. 8 | 9 | 10 | 11 | #### step1. Download ADE20k 12 | 13 | ``` 14 | bash download.sh 15 | ``` 16 | 17 | This will download and unzip the ADE20k dataset at the **parent directory** of the ADE-FewShot directory. If you want to save the dataset at other place, you may need to modify the directory parameter at the bash file and **other codes that refer to the origin ADE20k dataset**. 18 | 19 | #### step2. Transform and split the dataset 20 | 21 | ``` 22 | bash split.sh 23 | ``` 24 | 25 | This will automatically detect the objects in each image of ADE20k and save their locations and annotations. Then it will split all the dataset into base and novel based on the occurrence of each category. You can change the threshold of base and novel by modifying the parameter in split\_list.py 26 | 27 | #### step3. Move relevant files into right position 28 | 29 | ``` 30 | bash move.sh 31 | ``` 32 | 33 | This will move all the data needed for the project to ../data/ADE/ADE\_Origin. The preprocess code will refer data files in this folder. 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /DatasetConstruction/download.sh: -------------------------------------------------------------------------------- 1 | wget https://groups.csail.mit.edu/vision/datasets/ADE20K/ADE20K_2016_07_26.zip ../../ 2 | unzip ../../ADE20K_2016_07_26.zip 3 | rm -rf ../../ADE20K_2016_07_26.zip 4 | -------------------------------------------------------------------------------- /DatasetConstruction/move.sh: -------------------------------------------------------------------------------- 1 | cp data/img_path.json ../data/ADE/ADE_Origin/ 2 | cp data/img_path2size.json ../data/ADE/ADE_Origin/ 3 | mv data/base_set.json ../data/ADE/ADE_Origin/ 4 | mv data/novel_set.json ../data/ADE/ADE_Origin/ 5 | mv stat/base_list.json ../data/ADE/ADE_Origin/ 6 | mv stat/novel_list.json ../data/ADE/ADE_Origin/ 7 | -------------------------------------------------------------------------------- /DatasetConstruction/split.sh: -------------------------------------------------------------------------------- 1 | python transform.py 2 | python split_list.py 3 | python split_dataset.py 4 | -------------------------------------------------------------------------------- /DatasetConstruction/split_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | if __name__ == '__main__': 5 | data = json.load(open('data/data_all.json')) 6 | train_set = [] 7 | test_set = [] 8 | train_list = set(json.load(open('stat/base_list.json'))) 9 | test_list = set(json.load(open('stat/novel_list.json'))) 10 | for item in data: 11 | if item['obj'] in train_list: 12 | train_set.append(item) 13 | elif item['obj'] in test_list: 14 | test_set.append(item) 15 | else: 16 | #class size < 15 17 | continue 18 | json.dump(train_set, open("data/base_set.json", 'w')) 19 | json.dump(test_set, open("data/novel_set.json", 'w')) 20 | -------------------------------------------------------------------------------- /DatasetConstruction/split_list.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate train list and test list according to threshold 3 | """ 4 | import json 5 | import numpy as np 6 | import argparse 7 | from collections import defaultdict 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-low', default=15, help='lowest occurrences') 13 | parser.add_argument('-high', default=100000, help='highest occurrences') 14 | parser.add_argument('-threshold', default=100, help='threshold to split') 15 | parser.add_argument('-sorted_index', default='stat/sortedID.json') 16 | parser.add_argument('-stat', default='stat/occurrence.json', help='occurrences') 17 | parser.add_argument('-mode', default='threshold', help='mode of split') 18 | args = parser.parse_args() 19 | 20 | f = open(args.sorted_index) 21 | index_list = json.load(f) 22 | f.close() 23 | 24 | f = open(args.stat) 25 | stat = json.load(f) 26 | f.close() 27 | 28 | stat = sorted(stat.items(), key=lambda x: x[1]) 29 | stat_dict = defaultdict(list) 30 | for k, v in stat: 31 | stat_dict[int(k)] = int(v) 32 | 33 | if args.mode == 'threshold': 34 | train_list = [] 35 | test_list = [] 36 | for sample in stat_dict.items(): 37 | if sample[1] >= args.low and sample[1] < int(args.threshold): 38 | test_list.append(int(sample[0])) 39 | elif sample[1] >= int(args.threshold) and sample[1] < args.high: 40 | train_list.append(int(sample[0])) 41 | 42 | f = open('stat/base_list.json', 'w') 43 | json.dump(train_list, f) 44 | f.close() 45 | 46 | f = open('stat/novel_list.json', 'w') 47 | json.dump(test_list, f) 48 | f.close() 49 | -------------------------------------------------------------------------------- /DatasetConstruction/stat/objlist.json: -------------------------------------------------------------------------------- 1 | [1, 3, 4, 10, 14, 15, 18, 24, 29, 32, 33, 35, 38, 39, 42, 57, 61, 64, 65, 66, 71, 72, 73, 74, 75, 76, 78, 79, 80, 81, 82, 90, 91, 93, 95, 96, 97, 99, 100, 103, 104, 107, 108, 109, 110, 111, 114, 115, 117, 121, 125, 126, 127, 130, 131, 132, 133, 137, 138, 139, 143, 145, 146, 147, 149, 150, 151, 152, 153, 154, 155, 159, 163, 164, 165, 168, 170, 173, 174, 175, 177, 178, 179, 180, 181, 183, 185, 186, 187, 188, 189, 191, 195, 196, 197, 198, 199, 200, 201, 202, 204, 206, 207, 210, 216, 217, 219, 223, 225, 227, 228, 229, 230, 231, 232, 233, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 249, 250, 255, 257, 259, 265, 266, 268, 270, 271, 272, 273, 274, 276, 279, 280, 282, 283, 285, 286, 287, 288, 289, 294, 295, 296, 297, 298, 299, 300, 301, 303, 304, 305, 306, 309, 314, 315, 316, 317, 320, 321, 322, 323, 324, 325, 327, 328, 330, 331, 332, 335, 336, 337, 340, 345, 348, 349, 350, 352, 353, 354, 355, 356, 359, 360, 361, 362, 363, 364, 365, 366, 370, 371, 373, 374, 376, 377, 378, 379, 380, 382, 383, 384, 385, 386, 387, 388, 394, 395, 397, 399, 400, 401, 402, 403, 404, 405, 406, 407, 409, 412, 413, 415, 418, 419, 421, 422, 425, 426, 427, 428, 430, 431, 433, 435, 437, 438, 439, 440, 441, 442, 443, 444, 445, 454, 455, 457, 459, 462, 463, 464, 465, 466, 467, 469, 470, 471, 472, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 487, 488, 489, 490, 492, 493, 494, 495, 496, 497, 498, 500, 501, 502, 503, 504, 505, 506, 507, 508, 511, 512, 513, 514, 515, 516, 522, 523, 526, 527, 528, 529, 530, 532, 533, 534, 536, 537, 538, 539, 540, 542, 545, 546, 547, 552, 553, 554, 555, 556, 558, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 575, 577, 583, 585, 586, 588, 590, 591, 592, 594, 595, 596, 597, 598, 599, 603, 610, 611, 612, 613, 614, 616, 617, 619, 622, 628, 630, 632, 633, 639, 640, 641, 642, 643, 645, 646, 647, 649, 651, 652, 653, 654, 657, 658, 659, 660, 661, 662, 664, 666, 667, 669, 670, 673, 674, 675, 676, 677, 678, 679, 689, 690, 691, 692, 693, 694, 696, 697, 699, 701, 703, 704, 705, 708, 713, 714, 715, 717, 718, 719, 721, 722, 724, 725, 726, 727, 730, 731, 734, 735, 736, 737, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 753, 755, 757, 760, 762, 763, 764, 767, 768, 770, 771, 773, 774, 775, 777, 780, 782, 784, 785, 786, 789, 791, 792, 793, 794, 795, 798, 799, 800, 803, 804, 805, 807, 809, 810, 812, 813, 814, 815, 816, 817, 818, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 833, 834, 835, 836, 839, 841, 842, 844, 845, 846, 847, 848, 849, 851, 852, 853, 854, 856, 858, 860, 862, 863, 864, 866, 871, 873, 874, 875, 878, 879, 880, 881, 882, 883, 886, 894, 897, 898, 899, 900, 901, 903, 912, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 929, 931, 932, 933, 935, 937, 939, 940, 941, 942, 943, 944, 946, 948, 950, 952, 954, 955, 956, 958, 959, 960, 961, 963, 964, 965, 969, 970, 973, 974, 975, 978, 979, 982, 983, 984, 989, 991, 992, 993, 994, 995, 998, 999, 1000, 1001, 1002, 1003, 1006, 1007, 1010, 1011, 1013, 1015, 1019, 1021, 1023, 1025, 1026, 1033, 1034, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1045, 1046, 1053, 1054, 1055, 1057, 1058, 1061, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1085, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1099, 1100, 1101, 1102, 1104, 1105, 1106, 1107, 1108, 1109, 1110, 1114, 1115, 1116, 1117, 1119, 1120, 1123, 1124, 1127, 1128, 1131, 1133, 1136, 1137, 1138, 1139, 1142, 1143, 1144, 1148, 1150, 1151, 1152, 1153, 1154, 1155, 1157, 1158, 1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1169, 1170, 1171, 1172, 1173, 1175, 1176, 1177, 1178, 1181, 1184, 1186, 1187, 1188, 1189, 1190, 1191, 1193, 1194, 1195, 1196, 1197, 1198, 1200, 1201, 1210, 1211, 1212, 1214, 1215, 1217, 1218, 1220, 1221, 1223, 1224, 1229, 1231, 1232, 1234, 1235, 1239, 1248, 1250, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1262, 1264, 1268, 1269, 1270, 1271, 1273, 1274, 1275, 1277, 1278, 1283, 1285, 1288, 1289, 1290, 1291, 1292, 1297, 1301, 1302, 1303, 1305, 1306, 1307, 1308, 1309, 1311, 1312, 1313, 1314, 1315, 1316, 1318, 1320, 1321, 1322, 1323, 1324, 1326, 1327, 1330, 1335, 1337, 1340, 1341, 1342, 1344, 1346, 1347, 1349, 1356, 1357, 1358, 1359, 1360, 1364, 1367, 1368, 1369, 1370, 1372, 1373, 1375, 1376, 1377, 1378, 1380, 1381, 1383, 1385, 1386, 1387, 1388, 1389, 1390, 1391, 1392, 1394, 1395, 1396, 1398, 1399, 1400, 1404, 1405, 1406, 1407, 1408, 1409, 1411, 1413, 1414, 1419, 1420, 1421, 1422, 1424, 1425, 1433, 1435, 1436, 1437, 1438, 1441, 1442, 1445, 1446, 1447, 1449, 1450, 1451, 1452, 1453, 1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1466, 1468, 1469, 1471, 1472, 1473, 1474, 1477, 1478, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1490, 1491, 1492, 1494, 1495, 1496, 1501, 1502, 1505, 1507, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1522, 1524, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1536, 1537, 1538, 1539, 1541, 1542, 1543, 1544, 1545, 1549, 1550, 1552, 1553, 1557, 1558, 1560, 1561, 1562, 1563, 1564, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1576, 1577, 1581, 1583, 1584, 1585, 1587, 1588, 1591, 1593, 1594, 1598, 1599, 1600, 1601, 1602, 1603, 1612, 1613, 1614, 1616, 1617, 1619, 1623, 1624, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1636, 1639, 1640, 1641, 1642, 1643, 1644, 1647, 1648, 1649, 1650, 1651, 1653, 1654, 1655, 1657, 1658, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1669, 1670, 1671, 1672, 1673, 1675, 1676, 1677, 1678, 1680, 1681, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1691, 1692, 1694, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1708, 1711, 1712, 1714, 1715, 1716, 1717, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1727, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1739, 1740, 1741, 1744, 1745, 1749, 1750, 1751, 1752, 1755, 1756, 1757, 1758, 1759, 1760, 1761, 1762, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1775, 1777, 1779, 1780, 1785, 1791, 1795, 1796, 1801, 1805, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1822, 1823, 1824, 1825, 1826, 1827, 1831, 1832, 1833, 1834, 1835, 1837, 1838, 1840, 1841, 1842, 1843, 1844, 1846, 1847, 1849, 1850, 1852, 1853, 1854, 1855, 1856, 1858, 1859, 1861, 1862, 1864, 1867, 1868, 1869, 1870, 1871, 1873, 1875, 1876, 1877, 1878, 1880, 1881, 1884, 1885, 1886, 1887, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899, 1900, 1903, 1905, 1907, 1908, 1909, 1910, 1911, 1913, 1914, 1915, 1916, 1917, 1918, 1919, 1920, 1921, 1923, 1924, 1925, 1926, 1927, 1928, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938, 1939, 1940, 1944, 1945, 1949, 1950, 1953, 1955, 1956, 1957, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2009, 2011, 2012, 2013, 2014, 2017, 2018, 2020, 2021, 2022, 2023, 2024, 2025, 2026, 2027, 2028, 2029, 2032, 2034, 2035, 2037, 2038, 2039, 2040, 2042, 2043, 2045, 2046, 2047, 2048, 2049, 2062, 2065, 2066, 2070, 2074, 2075, 2076, 2077, 2078, 2079, 2081, 2082, 2083, 2084, 2085, 2086, 2088, 2089, 2090, 2093, 2094, 2096, 2097, 2098, 2099, 2100, 2101, 2105, 2106, 2107, 2111, 2115, 2122, 2123, 2124, 2125, 2126, 2132, 2134, 2138, 2139, 2140, 2143, 2148, 2149, 2150, 2152, 2154, 2155, 2156, 2157, 2158, 2160, 2162, 2163, 2164, 2165, 2166, 2167, 2169, 2170, 2171, 2172, 2173, 2174, 2177, 2180, 2183, 2184, 2187, 2188, 2190, 2191, 2193, 2194, 2198, 2199, 2202, 2203, 2204, 2205, 2207, 2209, 2210, 2211, 2218, 2220, 2221, 2222, 2223, 2227, 2228, 2229, 2230, 2233, 2234, 2236, 2237, 2238, 2239, 2240, 2242, 2243, 2244, 2245, 2246, 2247, 2248, 2249, 2252, 2255, 2257, 2258, 2259, 2260, 2262, 2263, 2265, 2267, 2268, 2269, 2270, 2272, 2273, 2274, 2275, 2276, 2278, 2280, 2281, 2282, 2283, 2286, 2288, 2290, 2291, 2292, 2293, 2296, 2297, 2298, 2301, 2302, 2304, 2305, 2306, 2307, 2308, 2309, 2310, 2313, 2314, 2318, 2319, 2320, 2321, 2322, 2324, 2325, 2326, 2327, 2328, 2331, 2332, 2333, 2336, 2337, 2339, 2341, 2342, 2344, 2345, 2348, 2349, 2350, 2353, 2354, 2355, 2359, 2361, 2362, 2363, 2364, 2365, 2370, 2371, 2372, 2378, 2381, 2382, 2383, 2384, 2385, 2386, 2390, 2393, 2394, 2395, 2396, 2398, 2399, 2400, 2401, 2402, 2403, 2404, 2406, 2409, 2410, 2411, 2413, 2414, 2418, 2419, 2422, 2423, 2424, 2429, 2430, 2431, 2432, 2434, 2435, 2437, 2440, 2441, 2442, 2443, 2444, 2445, 2447, 2448, 2450, 2451, 2452, 2455, 2456, 2457, 2458, 2461, 2462, 2463, 2464, 2465, 2466, 2467, 2468, 2471, 2472, 2473, 2475, 2476, 2477, 2478, 2479, 2481, 2482, 2483, 2484, 2485, 2486, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2498, 2499, 2500, 2501, 2502, 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 2512, 2513, 2514, 2515, 2516, 2517, 2518, 2519, 2520, 2521, 2522, 2527, 2532, 2533, 2534, 2535, 2536, 2540, 2542, 2546, 2547, 2548, 2549, 2551, 2553, 2554, 2555, 2556, 2557, 2558, 2559, 2561, 2563, 2565, 2570, 2571, 2572, 2573, 2574, 2575, 2576, 2577, 2579, 2580, 2581, 2583, 2585, 2586, 2587, 2590, 2591, 2592, 2594, 2595, 2596, 2599, 2600, 2601, 2602, 2603, 2604, 2605, 2607, 2611, 2612, 2613, 2614, 2615, 2616, 2618, 2619, 2620, 2621, 2622, 2623, 2624, 2627, 2628, 2629, 2630, 2631, 2632, 2633, 2635, 2636, 2637, 2638, 2639, 2641, 2644, 2646, 2649, 2650, 2651, 2654, 2655, 2657, 2661, 2662, 2663, 2664, 2665, 2666, 2667, 2669, 2670, 2671, 2673, 2675, 2676, 2677, 2678, 2679, 2680, 2681, 2682, 2684, 2685, 2690, 2691, 2693, 2694, 2696, 2697, 2701, 2703, 2704, 2705, 2706, 2707, 2709, 2710, 2711, 2712, 2713, 2715, 2719, 2721, 2722, 2723, 2725, 2727, 2728, 2729, 2730, 2731, 2732, 2733, 2734, 2735, 2737, 2740, 2745, 2746, 2747, 2748, 2749, 2751, 2752, 2753, 2754, 2755, 2756, 2757, 2758, 2760, 2762, 2764, 2766, 2767, 2768, 2769, 2771, 2773, 2774, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2785, 2786, 2787, 2788, 2790, 2791, 2792, 2793, 2798, 2799, 2801, 2802, 2803, 2804, 2805, 2806, 2807, 2808, 2809, 2810, 2811, 2812, 2813, 2814, 2816, 2817, 2821, 2822, 2823, 2824, 2825, 2826, 2827, 2828, 2832, 2833, 2834, 2835, 2836, 2837, 2838, 2839, 2840, 2841, 2842, 2844, 2845, 2846, 2847, 2848, 2851, 2853, 2854, 2855, 2858, 2862, 2863, 2864, 2865, 2866, 2867, 2868, 2869, 2870, 2871, 2872, 2873, 2874, 2875, 2877, 2878, 2879, 2880, 2881, 2883, 2887, 2888, 2889, 2890, 2891, 2894, 2895, 2896, 2897, 2898, 2899, 2900, 2901, 2902, 2903, 2904, 2913, 2914, 2915, 2918, 2923, 2927, 2928, 2931, 2932, 2933, 2936, 2939, 2941, 2942, 2948, 2950, 2951, 2952, 2953, 2955, 2958, 2959, 2960, 2963, 2964, 2965, 2966, 2971, 2972, 2973, 2974, 2975, 2976, 2979, 2982, 2983, 2984, 2985, 2987, 2988, 2989, 2991, 2992, 2993, 2997, 2998, 3000, 3001, 3002, 3003, 3004, 3005, 3007, 3008, 3010, 3011, 3013, 3014, 3017, 3023, 3024, 3025, 3026, 3029, 3030, 3031, 3032, 3033, 3035, 3036, 3037, 3039, 3040, 3041, 3042, 3043, 3044, 3045, 3046, 3047, 3048, 3049, 3051, 3054, 3058, 3059, 3062, 3066, 3067, 3069, 3070, 3071, 3072, 3075, 3076, 3077, 3078, 3079, 3080, 3081, 3083, 3084, 3085, 3086, 3087, 3088, 3090, 3091, 3092, 3093, 3094, 3095, 3097, 3098, 3101, 3102, 3103, 3104, 3105, 3107, 3110, 3111, 3112, 3113, 3119, 3121, 3123, 3125, 3127, 3129, 3132, 3133, 3136, 3138, 3139, 3140, 3141, 3142, 3143, 3144, 3146, 3147] -------------------------------------------------------------------------------- /DatasetConstruction/stat/sortedID.json: -------------------------------------------------------------------------------- 1 | [[2855, 21574], [1831, 18241], [471, 11386], [1910, 10589], [2684, 10158], [401, 9669], [1735, 8568], [1451, 7473], [350, 6693], [1395, 6233], [2616, 5006], [165, 4477], [774, 4460], [236, 4333], [266, 3419], [689, 3031], [249, 2737], [2473, 2698], [1981, 2403], [978, 2394], [57, 2260], [2509, 2199], [2138, 2124], [2932, 2023], [2243, 1785], [724, 1755], [1564, 1734], [1919, 1618], [1869, 1526], [1936, 1516], [181, 1495], [2982, 1406], [2836, 1313], [3087, 1311], [2272, 1293], [187, 1230], [1744, 1154], [146, 1142], [1974, 1066], [571, 1052], [95, 1037], [982, 972], [259, 938], [64, 937], [530, 928], [2586, 913], [591, 910], [1930, 898], [2821, 893], [2679, 891], [137, 889], [480, 881], [2928, 827], [954, 819], [378, 735], [1756, 725], [2833, 716], [2733, 706], [1349, 699], [223, 689], [1563, 681], [2423, 669], [2880, 668], [2676, 646], [2341, 643], [1002, 623], [943, 613], [239, 590], [918, 574], [10, 569], [155, 565], [2730, 548], [2985, 544], [1583, 540], [2793, 533], [894, 524], [2262, 523], [2901, 509], [2096, 490], [1033, 487], [303, 478], [1485, 469], [1644, 467], [371, 454], [294, 441], [29, 430], [2748, 424], [1702, 420], [327, 401], [1494, 375], [38, 338], [1619, 325], [1474, 324], [103, 319], [96, 313], [1260, 301], [597, 301], [14, 297], [206, 295], [1197, 294], [1791, 290], [747, 277], [376, 270], [677, 262], [1708, 259], [1131, 259], [1884, 258], [1139, 256], [2046, 255], [1892, 255], [1896, 248], [97, 241], [791, 241], [897, 239], [1662, 232], [90, 226], [242, 224], [762, 219], [1779, 219], [569, 218], [1957, 217], [1944, 216], [948, 213], [2723, 206], [1378, 203], [2005, 203], [704, 193], [198, 189], [1391, 188], [2547, 187], [2792, 186], [2828, 183], [1158, 183], [839, 182], [1407, 181], [2834, 180], [1223, 179], [32, 177], [2955, 176], [3067, 174], [2148, 168], [1445, 167], [944, 166], [2941, 165], [1346, 163], [1023, 161], [2465, 159], [643, 155], [1858, 154], [1810, 153], [373, 147], [490, 143], [2339, 142], [1367, 140], [3105, 139], [1481, 137], [2824, 136], [2099, 133], [2722, 133], [418, 132], [385, 129], [746, 128], [2939, 127], [594, 125], [1963, 124], [1549, 123], [1186, 121], [1924, 119], [818, 117], [3107, 116], [2989, 113], [1740, 113], [2370, 111], [286, 110], [2505, 108], [1013, 108], [789, 107], [1019, 107], [2007, 106], [2704, 106], [1769, 105], [2228, 105], [3035, 105], [1927, 105], [1986, 105], [2422, 103], [245, 103], [2877, 101], [131, 100], [314, 99], [675, 97], [132, 97], [662, 96], [3102, 94], [394, 94], [1849, 93], [2727, 93], [2615, 93], [642, 92], [1370, 92], [743, 92], [2337, 91], [2160, 89], [1805, 89], [1714, 88], [2443, 88], [196, 86], [2661, 85], [2462, 83], [2464, 82], [65, 82], [2362, 81], [1151, 81], [2490, 79], [1386, 78], [126, 78], [2782, 77], [1119, 75], [2840, 75], [2123, 74], [2115, 72], [79, 71], [533, 70], [939, 69], [2517, 68], [3104, 67], [495, 67], [202, 66], [925, 66], [1670, 66], [1614, 66], [366, 64], [565, 61], [3049, 61], [1911, 61], [2418, 61], [422, 60], [853, 59], [145, 59], [1560, 59], [296, 58], [1181, 57], [3139, 56], [2012, 56], [300, 55], [2236, 55], [1655, 55], [1885, 54], [2574, 54], [1572, 53], [1232, 52], [1419, 51], [356, 50], [175, 49], [2838, 48], [1543, 48], [3101, 47], [3129, 47], [382, 46], [426, 45], [1424, 43], [108, 43], [1301, 43], [361, 42], [2769, 42], [1513, 41], [2875, 41], [1231, 41], [929, 41], [831, 41], [552, 40], [992, 40], [2734, 40], [3088, 40], [537, 40], [753, 40], [1189, 40], [3127, 40], [2047, 39], [349, 39], [387, 39], [2155, 39], [301, 39], [2028, 38], [2600, 38], [2442, 38], [1852, 38], [3125, 38], [1613, 38], [3070, 38], [188, 37], [1914, 37], [1039, 37], [1188, 37], [107, 36], [503, 36], [2774, 36], [2246, 36], [2950, 36], [2732, 36], [2398, 36], [1657, 36], [740, 35], [813, 35], [2411, 35], [2218, 35], [3111, 35], [1229, 34], [1587, 34], [3039, 34], [1036, 34], [1425, 33], [2510, 33], [727, 33], [1647, 33], [1664, 33], [1001, 32], [2641, 32], [1483, 32], [736, 31], [658, 31], [457, 31], [2858, 31], [1905, 31], [1463, 31], [693, 31], [1945, 31], [1517, 31], [2565, 31], [2613, 31], [851, 31], [2827, 30], [2239, 30], [299, 30], [1215, 30], [1161, 30], [3058, 30], [708, 30], [2786, 29], [1274, 29], [210, 29], [3010, 29], [1934, 29], [288, 29], [265, 29], [2745, 28], [444, 28], [612, 28], [2993, 28], [2062, 28], [917, 27], [463, 27], [1262, 27], [482, 27], [283, 27], [1529, 27], [1341, 27], [610, 26], [1933, 26], [1486, 26], [374, 26], [2889, 26], [404, 26], [2149, 26], [1505, 26], [297, 26], [402, 26], [922, 25], [822, 25], [2936, 25], [2162, 25], [2894, 24], [2649, 24], [2349, 24], [1633, 24], [3113, 24], [2355, 24], [2194, 24], [1672, 24], [2811, 23], [2863, 23], [3037, 23], [2017, 23], [238, 23], [2663, 23], [1643, 23], [2991, 23], [3121, 23], [2715, 23], [3030, 23], [3000, 22], [995, 22], [2540, 22], [317, 22], [352, 22], [782, 22], [2086, 22], [2682, 22], [2210, 22], [2022, 22], [2915, 22], [2440, 21], [1745, 21], [2690, 21], [3036, 21], [965, 21], [2667, 21], [649, 21], [886, 21], [2282, 21], [1908, 21], [383, 21], [1327, 21], [464, 20], [2039, 20], [2494, 20], [430, 20], [149, 20], [2636, 20], [2696, 20], [125, 20], [2803, 20], [1741, 20], [1895, 20], [489, 19], [849, 19], [2746, 19], [640, 19], [2020, 19], [2964, 19], [2553, 19], [2203, 19], [1322, 19], [2401, 19], [130, 19], [2788, 18], [1724, 18], [1808, 18], [185, 18], [871, 18], [1915, 18], [545, 18], [1515, 18], [534, 18], [882, 17], [1759, 17], [2832, 17], [1368, 17], [93, 17], [718, 17], [2025, 17], [2309, 17], [2500, 17], [1983, 16], [1404, 16], [2731, 16], [2278, 16], [745, 16], [1527, 16], [3013, 16], [1763, 16], [926, 16], [2482, 16], [598, 16], [1824, 16], [975, 16], [2187, 16], [2983, 15], [1477, 15], [409, 15], [963, 15], [784, 15], [2220, 15], [2242, 15], [2301, 15], [1482, 15], [282, 15], [1835, 15], [1964, 15], [2324, 15], [540, 15], [1003, 15]] 2 | -------------------------------------------------------------------------------- /DatasetConstruction/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transform the annotation of a single image into bounding boxes 3 | """ 4 | import cv2 5 | import numpy as np 6 | import json 7 | import os 8 | import sys 9 | 10 | fgd = set(json.load(open('stat/objlist.json'))) 11 | img_paths = json.load(open('data/img_path.json')) 12 | objstat = dict() 13 | IMGNUM = 22210 14 | 15 | def transform_annotation(dir_path, img_id): 16 | """ 17 | read the image and generate bounding box annotations into a json file 18 | json file is specified by [{obj:id, box:[]}] 19 | :param dir_path: 20 | :param img_path: 21 | :return: None 22 | """ 23 | seg_path = img_paths[img_id][:-4] + "_seg.png" 24 | img_path = os.path.join(dir_path, seg_path) 25 | result = {} 26 | result['img'] = img_id 27 | if not os.path.exists(img_path): 28 | print(img_path) 29 | print(img_id) 30 | print('') 31 | result['annotation'] = [] 32 | return result 33 | img = cv2.imread(img_path) 34 | img = np.transpose(img, (2, 0, 1)).astype(np.int) 35 | [_, G, R] = img 36 | seg_maps = ((R/10 * 256) + G).astype(np.int) 37 | annotation = search_object(seg_maps) 38 | result['annotation'] = annotation 39 | return result 40 | 41 | 42 | def search_object(seg_map): 43 | """ 44 | search for objects in the map 45 | format is specified by [{obj:id, box:[]}] 46 | :param seg_map: annotation map 47 | :return: a list of objects with their locations 48 | """ 49 | visiting_queue = [] 50 | directions = [[0, -1], [0, 1], [1, 0], [-1, 0]] 51 | annotation_list = [] 52 | cur_obj = -1 53 | H, W = seg_map.shape 54 | area = H * W 55 | 56 | for h in range(H): 57 | for w in range(W): 58 | if seg_map[h, w] == -1: 59 | continue 60 | if int(seg_map[h, w]) not in fgd: 61 | continue 62 | visiting_queue.append([h, w]) 63 | left = right = w 64 | up = down = h 65 | while visiting_queue: 66 | cur_position = visiting_queue.pop() 67 | if seg_map[cur_position[0], cur_position[1]] == -1: 68 | continue 69 | cur_obj = seg_map[cur_position[0], cur_position[1]] 70 | if int(cur_obj) not in fgd: 71 | continue 72 | seg_map[cur_position[0], cur_position[1]] = -1 73 | 74 | left = min(left, cur_position[1]) 75 | right = max(right, cur_position[1]) 76 | up = min(up, cur_position[0]) 77 | down = max(down, cur_position[0]) 78 | 79 | for direction in directions: 80 | new_position = [cur_position[0] + direction[0], cur_position[1] + direction[1]] 81 | if new_position[0] >= 0 and new_position[0] < H and new_position[1] >= 0 and new_position[1] < W: 82 | if seg_map[new_position[0], new_position[1]] == cur_obj: 83 | visiting_queue.append(new_position) 84 | sample_area = (right - left) * (down - up) 85 | if sample_area <= min(100, area / 900): 86 | continue 87 | if (right - left) / (down - up) < 10 and (right - left) / (down - up) > 1 / 10: 88 | annotation_list.append({'obj': int(cur_obj), 'box': [left, right, up, down]}) 89 | cur_obj = int(cur_obj) 90 | if cur_obj not in objstat: 91 | objstat[cur_obj] = 0 92 | objstat[cur_obj] += 1 93 | 94 | return annotation_list 95 | 96 | if __name__ == '__main__': 97 | res = [] 98 | dir = "../../" 99 | for i in range(IMGNUM): 100 | res.append(transform_annotation(dir, i)) 101 | 102 | json.dump(res, open('data/data_all.json', 'w')) 103 | json.dump(objstat, open('stat/occurrence.json', 'w')) 104 | 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADE-FewShot 2 | 3 | Code for paper [Unlocking the Full Potential of Small Data with Diverse Supervision](https://arxiv.org/abs/1911.12911). Ziqi Pang\*, Zhiyuan Hu\*, Pavel Tokmakov, Yuxiong Wang, Martial Hebert. (\* indicates equal contribution) 4 | 5 | We are still working to create a clean repository and will finish the job recently. If you are interested, please refer to the branch dev for the time being. 6 | -------------------------------------------------------------------------------- /data/ADE/permutations_1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinahHu/ADE-FewShot/41dc9cc481bfaf3bd9fb8bd76c1e63fcf127339d/data/ADE/permutations_1000.npy -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinahHu/ADE-FewShot/41dc9cc481bfaf3bd9fb8bd76c1e63fcf127339d/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/collate.py: -------------------------------------------------------------------------------- 1 | import torch.cuda as cuda 2 | import torch.nn as nn 3 | import torch 4 | import collections 5 | from torch.nn.parallel._functions import Gather 6 | 7 | 8 | def async_copy_to(obj, dev, main_stream=None): 9 | if torch.is_tensor(obj): 10 | v = obj.cuda(dev, non_blocking=True) 11 | if main_stream is not None: 12 | v.data.record_stream(main_stream) 13 | return v 14 | elif isinstance(obj, collections.Mapping): 15 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 16 | elif isinstance(obj, collections.Sequence): 17 | return [async_copy_to(o, dev, main_stream) for o in obj] 18 | else: 19 | return obj 20 | 21 | 22 | def dict_gather(outputs, target_device, dim=0): 23 | """ 24 | Gathers variables from different GPUs on a specified device 25 | (-1 means the CPU), with dictionary support. 26 | """ 27 | def gather_map(outputs): 28 | out = outputs[0] 29 | if torch.is_tensor(out): 30 | # MJY(20180330) HACK:: force nr_dims > 0 31 | if out.dim() == 0: 32 | outputs = [o.unsqueeze(0) for o in outputs] 33 | return Gather.apply(target_device, dim, *outputs) 34 | elif out is None: 35 | return None 36 | elif isinstance(out, collections.Mapping): 37 | return {k: gather_map([o[k] for o in outputs]) for k in out} 38 | elif isinstance(out, collections.Sequence): 39 | return type(out)(map(gather_map, zip(*outputs))) 40 | return gather_map(outputs) 41 | 42 | 43 | class DictGatherDataParallel(nn.DataParallel): 44 | def gather(self, outputs, output_device): 45 | return dict_gather(outputs, output_device, dim=self.dim) 46 | 47 | 48 | class UserScatteredDataParallel(DictGatherDataParallel): 49 | def scatter(self, inputs, kwargs, device_ids): 50 | assert len(inputs) == 1 51 | inputs = inputs[0] 52 | inputs = _async_copy_stream(inputs, device_ids) 53 | inputs = [[i] for i in inputs] 54 | assert len(kwargs) == 0 55 | kwargs = [{} for _ in range(len(inputs))] 56 | 57 | return inputs, kwargs 58 | 59 | 60 | def user_scattered_collate(batch): 61 | return batch 62 | 63 | 64 | def _async_copy(inputs, device_ids): 65 | nr_devs = len(device_ids) 66 | assert type(inputs) in (tuple, list) 67 | assert len(inputs) == nr_devs 68 | 69 | outputs = [] 70 | for i, dev in zip(inputs, device_ids): 71 | with cuda.device(dev): 72 | outputs.append(async_copy_to(i, dev)) 73 | 74 | return tuple(outputs) 75 | 76 | 77 | def _async_copy_stream(inputs, device_ids): 78 | nr_devs = len(device_ids) 79 | assert type(inputs) in (tuple, list) 80 | assert len(inputs) == nr_devs 81 | 82 | outputs = [] 83 | streams = [_get_stream(d) for d in device_ids] 84 | for i, dev, stream in zip(inputs, device_ids, streams): 85 | with cuda.device(dev): 86 | main_stream = cuda.current_stream() 87 | with cuda.stream(stream): 88 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 89 | main_stream.wait_stream(stream) 90 | 91 | return outputs 92 | 93 | 94 | """Adapted from: torch/nn/parallel/_functions.py""" 95 | # background streams used for copying 96 | _streams = None 97 | 98 | 99 | def _get_stream(device): 100 | """Gets a background stream for copying between CPU and GPU""" 101 | global _streams 102 | if device == -1: 103 | return None 104 | if _streams is None: 105 | _streams = [None] * cuda.device_count() 106 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 107 | return _streams[device] 108 | -------------------------------------------------------------------------------- /dataset/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /dataset/novel_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for novel classes 3 | """ 4 | import os 5 | import json 6 | import torch 7 | from dataset.proto_dataset import NovelProtoDataset 8 | import cv2 9 | import math 10 | import numpy as np 11 | 12 | 13 | class NovelDataset(NovelProtoDataset): 14 | """ 15 | Form batch at object level 16 | """ 17 | def __init__(self, h5path, opt, batch_per_gpu=1, **kwargs): 18 | super(NovelDataset, self).__init__(h5path, opt, **kwargs) 19 | self.batch_per_gpu = batch_per_gpu 20 | self.batch_record_list = [] 21 | 22 | # override dataset length when trainig with batch_per_gpu > 1 23 | self.cur_idx = 0 24 | self.if_shuffled = False 25 | 26 | self.crop_height = opt.crop_height 27 | self.crop_width = opt.crop_width 28 | self.feat_dim = opt.feat_dim 29 | 30 | def _get_sub_batch(self): 31 | while True: 32 | # get a sample record 33 | this_sample = self.data[self.cur_idx] 34 | self.batch_record_list.append(this_sample) 35 | 36 | # update current sample pointer 37 | self.cur_idx += 1 38 | if self.cur_idx >= self.num_sample: 39 | self.cur_idx = 0 40 | np.random.shuffle(self.data) 41 | 42 | if len(self.batch_record_list) == self.batch_per_gpu: 43 | batch_records = self.batch_record_list 44 | self.batch_record_list = [] 45 | break 46 | 47 | return batch_records 48 | 49 | def __getitem__(self, index): 50 | # NOTE: random shuffle for the first time. shuffle in __init__ is useless 51 | if not self.if_shuffled: 52 | np.random.shuffle(self.data) 53 | self.if_shuffled = True 54 | 55 | # get sub-batch candidates 56 | batch_records = self._get_sub_batch() 57 | 58 | # calculate the BATCH's height and width 59 | # since we concat more than one samples, the batch's h and w shall be larger than EACH sample 60 | batch_features = torch.zeros(self.batch_per_gpu, self.feat_dim * self.crop_width * self.crop_height) 61 | batch_labels = torch.zeros(self.batch_per_gpu).int() 62 | for i in range(self.batch_per_gpu): 63 | batch_features[i] = torch.tensor(batch_records[i]['feature']) 64 | batch_labels[i] = torch.tensor(batch_records[i]['label'].astype(np.float)).int() 65 | # batch_anchors[i] = torch.tensor(batch_records[i]['anchors']) 66 | # batch_scales[i] = torch.tensor(batch_records[i]['scales']) 67 | 68 | output = dict() 69 | output['feature'] = batch_features 70 | output['label'] = batch_labels 71 | # output['anchors'] = batch_anchors 72 | # output['scales'] = batch_scales 73 | return output 74 | 75 | def __len__(self): 76 | return int(1e10) # It's a fake length due to the trick that every loader maintains its own list 77 | # return self.num_sampleclass 78 | -------------------------------------------------------------------------------- /dataset/proto_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | from torch._utils import _accumulate 4 | from torch import randperm 5 | import torch 6 | from torchvision import transforms 7 | import numpy as np 8 | import json 9 | import random 10 | import h5py 11 | import cv2 12 | from skimage.color import rgb2lab 13 | 14 | 15 | class Dataset(object): 16 | """An abstract class representing a Dataset. 17 | All other datasets should subclass it. All subclasses should override 18 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 19 | supporting integer indexing in range from 0 to len(self) exclusive. 20 | """ 21 | 22 | def __getitem__(self, index): 23 | raise NotImplementedError 24 | 25 | def __len__(self): 26 | raise NotImplementedError 27 | 28 | def __add__(self, other): 29 | return ConcatDataset([self, other]) 30 | 31 | 32 | class TensorDataset(Dataset): 33 | """Dataset wrapping data and target tensors. 34 | Each sample will be retrieved by indexing both tensors along the first 35 | dimension. 36 | Arguments: 37 | data_tensor (Tensor): contains sample data. 38 | target_tensor (Tensor): contains sample targets (labels). 39 | """ 40 | 41 | def __init__(self, data_tensor, target_tensor): 42 | assert data_tensor.size(0) == target_tensor.size(0) 43 | self.data_tensor = data_tensor 44 | self.target_tensor = target_tensor 45 | 46 | def __getitem__(self, index): 47 | return self.data_tensor[index], self.target_tensor[index] 48 | 49 | def __len__(self): 50 | return self.data_tensor.size(0) 51 | 52 | 53 | class ConcatDataset(Dataset): 54 | """ 55 | Dataset to concatenate multiple datasets. 56 | Purpose: useful to assemble different existing datasets, possibly 57 | large-scale datasets as the concatenation operation is done in an 58 | on-the-fly manner. 59 | Arguments: 60 | datasets (iterable): List of datasets to be concatenated 61 | """ 62 | 63 | @staticmethod 64 | def cumsum(sequence): 65 | r, s = [], 0 66 | for e in sequence: 67 | l = len(e) 68 | r.append(l + s) 69 | s += l 70 | return r 71 | 72 | def __init__(self, datasets): 73 | super(ConcatDataset, self).__init__() 74 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 75 | self.datasets = list(datasets) 76 | self.cumulative_sizes = self.cumsum(self.datasets) 77 | 78 | def __len__(self): 79 | return self.cumulative_sizes[-1] 80 | 81 | def __getitem__(self, idx): 82 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 83 | if dataset_idx == 0: 84 | sample_idx = idx 85 | else: 86 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 87 | return self.datasets[dataset_idx][sample_idx] 88 | 89 | @property 90 | def cummulative_sizes(self): 91 | warnings.warn("cummulative_sizes attribute is renamed to " 92 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 93 | return self.cumulative_sizes 94 | 95 | 96 | class Subset(Dataset): 97 | def __init__(self, dataset, indices): 98 | self.dataset = dataset 99 | self.indices = indices 100 | 101 | def __getitem__(self, idx): 102 | return self.dataset[self.indices[idx]] 103 | 104 | def __len__(self): 105 | return len(self.indices) 106 | 107 | 108 | def random_split(dataset, lengths): 109 | """ 110 | Randomly split a dataset into non-overlapping new datasets of given lengths 111 | ds 112 | Arguments: 113 | dataset (Dataset): Dataset to be split 114 | lengths (iterable): lengths of splits to be produced 115 | """ 116 | if sum(lengths) != len(dataset): 117 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 118 | 119 | indices = randperm(sum(lengths)) 120 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 121 | 122 | 123 | class BaseProtoDataset(Dataset): 124 | def __init__(self, data_file, args): 125 | # parse options 126 | self.imgShortSize = args.imgShortSize 127 | self.imgMaxSize = args.imgMaxSize 128 | self.list_sample = None 129 | self.num_sample = 0 130 | self.args = args 131 | 132 | # max down sampling rate of network to avoid rounding during conv or pooling 133 | self.padding_constant = args.padding_constant 134 | # parse the input list 135 | self.parse_input_list(data_file) 136 | 137 | # mean and std 138 | self.normalize = transforms.Normalize( 139 | mean=[102.9801, 115.9465, 122.7717], 140 | std=[1., 1., 1.]) 141 | 142 | def parse_input_list(self, data_file): 143 | f = open(data_file, 'r') 144 | old_list_sample = json.load(f) 145 | f.close() 146 | 147 | self.list_sample = [] 148 | iterator = filter(lambda x: (len(x['anchors']) != 0) and (len(x['anchors']) <= 100), old_list_sample) 149 | for sample in iterator: 150 | self.list_sample.append(sample) 151 | 152 | self.num_sample = len(self.list_sample) 153 | print('# samples: {}'.format(self.num_sample)) 154 | 155 | def img_transform(self, img): 156 | # image to float 157 | img = img.astype(np.float32) 158 | img = img.transpose((2, 0, 1)) 159 | img = self.normalize(torch.from_numpy(img.copy())) 160 | return img 161 | 162 | # Round x to the nearest multiple of p and x' >= x 163 | @staticmethod 164 | def round2nearest_multiple(x, p): 165 | return ((x - 1) // p + 1) * p 166 | 167 | def __getitem__(self, index): 168 | return NotImplementedError 169 | 170 | def __len__(self): 171 | return NotImplementedError 172 | 173 | def _get_sub_batch(self): 174 | return NotImplementedError 175 | 176 | 177 | class NovelProtoDataset(Dataset): 178 | def __init__(self, h5path, args): 179 | self.feat_dim = args.feat_dim * args.crop_height * args.crop_width 180 | self.data_path = h5path 181 | self.features = None 182 | self.labels = None 183 | self.num_sample = 0 184 | self.data = None 185 | self.args = args 186 | self._get_feat_data() 187 | 188 | def _get_feat_data(self): 189 | f = h5py.File(self.data_path, 'r') 190 | self.features = np.array(f['feature_map']) 191 | self.labels = np.array(f['labels']) 192 | self.num_sample = self.labels.size 193 | 194 | print('This dataset has {} samples'.format(self.num_sample)) 195 | 196 | self.data = [dict() for i in range(self.num_sample)] 197 | for i in range(self.num_sample): 198 | self.data[i] = {'feature': self.features[i], 199 | 'label': self.labels[i]} 200 | 201 | def __getitem__(self, index): 202 | return NotImplementedError 203 | 204 | def __len__(self): 205 | return NotImplementedError 206 | 207 | def _get_sub_batch(self): 208 | return NotImplementedError 209 | -------------------------------------------------------------------------------- /dataset/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transform data in supervision from raw form into usable data 3 | """ 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import os 8 | import json 9 | 10 | 11 | class Transform: 12 | def __init__(self, args): 13 | self.args = args 14 | 15 | def seg_transform(self, path, other=None): 16 | """ 17 | segmentation transform 18 | :param path: segmentation map path 19 | :return: segmentation map in the original size 20 | """ 21 | path = os.path.join(self.args.root_dataset, path) 22 | img = cv2.imread(path, 0) 23 | return img 24 | 25 | def attr_transform(self, tensor, other=None): 26 | """ 27 | attribute transform 28 | :param tensor: input attribute list 29 | :param other: other information needed for transformation 30 | :return: hot result 31 | """ 32 | if other is None: 33 | raise Exception('No attribute num for attribute supervision') 34 | attr_num = other['num_attr'] 35 | result = np.zeros(attr_num).astype(np.int) 36 | for i in tensor: 37 | result[i] = 1 38 | return result 39 | 40 | def part_transform(self, tensor, other=None): 41 | """ 42 | attribute transform 43 | :param tensor: input attribute list 44 | :param other: other information needed for transformation 45 | :return: hot result 46 | """ 47 | if other is None: 48 | raise Exception('No part num for attribute supervision') 49 | attr_num = other['num_attr'] 50 | result = np.zeros(attr_num).astype(np.int) 51 | for i in tensor: 52 | result[i] = 1 53 | return result 54 | 55 | def hierarchy_transform(self, tensor, other=None): 56 | """ 57 | hierarchy transform 58 | :param tensor: input attribute list 59 | :param other: other information needed for transformation 60 | :return: hot result 61 | """ 62 | return np.array(tensor) 63 | 64 | def fgbg_transform(self, path, other=None): 65 | """ 66 | transform the fg bg data 67 | :param path: mask path 68 | :param other: other needed information 69 | :return: map 70 | """ 71 | path = os.path.join(self.args.root_dataset, path) 72 | img = cv2.imread(path, 0) 73 | return img 74 | 75 | 76 | def scene_transform(self, tensor, other=None): 77 | """ 78 | hierarchy transform 79 | :param tensor: input attribute list 80 | :param other: other information needed for transformation 81 | :return: hot result 82 | """ 83 | #if other is None: 84 | # raise Exception('No attribute num for attribute supervision') 85 | #scene_num = other['scene_num'] 86 | #result = np.zeros(scene_num).astype(np.int) 87 | #result[tensor] = 1 88 | return np.array(tensor) 89 | 90 | def bbox_transform(self, bbox,other=None): 91 | return np.array(bbox) 92 | 93 | def bkg_transform(self, path, other=None): 94 | """ 95 | segmentation transform 96 | :param path: segmentation map path 97 | :return: segmentation map in the original size 98 | """ 99 | path = os.path.join(self.args.root_dataset, path) 100 | img = cv2.imread(path) 101 | p0, p1, p2 = np.transpose(img, (2, 0, 1)) 102 | bkg = p0 + 256 * p1 + 256 * 256 * p2 103 | bkg = (bkg - 999) * (bkg < 999) + 999 104 | return bkg.astype('int32') 105 | 106 | def hierarchy_transform(self, tensor, other=None): 107 | """ 108 | hierarchy transform 109 | :param tensor: input attribute list 110 | :param other: other information needed for transformation 111 | :return: hot result 112 | """ 113 | return np.array(tensor) 114 | 115 | -------------------------------------------------------------------------------- /evaluate/novel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import argparse 5 | import json 6 | import math 7 | import numpy as np 8 | 9 | import sys 10 | sys.path.append('../') 11 | 12 | import torch 13 | import torch.nn as nn 14 | from dataset.novel_dataset import NovelDataset 15 | from dataset.collate import UserScatteredDataParallel, user_scattered_collate 16 | from dataset.dataloader import DataLoader, DataLoaderIter 17 | from utils import AverageMeter, category_acc 18 | from model.parallel.replicate import patch_replication_callback 19 | from model.novel_model import NovelClassifier, NovelCosClassifier 20 | import copy 21 | 22 | 23 | def train(module, iterator, optimizers, epoch, args): 24 | batch_time = AverageMeter() 25 | data_time = AverageMeter() 26 | ave_total_loss = AverageMeter() 27 | ave_acc = AverageMeter() 28 | 29 | module.train() 30 | module.module.mode = 'train' 31 | # main loop 32 | tic = time.time() 33 | acc_iter = 0 34 | acc_iter_num = 0 35 | for i in range(args.train_epoch_iters): 36 | batch_data = next(iterator) 37 | data_time.update(time.time() - tic) 38 | 39 | module.zero_grad() 40 | loss, acc = module(batch_data) 41 | loss = loss.mean() 42 | acc = acc.mean() 43 | acc_iter += acc.data.item() * 100 44 | acc_iter_num += 1 45 | 46 | # Backward 47 | loss.backward() 48 | for optimizer in optimizers: 49 | optimizer.step() 50 | 51 | # measure elapsed time 52 | batch_time.update(time.time() - tic) 53 | tic = time.time() 54 | 55 | # update average loss and acc 56 | ave_total_loss.update(loss.data.item()) 57 | ave_acc.update(acc.data.item() * 100) 58 | 59 | if i % args.disp_iter == 0: 60 | print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 61 | 'lr_cls: {:.6f}, ' 62 | 'Accuracy: {:4.2f}, Loss: {:.6f}, Acc-Iter: {:4.2f}' 63 | .format(epoch, i, args.train_epoch_iters, 64 | batch_time.average(), data_time.average(), 65 | args.lr_cls, 66 | ave_acc.average(), ave_total_loss.average(), acc_iter / acc_iter_num)) 67 | acc_iter = 0 68 | acc_iter_num = 0 69 | 70 | 71 | def validate(module, iterator, epoch, args): 72 | batch_time = AverageMeter() 73 | data_time = AverageMeter() 74 | ave_acc = AverageMeter() 75 | 76 | module.eval() 77 | module.module.mode = 'val' 78 | # main loop 79 | tic = time.time() 80 | acc_iter = 0 81 | acc_iter_num = 0 82 | category_accuracy = torch.zeros(2, args.num_novel_class) 83 | for i in range(args.val_epoch_iters): 84 | batch_data = next(iterator) 85 | data_time.update(time.time() - tic) 86 | 87 | acc, category_batch_acc = module(batch_data) 88 | acc = acc.mean() 89 | acc_iter += acc.data.item() * 100 90 | acc_iter_num += 1 91 | category_batch_acc = category_batch_acc.cpu() 92 | # print(category_batch_acc[:, :10]) 93 | for j in range(len(args.gpus)): 94 | category_accuracy += category_batch_acc[2 * j:2 * j + 2, :] 95 | 96 | # measure elapsed time 97 | batch_time.update(time.time() - tic) 98 | tic = time.time() 99 | # update average loss and acc 100 | ave_acc.update(acc.data.item() * 100) 101 | 102 | if i % args.disp_iter == 0: 103 | print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 104 | 'Accuracy: {:4.2f}, Acc-Iter: {:4.2f}' 105 | .format(epoch, i, args.val_epoch_iters, 106 | batch_time.average(), data_time.average(), 107 | ave_acc.average(), acc_iter / acc_iter_num)) 108 | 109 | acc_iter = 0 110 | acc_iter_num = 0 111 | # print(category_accuracy) 112 | print('Epoch: [{}], Accuracy: {:4.2f}'.format(epoch, ave_acc.average())) 113 | acc = category_acc(category_accuracy, args) 114 | print('Ave Category Acc: {:4.2f}'.format(acc.item() * 100)) 115 | return [ave_acc.average(), acc] 116 | 117 | 118 | def checkpoint(nets, args, epoch_num): 119 | print('Saving checkpoints...') 120 | suffix_latest = 'epoch_{}.pth'.format(epoch_num) 121 | 122 | if not os.path.exists('ckpt/'): 123 | os.makedirs('ckpt/') 124 | 125 | torch.save(nets.module.state_dict(), 126 | 'ckpt/net_{}'.format(suffix_latest)) 127 | 128 | 129 | def slide_window_ave(acc_list, window_size=10): 130 | category = [] 131 | inst = [] 132 | for sample in acc_list: 133 | category.append(sample[1]) 134 | inst.append(sample[0]) 135 | category = np.array(category) 136 | inst = np.array(inst) 137 | epoch = category.size 138 | 139 | start_location = 0 140 | best_shot = -1 141 | for i in range(0, epoch - window_size): 142 | cur_value = category[i:i + window_size].mean() 143 | if cur_value > best_shot: 144 | start_location = i 145 | best_shot = cur_value 146 | 147 | best_inst = inst[start_location:start_location + window_size].mean() 148 | print('Best Category {}'.format(best_shot)) 149 | print('Best Inst {}'.format(best_inst)) 150 | print('Best Shot {}'.format(start_location)) 151 | 152 | 153 | def main(args): 154 | dataset_train = NovelDataset( 155 | args.list_train, args, batch_per_gpu=args.batch_size_per_gpu) 156 | loader_train = DataLoader( 157 | dataset_train, batch_size=len(args.gpus), shuffle=False, 158 | collate_fn=user_scattered_collate, 159 | num_workers=int(args.workers), 160 | drop_last=True, 161 | pin_memory=True 162 | ) 163 | 164 | vargs = copy.deepcopy(args) 165 | vargs.gpus = [0, 1, 2, 3] 166 | vargs.batch_size_per_gpu = 299 167 | vargs.disp_iter = 1 168 | dataset_val = NovelDataset( 169 | args.list_val, args, batch_per_gpu=vargs.batch_size_per_gpu) 170 | loader_val = DataLoader( 171 | dataset_val, batch_size=len(args.gpus), shuffle=False, 172 | collate_fn=user_scattered_collate, 173 | num_workers=int(args.workers), 174 | drop_last=True, 175 | pin_memory=True 176 | ) 177 | 178 | args.train_epoch_iters = \ 179 | math.ceil(dataset_train.num_sample / (args.batch_size_per_gpu * len(args.gpus))) 180 | vargs.val_epoch_iters = \ 181 | math.ceil(dataset_val.num_sample / (vargs.batch_size_per_gpu * len(vargs.gpus))) 182 | print('1 Train Epoch = {} iters'.format(args.train_epoch_iters)) 183 | print('1 Val Epoch = {} iters'.format(vargs.val_epoch_iters)) 184 | 185 | iterator_train = iter(loader_train) 186 | iterator_val = iter(loader_val) 187 | 188 | if args.cls == 'novel_cls': 189 | classifier = NovelClassifier(args) 190 | elif args.cls == 'novel_coscls': 191 | classifier = NovelCosClassifier(args) 192 | optimizer_cls = torch.optim.SGD(classifier.parameters(), 193 | lr=args.lr_cls, momentum=0.5) 194 | optimizers = [optimizer_cls] 195 | network = UserScatteredDataParallel(classifier, device_ids=args.gpus) 196 | patch_replication_callback(network) 197 | network.cuda() 198 | 199 | if args.start_epoch != 0: 200 | network.load_state_dict( 201 | torch.load('{}/net_epoch_{}.pth'.format(args.ckpt, args.log))) 202 | 203 | accuracy = [] 204 | for epoch in range(args.start_epoch, args.num_epoch): 205 | train(network, iterator_train, optimizers, epoch, args) 206 | accuracy.append(validate(network, iterator_val, epoch, vargs)) 207 | checkpoint(network, args, epoch) 208 | 209 | slide_window_ave(accuracy) 210 | print('Training Done') 211 | 212 | 213 | if __name__ == '__main__': 214 | parser = argparse.ArgumentParser() 215 | # Model related arguments 216 | parser.add_argument('--id', default='baseline', 217 | help="a name for identifying the model") 218 | parser.add_argument('--arch', default='resnet10') 219 | parser.add_argument('--cls', default='novel_cls') 220 | parser.add_argument('--feat_dim', default=512) 221 | parser.add_argument('--crop_height', default=3, type=int) 222 | parser.add_argument('--crop_width', default=3, type=int) 223 | parser.add_argument('--range_of_compute', default=5, type=int) 224 | 225 | # Path related arguments 226 | parser.add_argument('--list_train', 227 | default='data/img_train_feat_baseline.h5') 228 | parser.add_argument('--list_val', 229 | default='data/img_val_feat_baseline.h5') 230 | 231 | # optimization related arguments 232 | parser.add_argument('--gpus', default=[0, 1, 2, 3], 233 | help='gpus to use, e.g. 0-3 or 0,1,2,3') 234 | parser.add_argument('--batch_size_per_gpu', default=256, type=int, 235 | help='input batch size') 236 | parser.add_argument('--num_epoch', default=100, type=int, 237 | help='epochs to train for') 238 | parser.add_argument('--start_epoch', default=0, type=int, 239 | help='epoch to start training. useful if continue from a checkpoint') 240 | parser.add_argument('--train_epoch_iters', default=20, type=int, 241 | help='iterations of each epoch (irrelevant to batch size)') 242 | parser.add_argument('--val_epoch_iters', default=20, type=int) 243 | parser.add_argument('--optim', default='SGD', help='optimizer') 244 | parser.add_argument('--lr_cls', default=5.0 * 1e-1, type=float, help='LR') 245 | parser.add_argument('--weight_init', default='') 246 | 247 | # Data related arguments 248 | parser.add_argument('--num_novel_class', default=193, type=int, 249 | help='number of classes') 250 | parser.add_argument('--workers', default=8, type=int, 251 | help='number of data loading workers') 252 | parser.add_argument('--imgSize', default=[200, 250], 253 | nargs='+', type=int, 254 | help='input image size of short edge (int or list)') 255 | parser.add_argument('--imgMaxSize', default=1500, type=int, 256 | help='maximum input image size of long edge') 257 | parser.add_argument('--padding_constant', default=8, type=int, 258 | help='maxmimum downsampling rate of the network') 259 | parser.add_argument('--segm_downsampling_rate', default=8, type=int, 260 | help='downsampling rate of the segmentation label') 261 | parser.add_argument('--random_flip', default=True, type=bool, 262 | help='if horizontally flip images when training') 263 | 264 | # Misc arguments 265 | parser.add_argument('--seed', default=304, type=int, help='manual seed') 266 | parser.add_argument('--ckpt', default='./ckpt/', 267 | help='folder to output checkpoints') 268 | parser.add_argument('--disp_iter', type=int, default=1, 269 | help='frequency to display') 270 | parser.add_argument('--log_dir', default="./log_novel/", 271 | help='dir to save train and val log') 272 | parser.add_argument('--comment', default="this_child_may_save_the_world", 273 | help='add comment to this test') 274 | 275 | args = parser.parse_args() 276 | args.list_train = 'data/img_test_train_feat_{}.h5'.format(args.id) 277 | args.list_val = 'data/img_test_val_feat_{}.h5'.format(args.id) 278 | 279 | main(args) 280 | -------------------------------------------------------------------------------- /evaluate/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import argparse 5 | import json 6 | import math 7 | import numpy as np 8 | import h5py 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | import sys 14 | sys.path.append('../') 15 | from dataset.novel_dataset import NovelDataset 16 | from dataset.collate import UserScatteredDataParallel, user_scattered_collate 17 | from dataset.dataloader import DataLoader, DataLoaderIter 18 | from utils import AverageMeter, category_acc 19 | from model.parallel.replicate import patch_replication_callback 20 | from model.novel_model import NovelClassifier, NovelCosClassifier 21 | from utils import selective_load_weights 22 | import copy 23 | 24 | 25 | def train(module, iterator, optimizers, epoch, args): 26 | batch_time = AverageMeter() 27 | data_time = AverageMeter() 28 | ave_total_loss = AverageMeter() 29 | ave_acc = AverageMeter() 30 | 31 | module.train() 32 | module.module.mode = 'train' 33 | # main loop 34 | tic = time.time() 35 | acc_iter = 0 36 | acc_iter_num = 0 37 | for i in range(args.train_epoch_iters): 38 | batch_data = next(iterator) 39 | data_time.update(time.time() - tic) 40 | 41 | module.zero_grad() 42 | loss, acc = module(batch_data) 43 | loss = loss.mean() 44 | acc = acc.mean() 45 | acc_iter += acc.data.item() * 100 46 | acc_iter_num += 1 47 | 48 | # Backward 49 | loss.backward() 50 | for optimizer in optimizers: 51 | optimizer.step() 52 | 53 | # measure elapsed time 54 | batch_time.update(time.time() - tic) 55 | tic = time.time() 56 | 57 | # update average loss and acc 58 | ave_total_loss.update(loss.data.item()) 59 | ave_acc.update(acc.data.item() * 100) 60 | 61 | if i % args.disp_iter == 0: 62 | print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 63 | 'lr_cls: {:.6f}, ' 64 | 'Accuracy: {:4.2f}, Loss: {:.6f}, Acc-Iter: {:4.2f}' 65 | .format(epoch, i, args.train_epoch_iters, 66 | batch_time.average(), data_time.average(), 67 | args.lr_cls, 68 | ave_acc.average(), ave_total_loss.average(), acc_iter / acc_iter_num)) 69 | acc_iter = 0 70 | acc_iter_num = 0 71 | 72 | 73 | def validate(module, iterator, epoch, args): 74 | batch_time = AverageMeter() 75 | data_time = AverageMeter() 76 | ave_acc = AverageMeter() 77 | 78 | module.eval() 79 | module.module.mode = 'val' 80 | # main loop 81 | tic = time.time() 82 | acc_iter = 0 83 | acc_iter_num = 0 84 | category_accuracy = torch.zeros(2, args.num_novel_class) 85 | for i in range(args.val_epoch_iters): 86 | batch_data = next(iterator) 87 | data_time.update(time.time() - tic) 88 | 89 | acc, category_batch_acc = module(batch_data) 90 | acc = acc.mean() 91 | acc_iter += acc.data.item() * 100 92 | acc_iter_num += 1 93 | category_batch_acc = category_batch_acc.cpu() 94 | # print(category_batch_acc[:, :10]) 95 | for j in range(len(args.gpus)): 96 | category_accuracy += category_batch_acc[2 * j:2 * j + 2, :] 97 | 98 | # measure elapsed time 99 | batch_time.update(time.time() - tic) 100 | tic = time.time() 101 | # update average loss and acc 102 | ave_acc.update(acc.data.item() * 100) 103 | 104 | if i % args.disp_iter == 0: 105 | print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 106 | 'Accuracy: {:4.2f}, Acc-Iter: {:4.2f}' 107 | .format(epoch, i, args.val_epoch_iters, 108 | batch_time.average(), data_time.average(), 109 | ave_acc.average(), acc_iter / acc_iter_num)) 110 | 111 | acc_iter = 0 112 | acc_iter_num = 0 113 | # print(category_accuracy) 114 | print('Epoch: [{}], Accuracy: {:4.2f}'.format(epoch, ave_acc.average())) 115 | acc = category_acc(category_accuracy, args) 116 | print('Ave Category Acc: {:4.2f}'.format(acc.item() * 100)) 117 | return [ave_acc.average(), acc] 118 | 119 | 120 | def checkpoint(nets, args, epoch_num): 121 | print('Saving checkpoints...') 122 | suffix_latest = 'epoch_{}.pth'.format(epoch_num) 123 | 124 | torch.save(nets.module.state_dict(), 125 | '{}/net_{}'.format(args.ckpt, suffix_latest)) 126 | 127 | 128 | def slide_window_ave(acc_list, window_size=10): 129 | category = [] 130 | inst = [] 131 | for sample in acc_list: 132 | category.append(sample[1]) 133 | inst.append(sample[0]) 134 | category = np.array(category) 135 | inst = np.array(inst) 136 | epoch = category.size 137 | 138 | start_location = 0 139 | best_shot = -1 140 | for i in range(0, epoch - window_size): 141 | cur_value = category[i:i + window_size].mean() 142 | if cur_value > best_shot: 143 | start_location = i 144 | best_shot = cur_value 145 | 146 | best_inst = inst[start_location:start_location + window_size].mean() 147 | print('Best Category {}'.format(best_shot)) 148 | print('Best Inst {}'.format(best_inst)) 149 | 150 | 151 | def main(args): 152 | if not os.path.exists('pred/'): 153 | os.makedirs('pred/') 154 | 155 | dataset = NovelDataset( 156 | args.list_val, args, batch_per_gpu=args.batch_size_per_gpu) 157 | dataset.if_shuffled = True 158 | loader = DataLoader( 159 | dataset, batch_size=len(args.gpus), shuffle=False, 160 | collate_fn=user_scattered_collate, 161 | num_workers=int(args.workers), 162 | drop_last=True, 163 | pin_memory=True 164 | ) 165 | 166 | args.epoch_iters = \ 167 | math.ceil(dataset.num_sample / (args.batch_size_per_gpu * len(args.gpus))) 168 | print('1 Train Epoch = {} iters'.format(args.epoch_iters)) 169 | 170 | iterator = iter(loader) 171 | 172 | classifier = NovelClassifier(args) 173 | selective_load_weights(classifier, args.model_weight) 174 | network = UserScatteredDataParallel(classifier, device_ids=args.gpus) 175 | patch_replication_callback(network) 176 | network.cuda() 177 | network.module.mode = 'prob' 178 | 179 | iterations = 0 180 | preds = np.zeros((40000, args.num_novel_class)) 181 | labels = np.zeros(40000) 182 | flag = 0 183 | network.eval() 184 | while iterations < args.epoch_iters: 185 | batch_data = next(iterator) 186 | if iterations % 1000 == 0: 187 | print('{} / {}'.format(iterations, args.epoch_iters)) 188 | pred, label = network(batch_data) 189 | pred = np.array(pred.detach().cpu()) 190 | label = np.array(label.cpu()) 191 | preds[flag:flag + pred.shape[0], :] = pred 192 | labels[flag:flag + pred.shape[0]] = label 193 | flag += pred.shape[0] 194 | iterations += 1 195 | 196 | preds = preds[:flag, :] 197 | labels = labels[:flag] 198 | f = h5py.File('pred/img_test_pred_{}.h5'.format(args.id), 'w') 199 | f.create_dataset('preds', data=preds) 200 | f.create_dataset('labels', data=labels) 201 | f.close() 202 | 203 | 204 | if __name__ == '__main__': 205 | parser = argparse.ArgumentParser() 206 | # Model related arguments 207 | parser.add_argument('--id', default='baseline', 208 | help="a name for identifying the model") 209 | parser.add_argument('--cls', default='novel_cls') 210 | parser.add_argument('--feat_dim', default=512) 211 | parser.add_argument('--crop_height', default=3, type=int) 212 | parser.add_argument('--crop_width', default=3, type=int) 213 | parser.add_argument('--range_of_compute', default=5, type=int) 214 | parser.add_argument('--model_weight', default='') 215 | parser.add_argument('--epoch', default=-1, type=int) 216 | parser.add_argument('--mode', default='val') 217 | 218 | parser.add_argument('--list_val', 219 | default='data/img_test_val_feat.h5') 220 | 221 | # optimization related arguments 222 | parser.add_argument('--gpus', default=[0], 223 | help='gpus to use, e.g. 0-3 or 0,1,2,3') 224 | parser.add_argument('--batch_size_per_gpu', default=1, type=int, 225 | help='input batch size') 226 | parser.add_argument('--num_epoch', default=100, type=int, 227 | help='epochs to train for') 228 | parser.add_argument('--start_epoch', default=0, type=int, 229 | help='epoch to start training. useful if continue from a checkpoint') 230 | parser.add_argument('--train_epoch_iters', default=20, type=int, 231 | help='iterations of each epoch (irrelevant to batch size)') 232 | parser.add_argument('--val_epoch_iters', default=20, type=int) 233 | parser.add_argument('--optim', default='SGD', help='optimizer') 234 | 235 | # Data related arguments 236 | parser.add_argument('--num_novel_class', default=193, type=int, 237 | help='number of classes') 238 | parser.add_argument('--workers', default=0, type=int, 239 | help='number of data loading workers') 240 | parser.add_argument('--disp_iter', type=int, default=1, 241 | help='frequency to display') 242 | 243 | args = parser.parse_args() 244 | args.list_val = 'data/img_test_val_feat_{}.h5'.format(args.id) 245 | args.model_weight = 'ckpt/net_epoch_{}.pth'.format(args.epoch) 246 | 247 | main(args) 248 | -------------------------------------------------------------------------------- /evaluate/save_features.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import torch 4 | from dataset.base_dataset import BaseDataset 5 | from dataset.dataloader import DataLoader 6 | from dataset.collate import UserScatteredDataParallel, user_scattered_collate 7 | from model.builder import ModelBuilder 8 | from model.base_model import BaseLearningModule 9 | from model.parallel.replicate import patch_replication_callback 10 | import os 11 | import argparse 12 | import h5py 13 | import numpy as np 14 | import math 15 | from utils import selective_load_weights 16 | 17 | 18 | def save_feature(args): 19 | if not os.path.exists('data/'): 20 | os.makedirs('data/') 21 | 22 | # Network Builders 23 | builder = ModelBuilder(args) 24 | feature_extractor = builder.build_backbone() 25 | network = BaseLearningModule(args, feature_extractor, classifier=None) 26 | selective_load_weights(network, args.model_weight) 27 | network = UserScatteredDataParallel(network, device_ids=args.gpus) 28 | patch_replication_callback(network) 29 | network.cuda() 30 | network.eval() 31 | network.module.mode = 'feature' 32 | 33 | dataset_train = BaseDataset(args.data_train, args) 34 | dataset_val = BaseDataset(args.data_val, args) 35 | dataset_train.if_shuffled = True 36 | dataset_val.if_shuffled = True 37 | loader_train = DataLoader( 38 | dataset_train, batch_size=len(args.gpus), shuffle=False, 39 | collate_fn=user_scattered_collate, 40 | num_workers=int(args.workers), 41 | drop_last=True, 42 | pin_memory=True 43 | ) 44 | loader_val = DataLoader( 45 | dataset_val, batch_size=len(args.gpus), shuffle=False, 46 | collate_fn=user_scattered_collate, 47 | num_workers=int(args.workers), 48 | drop_last=True, 49 | pin_memory=True 50 | ) 51 | iter_train = iter(loader_train) 52 | iter_val = iter(loader_val) 53 | 54 | args.train_epoch_iters = \ 55 | math.ceil(dataset_train.num_sample / (args.batch_size_per_gpu * len(args.gpus))) 56 | args.val_epoch_iters = \ 57 | math.ceil(dataset_val.num_sample / (args.batch_size_per_gpu * len(args.gpus))) 58 | print('1 Train Epoch = {} iters'.format(args.train_epoch_iters)) 59 | print('1 Val Epoch = {} iters'.format(args.val_epoch_iters)) 60 | 61 | iterations = 0 62 | features = np.zeros((240000, args.feat_dim * args.crop_height * args.crop_width)) 63 | labels = np.zeros(240000) 64 | flag = 0 65 | while iterations < args.train_epoch_iters: 66 | batch_data = next(iter_train) 67 | if iterations % 10 == 0: 68 | print('{} / {}'.format(iterations, args.train_epoch_iters)) 69 | feature, label = network(batch_data) 70 | feature = np.array(feature.detach().cpu()) 71 | label = np.array(label.cpu()) 72 | features[flag:flag+feature.shape[0], :] = feature 73 | labels[flag:flag+label.size] = label 74 | flag += feature.shape[0] 75 | iterations += 1 76 | features = features[:flag, :] 77 | labels = labels[:flag] 78 | f = h5py.File('data/img_test_train_feat_{}.h5'.format(args.id), 'w') 79 | f.create_dataset('feature_map', data=features) 80 | f.create_dataset('labels', data=labels) 81 | f.close() 82 | 83 | iterations = 0 84 | features = np.zeros((40000, args.feat_dim * args.crop_height * args.crop_width)) 85 | labels = np.zeros(40000) 86 | flag = 0 87 | while iterations < args.val_epoch_iters: 88 | batch_data = next(iter_val) 89 | if iterations % 10 == 0: 90 | print('{} / {}'.format(iterations, args.val_epoch_iters)) 91 | feature, label = network(batch_data) 92 | feature = np.array(feature.detach().cpu()) 93 | label = np.array(label.cpu()) 94 | features[flag:flag + feature.shape[0], :] = feature 95 | labels[flag:flag + feature.shape[0]] = label 96 | flag += feature.shape[0] 97 | iterations += 1 98 | 99 | features = features[:flag, :] 100 | labels = labels[:flag] 101 | f = h5py.File('data/img_test_val_feat_{}.h5'.format(args.id), 'w') 102 | f.create_dataset('feature_map', data=features) 103 | f.create_dataset('labels', data=labels) 104 | f.close() 105 | 106 | return None 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | # Model related arguments 112 | parser.add_argument('--id', default='', 113 | help="a name for identifying the model") 114 | parser.add_argument('--architecture', default='resnet10') 115 | parser.add_argument('--feat_dim', default=512) 116 | parser.add_argument('--crop_height', default=3, type=int) 117 | parser.add_argument('--crop_width', default=3, type=int) 118 | 119 | # Path related arguments 120 | parser.add_argument('--data_train', 121 | default='../data/ADE/ADE_Novel/novel_img_test_train.json') 122 | parser.add_argument('--data_val', 123 | default='../data/ADE/ADE_Novel/novel_img_test_val.json') 124 | parser.add_argument('--root_dataset', 125 | default='../../') 126 | 127 | # optimization related argument 128 | parser.add_argument('--gpus', default=[0], 129 | help='gpus to use, e.g. 0-3 or 0,1,2,3') 130 | parser.add_argument('--batch_size_per_gpu', default=1, type=int, 131 | help='input batch size') 132 | parser.add_argument('--train_epoch_iters', default=20, type=int, 133 | help='iterations of each epoch (irrelevant to batch size)') 134 | parser.add_argument('--val_epoch_iters', default=20, type=int) 135 | parser.add_argument('--model_weight', default='') 136 | parser.add_argument('--mode', default='val') 137 | 138 | # Data related arguments 139 | parser.add_argument('--workers', default=0, type=int, 140 | help='number of data loading workers') 141 | parser.add_argument('--imgShortSize', default=800, type=int, 142 | help='input image size of short edge (int or list)') 143 | parser.add_argument('--imgMaxSize', default=1500, type=int, 144 | help='maximum input image size of long edge') 145 | parser.add_argument('--padding_constant', default=8, type=int, 146 | help='max down sampling rate of the network') 147 | parser.add_argument('--down_sampling_rate', default=8, type=int, 148 | help='down sampling rate of the segmentation label') 149 | parser.add_argument('--sample_type', default='inst', 150 | help='instance level or category level sampling') 151 | 152 | # Misc arguments 153 | parser.add_argument('--seed', default=304, type=int, help='manual seed') 154 | parser.add_argument('--ckpt', default='./checkpoint', 155 | help='folder to output checkpoints') 156 | parser.add_argument('--disp_iter', type=int, default=10, 157 | help='frequency to display') 158 | parser.add_argument('--log_dir', default="./log_base/", 159 | help='dir to save train and val log') 160 | parser.add_argument('--comment', default="", 161 | help='add comment to this train') 162 | parser.add_argument('--max_anchor_per_img', default=100) 163 | 164 | args = parser.parse_args() 165 | 166 | save_feature(args) 167 | -------------------------------------------------------------------------------- /evaluate/vote.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import h5py 4 | import numpy as np 5 | import argparse 6 | import utils 7 | import torch 8 | 9 | 10 | def acc(preds, label, range_of_compute): 11 | category_acc = np.zeros((2, preds.shape[1])) 12 | acc_sum = 0 13 | num = preds.shape[0] 14 | preds = np.argsort(preds) 15 | label = label.astype(np.int) 16 | for i in range(num): 17 | category_acc[1, label[i]] += 1 18 | if label[i] in preds[i, -range_of_compute:]: 19 | acc_sum += 1 20 | category_acc[0, label[i]] += 1 21 | acc = np.array(acc_sum / (num + 1e-10)) 22 | cat_acc = utils.category_acc(torch.tensor(category_acc), 1) 23 | return acc, cat_acc 24 | 25 | 26 | def main(args): 27 | models = args.models 28 | preds = [] 29 | labels = None 30 | for model in models: 31 | file_path = 'pred/img_test_pred_{}.h5'.format(model) 32 | f = h5py.File(file_path, 'r') 33 | preds.append(np.array(f['preds'])) 34 | if labels is None: 35 | labels = np.array(f['labels']) 36 | else: 37 | if not np.array_equal(labels, np.array(f['labels'])): 38 | raise RuntimeError('mismatch list') 39 | 40 | model_num = len(models) 41 | instance_num = preds[0].shape[0] 42 | class_num = preds[0].shape[1] 43 | weight = args.weight 44 | 45 | pred_vote = np.zeros((instance_num, class_num)) 46 | for i in range(model_num): 47 | pred_vote += weight[i] * preds[i] 48 | 49 | acc_1 = acc(pred_vote, labels, range_of_compute=1) 50 | acc_5 = acc(pred_vote, labels, range_of_compute=5) 51 | 52 | print(acc_1) 53 | print(acc_5) 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--models', default=['baseline']) 59 | parser.add_argument('-weight', default=[1.0]) 60 | parser.add_argument('--mode', default='val') 61 | args = parser.parse_args() 62 | main(args) 63 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import scipy.misc 4 | 5 | import os 6 | 7 | try: 8 | from StringIO import StringIO # Python 2.7 9 | except ImportError: 10 | from io import BytesIO # Python 3.x 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | # 创建一个指向log文件夹的summary writer 16 | if os.path.exists(log_dir): 17 | None 18 | #for f in os.listdir(log_dir): 19 | #os.remove(os.path.join(log_dir, f)) 20 | else: 21 | os.makedirs(log_dir) 22 | 23 | self.writer = tf.summary.FileWriter(log_dir) 24 | 25 | 26 | def scalar_summary(self, tag, value, step): 27 | """Log a scalar variable.""" 28 | # 标量信息 日志 29 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 30 | self.writer.add_summary(summary, step) 31 | 32 | def image_summary(self, tag, images, step): 33 | """Log a list of images.""" 34 | # 图像信息 日志 35 | img_summaries = [] 36 | for i, img in enumerate(images): 37 | # Write the image to a string 38 | try: 39 | s = StringIO() 40 | except: 41 | s = BytesIO() 42 | scipy.misc.toimage(img).save(s, format="png") 43 | 44 | # Create an Image object 45 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 46 | height=img.shape[0], 47 | width=img.shape[1]) 48 | # Create a Summary value 49 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 50 | 51 | # Create and write Summary 52 | summary = tf.Summary(value=img_summaries) 53 | self.writer.add_summary(summary, step) 54 | 55 | def histo_summary(self, tag, values, step, bins=1000): 56 | """Log a histogram of the tensor of values.""" 57 | # 直方图信息 日志 58 | # Create a histogram using numpy 59 | counts, bin_edges = np.histogram(values, bins=bins) 60 | 61 | # Fill the fields of the histogram proto 62 | hist = tf.HistogramProto() 63 | hist.min = float(np.min(values)) 64 | hist.max = float(np.max(values)) 65 | hist.num = int(np.prod(values.shape)) 66 | hist.sum = float(np.sum(values)) 67 | hist.sum_squares = float(np.sum(values**2)) 68 | 69 | # Drop the start of the first bin 70 | bin_edges = bin_edges[1:] 71 | 72 | # Add bin edges and counts 73 | for edge in bin_edges: 74 | hist.bucket_limit.append(edge) 75 | for c in counts: 76 | hist.bucket.append(c) 77 | 78 | # Create and write Summary 79 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 80 | self.writer.add_summary(summary, step) 81 | self.writer.flush() 82 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinahHu/ADE-FewShot/41dc9cc481bfaf3bd9fb8bd76c1e63fcf127339d/model/__init__.py -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from roi_align.roi_align import RoIAlign 5 | from torch.autograd import Variable 6 | import random 7 | 8 | 9 | def to_variable(arr, requires_grad=False, is_cuda=True): 10 | tensor = torch.from_numpy(arr) 11 | if is_cuda: 12 | tensor = tensor.cuda() 13 | var = Variable(tensor, requires_grad=requires_grad) 14 | return var 15 | 16 | 17 | class BaseLearningModule(nn.Module): 18 | def __init__(self, args, backbone, classifier): 19 | super(BaseLearningModule, self).__init__() 20 | self.args = args 21 | self.backbone = backbone 22 | self.classifier = classifier 23 | 24 | self.crop_height = int(args.crop_height) 25 | self.crop_width = args.crop_width 26 | self.roi_align = RoIAlign(args.crop_height, args.crop_width, transform_fpcoor=True) 27 | self.down_sampling_rate = args.down_sampling_rate 28 | 29 | # supervision modules are generated in train 30 | # args.module example: 31 | # [{'name': 'seg', 'module': seg_module}] 32 | if hasattr(args, 'module'): 33 | for module in args.module: 34 | setattr(self, module['name'], module['module']) 35 | 36 | self.mode = 'train' 37 | if self.classifier is not None: 38 | self.classifier.mode = self.mode 39 | 40 | def process_in_roi_layer(self, feature_map, scale, anchors, anchor_num): 41 | """ 42 | process the data in roi_layer and get the feature 43 | :param feature_map: C * H * W 44 | :param scale: anchor_num * 2 45 | :param anchors: anchor_num * 4 46 | :param anchor_num: int 47 | :return: feature C * crop_height * crop_width 48 | """ 49 | anchors = np.array(anchors.detach().cpu()) 50 | scale = np.array(scale.detach().cpu()) 51 | anchors = anchors / self.down_sampling_rate 52 | anchor_num = int(anchor_num) 53 | anchors[:, 2] = anchors[:, 2] * scale[0] 54 | anchors[:, 3] = anchors[:, 3] * scale[0] 55 | anchors[:, 0] = anchors[:, 0] * scale[1] 56 | anchors[:, 1] = anchors[:, 1] * scale[1] 57 | anchors[:, [1, 2]] = anchors[:, [2, 1]] 58 | anchor_index = np.zeros(anchor_num) 59 | anchor_index = to_variable(anchor_index).int() 60 | anchors = to_variable(anchors[:anchor_num, :]).float() 61 | feature_map = feature_map.unsqueeze(0) 62 | feature = self.roi_align(feature_map, anchors, anchor_index) 63 | feature = feature.view(-1, self.args.feat_dim * self.crop_height * self.crop_width) 64 | return feature 65 | 66 | def predict(self, feed_dict): 67 | feature_map = self.backbone(feed_dict['img_data']) 68 | batch_img_num = feature_map.shape[0] 69 | features = None 70 | labels = None 71 | for i in range(batch_img_num): 72 | anchor_num = int(feed_dict['anchor_num'][i].detach().cpu()) 73 | if anchor_num == 0 or anchor_num > 100: 74 | continue 75 | feature = self.process_in_roi_layer(feature_map[i], feed_dict['scales'][i], 76 | feed_dict['anchors'][i], anchor_num) 77 | label = feed_dict['label'][i][:anchor_num].long() 78 | 79 | if features is None: 80 | features = feature.clone() 81 | labels = label.clone() 82 | else: 83 | features = torch.stack((features, feature), dim=0) 84 | labels = torch.stack((labels, label), dim=0) 85 | if features.shape[0] != labels.shape[0]: 86 | print(features.shape[0]) 87 | print(anchor_num) 88 | return features, labels 89 | 90 | def forward(self, feed_dict): 91 | if self.mode == 'feature': 92 | return self.predict(feed_dict) 93 | elif self.mode == 'diagnosis': 94 | return self.diagnosis(feed_dict) 95 | 96 | category_accuracy = torch.zeros(2, self.args.num_base_class).cuda() 97 | 98 | feature_map = self.backbone(feed_dict['img_data']) 99 | acc = 0 100 | loss = 0 101 | batch_img_num = feature_map.shape[0] 102 | 103 | instance_sum = torch.tensor([0]).cuda() 104 | loss_classification = torch.zeros(1) 105 | loss_supervision = torch.zeros(len(self.args.supervision)) 106 | for i in range(batch_img_num): 107 | anchor_num = int(feed_dict['anchor_num'][i].detach().cpu()) 108 | if anchor_num == 0 or anchor_num > 100: 109 | continue 110 | feature = self.process_in_roi_layer(feature_map[i], feed_dict['scales'][i], 111 | feed_dict['anchors'][i], anchor_num) 112 | labels = feed_dict['label'][i, : anchor_num].long() 113 | loss_cls, acc_cls, category_acc_img = self.classifier([feature, labels]) 114 | instance_sum[0] += labels.shape[0] 115 | loss += loss_cls * labels.shape[0] 116 | 117 | acc += acc_cls * labels.shape[0] 118 | loss_classification += loss_cls.item() * labels.shape[0] 119 | category_accuracy += category_acc_img.cuda() 120 | # do not contain other supervision 121 | if not hasattr(self.args, 'module'): 122 | continue 123 | if self.mode == 'val': 124 | continue 125 | 126 | # form generic data input for all supervision branch 127 | input_agg = dict() 128 | input_agg['features'] = feature 129 | input_agg['feature_map'] = feature_map[i] 130 | input_agg['anchors'] = feed_dict['anchors'][i][:anchor_num] 131 | input_agg['scales'] = feed_dict['scales'][i] 132 | input_agg['labels'] = feed_dict['label'][i][:anchor_num] 133 | 134 | for key in feed_dict.keys(): 135 | if key not in ['img_data', 'patch_location_label', 'patch_location__img', 136 | 'rotation_img', 'rotation_label']: 137 | supervision = next((x for x in self.args.supervision if x['name'] == key), None) 138 | if supervision is not None: 139 | input_agg[key] = feed_dict[key][i] 140 | 141 | for j, supervision in enumerate(self.args.supervision): 142 | if supervision['type'] != 'self': 143 | loss_branch = getattr(self, supervision['name'])(input_agg) * labels.shape[0] 144 | elif supervision['name'] == 'patch_location': 145 | input_patch_location = feed_dict['patch_location_img'] 146 | _, _, _, height, width = input_patch_location.shape 147 | patch_location_label = feed_dict['patch_location_label'] 148 | patch_location_feature_map = self.backbone(input_patch_location.view(-1, 3, height, width)) 149 | _, C, H, W = patch_location_feature_map.shape 150 | patch_location_feature_map = patch_location_feature_map.reshape(batch_img_num, 2, C, H, W) 151 | loss_branch = getattr(self, 'patch_location')([patch_location_feature_map, patch_location_label]) 152 | elif supervision['name'] == 'rotation': 153 | input_img = feed_dict['rotation_img'] 154 | input_label = feed_dict['rotation_label'] 155 | rotation_feature_map = self.backbone(input_img) 156 | loss_branch = getattr(self, 'rotation')([rotation_feature_map, input_label]) 157 | loss += (loss_branch * supervision['weight']) 158 | loss_supervision[j] += loss_branch.item() 159 | 160 | if self.mode == 'val': 161 | return category_accuracy, loss / (instance_sum[0] + 1e-10), acc / (instance_sum[0] + 1e-10), instance_sum 162 | if hasattr(self.args, 'module'): 163 | loss_supervision = loss_supervision.cuda() 164 | loss_classification = loss_classification.cuda() 165 | return category_accuracy, loss / (instance_sum[0] + 1e-10), acc / (instance_sum[0] + 1e-10), instance_sum, \ 166 | loss_supervision / (instance_sum[0] + 1e-10), loss_classification / (instance_sum[0] + 1e-10) 167 | else: 168 | return category_accuracy, loss / (instance_sum[0] + 1e-10), acc / (instance_sum[0] + 1e-10), \ 169 | instance_sum, None, None -------------------------------------------------------------------------------- /model/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.component.classifier import Classifier, CosClassifier 4 | from model.component.resnet import resnet18, resnet34, resnet10 5 | from model.component.attr import AttrClassifier 6 | from model.component.scene import SceneClassifier 7 | from model.component.seg import MaskPredictor 8 | from model.component.bbox import BBoxModule 9 | from model.component.bkg import FullMaskPredictor 10 | from model.component.hierarchy import HierarchyClassifier 11 | from model.component.part import PartClassifier 12 | from model.component.patch_location import PatchLocationClassifier 13 | from model.component.rotation import RotationClassifier 14 | 15 | 16 | class ModelBuilder: 17 | # weight initialization 18 | def __init__(self, args): 19 | self.args = args 20 | 21 | def weight_init(self, m): 22 | class_name = m.__class__.__name__ 23 | if class_name.find('Conv') != -1: 24 | nn.init.kaiming_normal_(m.weight.data) 25 | elif class_name.find('BatchNorm') != -1: 26 | m.weight.data.fill_(1.) 27 | m.bias.data.fill_(1e-4) 28 | elif class_name.find('Linear') != -1: 29 | m.weight.data.normal_(0.0, 0.01) 30 | 31 | def build_backbone(self): 32 | if self.args.architecture == 'resnet18': 33 | backbone = resnet18() 34 | elif self.args.architecture == 'resnet10': 35 | backbone = resnet10() 36 | elif self.args.architecture == 'resnet34': 37 | backbone = resnet34() 38 | 39 | backbone.apply(self.weight_init) 40 | return backbone 41 | 42 | def build_classifier(self): 43 | if self.args.cls == 'Linear': 44 | classifier = Classifier(self.args) 45 | elif self.args.cls == 'Cos': 46 | classifier = CosClassifier(self.args) 47 | classifier.apply(self.weight_init) 48 | return classifier 49 | 50 | def build_attr(self): 51 | attr_classifier = AttrClassifier(self.args) 52 | attr_classifier.apply(self.weight_init) 53 | return attr_classifier 54 | 55 | def build_part(self): 56 | part_classifier = PartClassifier(self.args) 57 | part_classifier.apply(self.weight_init) 58 | return part_classifier 59 | 60 | def build_scene(self): 61 | scene_classifier = SceneClassifier(self.args) 62 | scene_classifier.apply(self.weight_init) 63 | return scene_classifier 64 | 65 | def build_seg(self): 66 | segment_module = MaskPredictor(self.args) 67 | segment_module.apply(self.weight_init) 68 | return segment_module 69 | 70 | def build_bkg(self): 71 | background_module = FullMaskPredictor(self.args) 72 | background_module.apply(self.weight_init) 73 | return background_module 74 | 75 | def build_bbox(self): 76 | bbox_module = BBoxModule(self.args) 77 | bbox_module.apply(self.weight_init) 78 | return bbox_module 79 | 80 | def build_hierarchy(self): 81 | hierarchy_classifier = HierarchyClassifier(self.args) 82 | hierarchy_classifier.apply(self.weight_init) 83 | return hierarchy_classifier 84 | 85 | def build_patch_location(self): 86 | patch_location_classifier = PatchLocationClassifier(self.args) 87 | patch_location_classifier.apply(self.weight_init) 88 | return patch_location_classifier 89 | 90 | def build_rotation(self): 91 | rotation_classifier = RotationClassifier(self.args) 92 | rotation_classifier.apply(self.weight_init) 93 | return rotation_classifier 94 | -------------------------------------------------------------------------------- /model/component/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinahHu/ADE-FewShot/41dc9cc481bfaf3bd9fb8bd76c1e63fcf127339d/model/component/__init__.py -------------------------------------------------------------------------------- /model/component/attr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class AttrSoftLoss(nn.Module): 8 | """ 9 | soft margin loss for this 10 | """ 11 | def __init__(self): 12 | super(AttrSoftLoss, self).__init__() 13 | 14 | def forward(self, x): 15 | """ 16 | consider attributes as multiple labels 17 | carry out multi-label classification loss 18 | :param x: score and attributes 19 | :return: loss 20 | """ 21 | scores, attributes = x 22 | attr_loss = 0.0 23 | attributes = attributes.float().cuda() 24 | for i in range(attributes.shape[0]): 25 | loss_mask = torch.ones(attributes.shape[1]).cuda() 26 | zeros = (attributes[i, :] == 0).nonzero().cpu().numpy() 27 | indices = np.random.choice(zeros.squeeze(), int(round(len(zeros) * 0.95)), False) 28 | loss_mask[indices] = 0 29 | 30 | attr_loss += F.multilabel_soft_margin_loss(scores[i].unsqueeze(0), 31 | attributes[i].unsqueeze(0), weight=loss_mask) 32 | attr_loss /= attributes.shape[0] 33 | return attr_loss 34 | 35 | 36 | class AttrClassifier(nn.Module): 37 | """ 38 | Linear Classifier 39 | """ 40 | def __init__(self, args): 41 | super(AttrClassifier, self).__init__() 42 | self.in_dim = args.feat_dim * args.crop_height * args.crop_width 43 | for supervision in args.supervision: 44 | if supervision['name'] == 'attr': 45 | self.num_class = supervision['other']['num_attr'] 46 | # self.mid_layer = nn.Linear(self.in_dim, self.in_dim) 47 | self.classifier = nn.Linear(self.in_dim, self.num_class) 48 | self.sigmoid = nn.Sigmoid() 49 | self.loss = AttrSoftLoss() 50 | self.mode = 'train' 51 | 52 | def forward(self, agg_data): 53 | """ 54 | forward pipeline, compute loss function 55 | :param agg_data: refer to ../base_model.py 56 | :return: loss, acc 57 | """ 58 | if self.mode == 'diagnosis': 59 | return self.diagnosis(agg_data) 60 | 61 | x = agg_data['features'] 62 | attributes = agg_data['attr'] 63 | x = self.classifier(x) 64 | # x = self.sigmoid(x) 65 | attributes = attributes[:x.shape[0]].long() 66 | loss = self.loss([x, attributes]) 67 | return loss 68 | -------------------------------------------------------------------------------- /model/component/bbox.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from roi_align.roi_align import RoIAlign 6 | from torch.autograd import Variable 7 | 8 | 9 | def to_variable(arr, requires_grad=False, is_cuda=True): 10 | tensor = torch.from_numpy(arr) 11 | if is_cuda: 12 | tensor = tensor.cuda() 13 | var = Variable(tensor, requires_grad=requires_grad) 14 | return var 15 | 16 | 17 | class BBoxModule(nn.Module): 18 | def __init__(self, args): 19 | super(BBoxModule, self).__init__() 20 | self.args = args 21 | for supervision in args.supervision: 22 | if supervision['name'] == 'bbox': 23 | self.crop_height = int(supervision['other']['pool_size']) 24 | self.crop_width = int(supervision['other']['pool_size']) 25 | self.roi_align = RoIAlign(self.crop_height, self.crop_width, transform_fpcoor=True) 26 | 27 | self.down_sampling_rate = self.args.down_sampling_rate 28 | 29 | self.feat_dim = args.feat_dim * self.crop_width * self.crop_height 30 | self.num_class = args.num_base_class 31 | self.regress = nn.Linear(self.feat_dim, 4) 32 | 33 | def process_in_roi_layer(self, feature_map, scale, anchors, anchor_num): 34 | """ 35 | process the data in roi_layer and get the feature 36 | :param feature_map: C * H * W 37 | :param scale: anchor_num * 2 38 | :param anchors: anchor_num * 4 39 | :param anchor_num: int 40 | :param change after processed in the network 41 | :return: feature C * crop_height * crop_width 42 | """ 43 | anchors = np.array(anchors.detach().cpu()) 44 | scale = np.array(scale.detach().cpu()) 45 | anchors = anchors / self.down_sampling_rate 46 | anchor_num = int(anchor_num) 47 | anchors[:, 2] = anchors[:, 2] * scale[0] 48 | anchors[:, 3] = anchors[:, 3] * scale[0] 49 | anchors[:, 0] = anchors[:, 0] * scale[1] 50 | anchors[:, 1] = anchors[:, 1] * scale[1] 51 | anchors[:, [1, 2]] = anchors[:, [2, 1]] 52 | anchor_index = np.zeros(anchor_num) 53 | anchor_index = to_variable(anchor_index).int() 54 | anchors = to_variable(anchors[:anchor_num, :]).float() 55 | feature_map = feature_map.unsqueeze(0) 56 | feature = self.roi_align(feature_map, anchors, anchor_index) 57 | feature = feature.view(-1, self.args.feat_dim * self.crop_height * self.crop_width) 58 | return feature 59 | 60 | @staticmethod 61 | def compute_anchor_location(anchor, scale, original_scale): 62 | """ 63 | compute the anchor location after resize operation 64 | :param anchor: input anchor 65 | :param scale: scale 66 | :param original_scale: the scale of original data loading 67 | :return: anchor on the feature map 68 | """ 69 | anchor = np.array(anchor.detach().cpu()) 70 | original_scale = np.array(original_scale) 71 | scale = np.array(scale.cpu()) 72 | anchor[:, 2] = anchor[:, 2] * scale[0] * original_scale[0] 73 | anchor[:, 3] = anchor[:, 3] * scale[0] * original_scale[0] 74 | anchor[:, 0] = anchor[:, 0] * scale[1] * original_scale[1] 75 | anchor[:, 1] = anchor[:, 1] * scale[1] * original_scale[1] 76 | return anchor 77 | 78 | @staticmethod 79 | def prepare_target_value(crop_anchor, tgt_anchor): 80 | crop_anchor = np.array(crop_anchor.cpu()) 81 | tgt_anchor = np.array(tgt_anchor.cpu()) 82 | # transpose the anchors, then we can directly get the location on each axis 83 | crop_anchor = np.transpose(crop_anchor) 84 | tgt_anchor = np.transpose(tgt_anchor) 85 | # left, right, up, down 86 | x_crop, r_crop, y_crop, d_crop = crop_anchor 87 | x_tgt, r_tgt, y_tgt, d_tgt = tgt_anchor 88 | h_crop = d_crop - y_crop 89 | h_tgt = d_tgt - y_tgt 90 | w_crop = r_crop - x_crop 91 | w_tgt = r_tgt - x_tgt 92 | 93 | # compute the value 94 | dx = (x_tgt - x_crop) / (w_crop + 1e-10) 95 | dy = (y_crop - y_tgt) / (h_crop + 1e-10) 96 | dw = np.log(w_tgt / (w_crop + 1e-10)) 97 | dh = np.log(h_tgt / (h_crop + 1e-10)) 98 | target_value = np.stack((dx, dy, dw, dh), axis=0) 99 | # restore the shape 100 | target_value = np.transpose(target_value) 101 | return target_value 102 | 103 | def forward(self, agg_data): 104 | feature_map = agg_data['feature_map'] 105 | crop_anchors = agg_data['anchors'] 106 | anchor_num = crop_anchors.shape[0] 107 | tgt_anchors = agg_data['bbox'][:anchor_num, :] 108 | scale = agg_data['scales'] 109 | 110 | features = self.process_in_roi_layer(feature_map, scale, 111 | crop_anchors, anchor_num) 112 | 113 | target_value = self.prepare_target_value(crop_anchor=crop_anchors, tgt_anchor=tgt_anchors) 114 | target_value = torch.tensor(target_value).float().cuda() 115 | 116 | pred_value = self.regress(features) 117 | loss = F.smooth_l1_loss(pred_value, target_value) 118 | return loss 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /model/component/bkg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import json 6 | 7 | 8 | class FullMaskPredictor(nn.Module): 9 | def __init__(self, args): 10 | super(FullMaskPredictor, self).__init__() 11 | self.in_dim = args.feat_dim 12 | self.args = args 13 | self.down_sampling_rate = args.down_sampling_rate 14 | 15 | self.fc1 = nn.Conv2d(self.in_dim, self.args.num_all_class, kernel_size=3, stride=1, padding=1) 16 | 17 | self.base_classes = json.load(open('data/ADE/ADE_Origin/base_list.json', 'r')) 18 | 19 | @staticmethod 20 | def compute_anchor_location(anchor, scale, original_scale): 21 | anchor = np.array(anchor.detach().cpu()) 22 | original_scale = np.array(original_scale) 23 | scale = np.array(scale.cpu()) 24 | anchor[:, 2] = np.floor(anchor[:, 2] * scale[0] * original_scale[0]) 25 | anchor[:, 3] = np.ceil(anchor[:, 3] * scale[0] * original_scale[0]) 26 | anchor[:, 0] = np.floor(anchor[:, 0] * scale[1] * original_scale[1]) 27 | anchor[:, 1] = np.ceil(anchor[:, 1] * scale[1] * original_scale[1]) 28 | return anchor.astype(np.int) 29 | 30 | @staticmethod 31 | def binary_transform(mask, label): 32 | return mask[:, int(label.item()), :, :] 33 | 34 | def forward(self, agg_input): 35 | """ 36 | take in the feature map and make predictions 37 | :param agg_input: input data 38 | :return: loss averaged over instances 39 | """ 40 | feature_map = agg_input['feature_map'] 41 | mask = agg_input['bkg'] 42 | 43 | feature_map = feature_map.unsqueeze(0) 44 | predicted_map = self.fc1(feature_map) 45 | mask = mask.unsqueeze(0) 46 | mask = mask.unsqueeze(0) 47 | mask = F.interpolate(mask, size=(predicted_map.shape[2], predicted_map.shape[3]), mode='nearest') 48 | mask = mask.squeeze(0) 49 | weight = torch.ones(self.args.num_all_class).cuda() 50 | weight[self.args.num_base_class:] = 0.1 51 | 52 | loss = F.cross_entropy(predicted_map, mask.long(), weight=weight) 53 | return loss 54 | -------------------------------------------------------------------------------- /model/component/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class Classifier(nn.Module): 8 | def __init__(self, args): 9 | super(Classifier, self).__init__() 10 | self.num_class = args.num_base_class 11 | self.down_sampling_rate = args.down_sampling_rate 12 | self.fc_512 = nn.Linear(args.feat_dim * args.crop_width * args.crop_height, self.num_class) 13 | self.loss = nn.CrossEntropyLoss(ignore_index=-1) 14 | self.mode = 'train' 15 | 16 | def _acc(self, pred, label): 17 | category_accuracy = torch.zeros(2, self.num_class) 18 | _, preds = torch.max(pred, dim=1) 19 | valid = (label >= 0).long() 20 | acc_sum = torch.sum(valid * (preds == label).long()) 21 | instance_sum = torch.sum(valid) 22 | acc = acc_sum.float() / (instance_sum.float() + 1e-10) 23 | for i, label_instance in enumerate(label): 24 | category_accuracy[1, label_instance] += 1 25 | if preds[i] == label_instance: 26 | category_accuracy[0, label_instance] += 1 27 | del pred 28 | return acc, category_accuracy 29 | 30 | def forward(self, x): 31 | if self.mode == 'diagnosis': 32 | return self.diagnosis(x) 33 | 34 | feature, labels = x 35 | pred = self.fc_512(feature) 36 | loss = self.loss(pred, labels) 37 | acc, category_accuracy = self._acc(pred, labels) 38 | 39 | return loss, acc, category_accuracy 40 | 41 | 42 | class CosClassifier(nn.Module): 43 | def __init__(self, args): 44 | super(CosClassifier, self).__init__() 45 | 46 | self.num_class = args.num_base_class 47 | self.indim = args.feat_dim * args.crop_width * args.crop_height 48 | self.outdim = args.num_base_class 49 | 50 | self.t = torch.ones(1).cuda() * 10 51 | self.weight = nn.Parameter(torch.Tensor(self.outdim , self.indim)) 52 | self.reset_parameters() 53 | self.loss = nn.CrossEntropyLoss(ignore_index=-1) 54 | self.mode = 'train' 55 | 56 | def reset_parameters(self): 57 | stdv = 1. / math.sqrt(self.weight.size(1)) 58 | self.weight.data.uniform_(-stdv, stdv) 59 | 60 | def _acc(self, pred, label): 61 | category_accuracy = torch.zeros(2, self.num_class) 62 | _, preds = torch.max(pred, dim=1) 63 | valid = (label >= 0).long() 64 | acc_sum = torch.sum(valid * (preds == label).long()) 65 | instance_sum = torch.sum(valid) 66 | acc = acc_sum.float() / (instance_sum.float() + 1e-10) 67 | for i, label_instance in enumerate(label): 68 | category_accuracy[1, label_instance] += 1 69 | if preds[i] == label_instance: 70 | category_accuracy[0, label_instance] += 1 71 | del pred 72 | return acc, category_accuracy 73 | 74 | def forward(self, data): 75 | if self.mode == 'diagnosis': 76 | return self.diagnosis(data) 77 | loss = 0 78 | feature, labels = data 79 | feat_layers = len(feature) 80 | pred = None 81 | 82 | for i in range(feat_layers-1, feat_layers): 83 | x = feature[i] 84 | batch_size = x.size(0) 85 | pred = self.t.cuda() * F.cosine_similarity( 86 | x.unsqueeze(1).expand(batch_size, self.outdim, self.indim), 87 | self.weight.unsqueeze(0).expand(batch_size, self.outdim, self.indim).cuda(), 2) 88 | loss += self.loss(pred, labels) 89 | 90 | acc, category_accuracy = self._acc(pred, labels) 91 | 92 | return loss, acc, category_accuracy 93 | -------------------------------------------------------------------------------- /model/component/hierarchy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class HierarchyClassifier(nn.Module): 8 | def __init__(self, args): 9 | super(HierarchyClassifier, self).__init__() 10 | self.layer_width = [] 11 | self.in_dim = args.feat_dim * args.crop_height * args.crop_width 12 | for supervision in args.supervision: 13 | if supervision['name'] == 'hierarchy': 14 | self.layer_width = supervision['other']['layer_width'] 15 | self.fcs = nn.ModuleList() 16 | for width in self.layer_width: 17 | self.fcs.append(nn.Linear(self.in_dim, width)) 18 | self.loss = nn.CrossEntropyLoss() 19 | self.mode = 'train' 20 | 21 | def forward(self, agg_data): 22 | """ 23 | forward pipeline, compute loss function 24 | :param agg_data: refer to ../base_model.py 25 | :return: loss 26 | """ 27 | if self.mode == 'diagnosis': 28 | return self.diagnosis(agg_data) 29 | loss_sum = 0 30 | x = agg_data['features'] 31 | hierarchy = agg_data['hierarchy'].long() 32 | hierarchy = hierarchy[:x.shape[0]] 33 | 34 | # Shallow supervision only 35 | losses = [] 36 | for i in range(len(self.fcs)): 37 | fc = self.fcs[i] 38 | label = hierarchy[:, i] 39 | score = fc(x) 40 | losses.append(self.loss(score, label)) 41 | for loss in losses: 42 | loss_sum += loss 43 | return loss_sum 44 | -------------------------------------------------------------------------------- /model/component/part.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class AttrSoftLoss(nn.Module): 8 | """ 9 | soft margin loss for this 10 | """ 11 | def __init__(self): 12 | super(AttrSoftLoss, self).__init__() 13 | 14 | def forward(self, x): 15 | """ 16 | consider attributes as multiple labels 17 | carry out multi-label classification loss 18 | :param x: score and attributes 19 | :return: loss 20 | """ 21 | scores, attributes = x 22 | attr_loss = 0.0 23 | attributes = attributes.float().cuda() 24 | for i in range(attributes.shape[0]): 25 | loss_mask = torch.ones(attributes.shape[1]).cuda() 26 | zeros = (attributes[i, :] == 0).nonzero().cpu().numpy() 27 | indices = np.random.choice(zeros.squeeze(), int(round(len(zeros) * 0.95)), False) 28 | loss_mask[indices] = 0 29 | 30 | attr_loss += F.multilabel_soft_margin_loss(scores[i].unsqueeze(0), 31 | attributes[i].unsqueeze(0), weight=loss_mask) 32 | attr_loss /= attributes.shape[0] 33 | return attr_loss 34 | 35 | 36 | class PartClassifier(nn.Module): 37 | """ 38 | Linear Classifier 39 | """ 40 | def __init__(self, args): 41 | super(PartClassifier, self).__init__() 42 | self.in_dim = args.feat_dim * args.crop_height * args.crop_width 43 | for supervision in args.supervision: 44 | if supervision['name'] == 'part': 45 | self.num_class = supervision['other']['num_attr'] 46 | # self.mid_layer = nn.Linear(self.in_dim, self.in_dim) 47 | self.classifier = nn.Linear(self.in_dim, self.num_class) 48 | self.sigmoid = nn.Sigmoid() 49 | self.loss = AttrSoftLoss() 50 | self.mode = 'train' 51 | 52 | def forward(self, agg_data): 53 | """ 54 | forward pipeline, compute loss function 55 | :param agg_data: refer to ../base_model.py 56 | :return: loss, acc 57 | """ 58 | if self.mode == 'diagnosis': 59 | return self.diagnosis(agg_data) 60 | 61 | x = agg_data['features'] 62 | attributes = agg_data['part'] 63 | # x = self.mid_layer(x) 64 | x = self.classifier(x) 65 | # x = self.sigmoid(x) 66 | attributes = attributes[:x.shape[0]].long() 67 | loss = self.loss([x, attributes]) 68 | return loss 69 | -------------------------------------------------------------------------------- /model/component/patch_location.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class PatchLocationClassifier(nn.Module): 8 | def __init__(self, args): 9 | super(PatchLocationClassifier, self).__init__() 10 | self.num_class = 8 11 | self.fc = nn.Linear(2 * 512, self.num_class) 12 | self.loss = nn.CrossEntropyLoss(ignore_index=-1) 13 | self.mode = 'train' 14 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 15 | 16 | def forward(self, x): 17 | """ 18 | the input has shape [batch_img_num, 9, C, H, W] 19 | :param x: input 20 | :return: loss 21 | """ 22 | feature, labels = x 23 | batch_img_num, _, channel, _, _ = feature.shape 24 | pooled_features = [] 25 | for j in range(2): 26 | pooled_features.append(self.global_pool(feature[:, j, :, :, :])) 27 | concat_feature = torch.cat(pooled_features, 1).view(batch_img_num, -1) 28 | pred = self.fc(concat_feature) 29 | loss = self.loss(pred, labels) 30 | 31 | return loss 32 | -------------------------------------------------------------------------------- /model/component/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from model.\ 7 | parallel.batchnorm import SynchronizedBatchNorm2d 8 | 9 | try: 10 | from urllib import urlretrieve 11 | except ImportError: 12 | from urllib.request import urlretrieve 13 | 14 | 15 | __all__ = ['ResNet', 'resnet18'] # resnet101 is coming soon! 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, groups=groups, bias=False, dilation=dilation) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 33 | base_width=64, dilation=1, norm_layer=None): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = SynchronizedBatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = norm_layer(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = norm_layer(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | if self.downsample is not None: 61 | identity = self.downsample(x) 62 | 63 | out += identity 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | class Bottleneck(nn.Module): 70 | expansion = 4 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=SynchronizedBatchNorm2d): 74 | super(Bottleneck, self).__init__() 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width) 79 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 80 | self.bn2 = norm_layer(width) 81 | self.conv3 = conv1x1(width, planes * self.expansion) 82 | self.bn3 = norm_layer(planes * self.expansion) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet(nn.Module): 111 | 112 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 113 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 114 | norm_layer=None): 115 | super(ResNet, self).__init__() 116 | if norm_layer is None: 117 | norm_layer = SynchronizedBatchNorm2d 118 | self._norm_layer = norm_layer 119 | 120 | self.inplanes = 64 121 | self.dilation = 1 122 | if replace_stride_with_dilation is None: 123 | # each element in the tuple indicates if we should replace 124 | # the 2x2 stride with a dilated convolution instead 125 | replace_stride_with_dilation = [False, False, False] 126 | if len(replace_stride_with_dilation) != 3: 127 | raise ValueError("replace_stride_with_dilation should be None " 128 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 129 | self.groups = groups 130 | self.base_width = width_per_group 131 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 132 | bias=False) 133 | self.bn1 = norm_layer(self.inplanes) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=1, 138 | dilate=replace_stride_with_dilation[0]) 139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 140 | dilate=replace_stride_with_dilation[1]) 141 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 142 | dilate=replace_stride_with_dilation[2], grow=True) 143 | self.avgpool = nn.AvgPool2d(kernel_size=3, stride=1) 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 148 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 149 | nn.init.constant_(m.weight, 1) 150 | nn.init.constant_(m.bias, 0) 151 | 152 | # Zero-initialize the last BN in each residual branch, 153 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 154 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 155 | if zero_init_residual: 156 | for m in self.modules(): 157 | if isinstance(m, Bottleneck): 158 | nn.init.constant_(m.bn3.weight, 0) 159 | elif isinstance(m, BasicBlock): 160 | nn.init.constant_(m.bn2.weight, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, grow=False): 163 | norm_layer = self._norm_layer 164 | downsample = None 165 | previous_dilation = self.dilation 166 | if dilate: 167 | self.dilation *= stride 168 | stride = 1 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | conv1x1(self.inplanes, planes * block.expansion, stride), 172 | norm_layer(planes * block.expansion), 173 | ) 174 | 175 | layers = [] 176 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 177 | self.base_width, previous_dilation, norm_layer)) 178 | self.inplanes = planes * block.expansion 179 | for _ in range(1, blocks): 180 | layers.append(block(self.inplanes, planes, groups=self.groups, 181 | base_width=self.base_width, dilation=self.dilation, 182 | norm_layer=norm_layer)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def forward(self, x): 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | x = self.maxpool(x) 191 | 192 | feat1 = self.layer1(x) 193 | feat2 = self.layer2(feat1) 194 | feat3 = self.layer3(feat2) 195 | feat4 = self.layer4(feat3) 196 | 197 | feat4 = self.avgpool(feat4) 198 | return feat4 199 | 200 | 201 | def resnet10(pretrained=False, progress=True, **kwargs): 202 | r"""ResNet-10 model from 203 | `"Deep Residual Learning for Image Recognition" '_ 204 | Args: 205 | pretrained (bool): If True, returns a model pre-trained on ImageNet 206 | progress (bool): If True, displays a progress bar of the download to stderr 207 | """ 208 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 209 | return model 210 | 211 | 212 | def resnet18(pretrained=False, progress=True, **kwargs): 213 | r"""ResNet-18 model from 214 | `"Deep Residual Learning for Image Recognition" '_ 215 | Args: 216 | pretrained (bool): If True, returns a model pre-trained on ImageNet 217 | progress (bool): If True, displays a progress bar of the download to stderr 218 | """ 219 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 220 | return model 221 | 222 | 223 | def resnet34(pretrained=False, progress=True, **kwargs): 224 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 225 | return model 226 | -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/__init__.py: -------------------------------------------------------------------------------- 1 | from .roi_align import RoIAlign, CropAndResizeFunction -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/crop_and_resize.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | 7 | import roi_align.crop_and_resize_cpu as crop_and_resize_cpu 8 | if torch.cuda.is_available(): 9 | import roi_align.crop_and_resize_gpu as crop_and_resize_gpu 10 | 11 | 12 | 13 | class CropAndResizeFunction(Function): 14 | 15 | def __init__(self, crop_height, crop_width, extrapolation_value=0): 16 | self.crop_height = crop_height 17 | self.crop_width = crop_width 18 | self.extrapolation_value = extrapolation_value 19 | 20 | def forward(self, image, boxes, box_ind): 21 | crops = torch.zeros_like(image) 22 | 23 | if image.is_cuda: 24 | crop_and_resize_gpu.forward( 25 | image, boxes, box_ind, 26 | self.extrapolation_value, self.crop_height, self.crop_width, crops) 27 | else: 28 | crop_and_resize_cpu.forward( 29 | image, boxes, box_ind, 30 | self.extrapolation_value, self.crop_height, self.crop_width, crops) 31 | 32 | # save for backward 33 | self.im_size = image.size() 34 | self.save_for_backward(boxes, box_ind) 35 | 36 | return crops 37 | 38 | def backward(self, grad_outputs): 39 | boxes, box_ind = self.saved_tensors 40 | 41 | grad_outputs = grad_outputs.contiguous() 42 | grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size) 43 | 44 | if grad_outputs.is_cuda: 45 | crop_and_resize_gpu.backward( 46 | grad_outputs, boxes, box_ind, grad_image 47 | ) 48 | else: 49 | crop_and_resize_cpu.backward( 50 | grad_outputs, boxes, box_ind, grad_image 51 | ) 52 | 53 | return grad_image, None, None 54 | 55 | 56 | class CropAndResize(nn.Module): 57 | """ 58 | Crop and resize ported from tensorflow 59 | See more details on https://www.tensorflow.org/api_docs/python/tf/image/crop_and_resize 60 | """ 61 | 62 | def __init__(self, crop_height, crop_width, extrapolation_value=0): 63 | super(CropAndResize, self).__init__() 64 | 65 | self.crop_height = crop_height 66 | self.crop_width = crop_width 67 | self.extrapolation_value = extrapolation_value 68 | 69 | def forward(self, image, boxes, box_ind): 70 | return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(image, boxes, box_ind) 71 | -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/roi_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from roi_align.crop_and_resize import CropAndResizeFunction, CropAndResize 5 | 6 | 7 | class RoIAlign(nn.Module): 8 | 9 | def __init__(self, crop_height, crop_width, extrapolation_value=0, transform_fpcoor=True): 10 | super(RoIAlign, self).__init__() 11 | 12 | self.crop_height = crop_height 13 | self.crop_width = crop_width 14 | self.extrapolation_value = extrapolation_value 15 | self.transform_fpcoor = transform_fpcoor 16 | 17 | def forward(self, featuremap, boxes, box_ind): 18 | """ 19 | RoIAlign based on crop_and_resize. 20 | See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301 21 | :param featuremap: NxCxHxW 22 | :param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization** 23 | :param box_ind: M 24 | :return: MxCxoHxoW 25 | """ 26 | x1, y1, x2, y2 = torch.split(boxes, 1, dim=1) 27 | image_height, image_width = featuremap.size()[2:4] 28 | 29 | if self.transform_fpcoor: 30 | spacing_w = (x2 - x1) / float(self.crop_width) 31 | spacing_h = (y2 - y1) / float(self.crop_height) 32 | 33 | nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1) 34 | ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1) 35 | nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1) 36 | nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1) 37 | 38 | boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1) 39 | else: 40 | x1 = x1 / float(image_width - 1) 41 | x2 = x2 / float(image_width - 1) 42 | y1 = y1 / float(image_height - 1) 43 | y2 = y2 / float(image_height - 1) 44 | boxes = torch.cat((y1, x1, y2, x2), 1) 45 | 46 | boxes = boxes.detach().contiguous() 47 | box_ind = box_ind.detach() 48 | return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind) 49 | -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/src/crop_and_resize.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | //#include 3 | #include 4 | #include 5 | 6 | namespace torch { 7 | void CropAndResizePerBox( 8 | const float * image_data, 9 | const int batch_size, 10 | const int depth, 11 | const int image_height, 12 | const int image_width, 13 | 14 | const float * boxes_data, 15 | const int * box_index_data, 16 | const int start_box, 17 | const int limit_box, 18 | 19 | float * corps_data, 20 | const int crop_height, 21 | const int crop_width, 22 | const float extrapolation_value 23 | ) { 24 | const int image_channel_elements = image_height * image_width; 25 | const int image_elements = depth * image_channel_elements; 26 | 27 | const int channel_elements = crop_height * crop_width; 28 | const int crop_elements = depth * channel_elements; 29 | 30 | int b; 31 | #pragma omp parallel for 32 | for (b = start_box; b < limit_box; ++b) { 33 | const float * box = boxes_data + b * 4; 34 | const float y1 = box[0]; 35 | const float x1 = box[1]; 36 | const float y2 = box[2]; 37 | const float x2 = box[3]; 38 | 39 | const int b_in = box_index_data[b]; 40 | if (b_in < 0 || b_in >= batch_size) { 41 | printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); 42 | exit(-1); 43 | } 44 | 45 | const float height_scale = 46 | (crop_height > 1) 47 | ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 48 | : 0; 49 | const float width_scale = 50 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) 51 | : 0; 52 | 53 | for (int y = 0; y < crop_height; ++y) 54 | { 55 | const float in_y = (crop_height > 1) 56 | ? y1 * (image_height - 1) + y * height_scale 57 | : 0.5 * (y1 + y2) * (image_height - 1); 58 | 59 | if (in_y < 0 || in_y > image_height - 1) 60 | { 61 | for (int x = 0; x < crop_width; ++x) 62 | { 63 | for (int d = 0; d < depth; ++d) 64 | { 65 | // crops(b, y, x, d) = extrapolation_value; 66 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; 67 | } 68 | } 69 | continue; 70 | } 71 | 72 | const int top_y_index = floorf(in_y); 73 | const int bottom_y_index = ceilf(in_y); 74 | const float y_lerp = in_y - top_y_index; 75 | 76 | for (int x = 0; x < crop_width; ++x) 77 | { 78 | const float in_x = (crop_width > 1) 79 | ? x1 * (image_width - 1) + x * width_scale 80 | : 0.5 * (x1 + x2) * (image_width - 1); 81 | if (in_x < 0 || in_x > image_width - 1) 82 | { 83 | for (int d = 0; d < depth; ++d) 84 | { 85 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; 86 | } 87 | continue; 88 | } 89 | 90 | const int left_x_index = floorf(in_x); 91 | const int right_x_index = ceilf(in_x); 92 | const float x_lerp = in_x - left_x_index; 93 | 94 | for (int d = 0; d < depth; ++d) 95 | { 96 | const float *pimage = image_data + b_in * image_elements + d * image_channel_elements; 97 | 98 | const float top_left = pimage[top_y_index * image_width + left_x_index]; 99 | const float top_right = pimage[top_y_index * image_width + right_x_index]; 100 | const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; 101 | const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; 102 | 103 | const float top = top_left + (top_right - top_left) * x_lerp; 104 | const float bottom = 105 | bottom_left + (bottom_right - bottom_left) * x_lerp; 106 | 107 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = top + (bottom - top) * y_lerp; 108 | } 109 | } // end for x 110 | } // end for y 111 | } // end for b 112 | 113 | } 114 | 115 | #define CHECK_CUDA(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be a CPU tensor") 116 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 117 | #define CHECK_DIMS(x) AT_ASSERTM(x.dim() == 4, #x " must have 4 dimensions") 118 | 119 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 120 | #define CHECK_FLOAT(x) AT_ASSERTM(x.type().scalarType() == torch::ScalarType::Float, #x " must be float Tensor") 121 | #define CHECK_INT(x) AT_ASSERTM(x.type().scalarType() == torch::ScalarType::Int, #x " must be int Tensor") 122 | 123 | void crop_and_resize_forward( 124 | torch::Tensor image, 125 | torch::Tensor boxes, // [y1, x1, y2, x2] 126 | torch::Tensor box_index, // range in [0, batch_size) 127 | const float extrapolation_value, 128 | const int crop_height, 129 | const int crop_width, 130 | torch::Tensor crops 131 | ) { 132 | CHECK_INPUT(image); CHECK_FLOAT(image); CHECK_DIMS(image); 133 | CHECK_INPUT(boxes); CHECK_FLOAT(boxes); //TODO: check dims for other arguments required. 134 | CHECK_INPUT(box_index); CHECK_INT(box_index); 135 | CHECK_INPUT(crops); CHECK_FLOAT(crops); 136 | 137 | const int batch_size = image.size(0); 138 | const int depth = image.size(1); 139 | const int image_height = image.size(2); 140 | const int image_width = image.size(3); 141 | 142 | const int num_boxes = boxes.size(0); 143 | 144 | crops.resize_({num_boxes, depth, crop_height, crop_width}); 145 | crops.zero_(); 146 | 147 | // crop_and_resize for each box 148 | CropAndResizePerBox( 149 | image.data(), 150 | batch_size, 151 | depth, 152 | image_height, 153 | image_width, 154 | 155 | boxes.data(), 156 | box_index.data(), 157 | 0, 158 | num_boxes, 159 | 160 | crops.data(), 161 | crop_height, 162 | crop_width, 163 | extrapolation_value 164 | ); 165 | 166 | } 167 | 168 | 169 | void crop_and_resize_backward( 170 | torch::Tensor grads, 171 | torch::Tensor boxes, // [y1, x1, y2, x2] 172 | torch::Tensor box_index, // range in [0, batch_size) 173 | torch::Tensor grads_image // resize to [bsize, c, hc, wc] 174 | ) { 175 | CHECK_INPUT(grads); CHECK_FLOAT(grads); 176 | CHECK_INPUT(boxes); CHECK_FLOAT(boxes); 177 | CHECK_INPUT(box_index); CHECK_INT(box_index); 178 | CHECK_INPUT(grads_image); CHECK_FLOAT(grads_image); CHECK_DIMS(grads_image); 179 | 180 | // shape 181 | const int batch_size = grads_image.size(0); 182 | const int depth = grads_image.size(1); 183 | const int image_height = grads_image.size(2); 184 | const int image_width = grads_image.size(3); 185 | 186 | const int num_boxes = grads.size(0); 187 | const int crop_height = grads.size(2); 188 | const int crop_width = grads.size(3); 189 | 190 | // n_elements 191 | const int image_channel_elements = image_height * image_width; 192 | const int image_elements = depth * image_channel_elements; 193 | 194 | const int channel_elements = crop_height * crop_width; 195 | const int crop_elements = depth * channel_elements; 196 | 197 | // init output space 198 | grads_image.zero_(); 199 | // THFloatTensor_zero(grads_image); 200 | 201 | // data pointer 202 | const float * grads_data = grads.data(); 203 | const float * boxes_data = boxes.data(); 204 | const int * box_index_data = box_index.data(); 205 | float * grads_image_data = grads_image.data(); 206 | 207 | for (int b = 0; b < num_boxes; ++b) { 208 | const float * box = boxes_data + b * 4; 209 | const float y1 = box[0]; 210 | const float x1 = box[1]; 211 | const float y2 = box[2]; 212 | const float x2 = box[3]; 213 | 214 | const int b_in = box_index_data[b]; 215 | if (b_in < 0 || b_in >= batch_size) { 216 | printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); 217 | exit(-1); 218 | } 219 | 220 | const float height_scale = 221 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 222 | : 0; 223 | const float width_scale = 224 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) 225 | : 0; 226 | 227 | for (int y = 0; y < crop_height; ++y) 228 | { 229 | const float in_y = (crop_height > 1) 230 | ? y1 * (image_height - 1) + y * height_scale 231 | : 0.5 * (y1 + y2) * (image_height - 1); 232 | if (in_y < 0 || in_y > image_height - 1) 233 | { 234 | continue; 235 | } 236 | const int top_y_index = floorf(in_y); 237 | const int bottom_y_index = ceilf(in_y); 238 | const float y_lerp = in_y - top_y_index; 239 | 240 | for (int x = 0; x < crop_width; ++x) 241 | { 242 | const float in_x = (crop_width > 1) 243 | ? x1 * (image_width - 1) + x * width_scale 244 | : 0.5 * (x1 + x2) * (image_width - 1); 245 | if (in_x < 0 || in_x > image_width - 1) 246 | { 247 | continue; 248 | } 249 | const int left_x_index = floorf(in_x); 250 | const int right_x_index = ceilf(in_x); 251 | const float x_lerp = in_x - left_x_index; 252 | 253 | for (int d = 0; d < depth; ++d) 254 | { 255 | float *pimage = grads_image_data + b_in * image_elements + d * image_channel_elements; 256 | const float grad_val = grads_data[crop_elements * b + channel_elements * d + y * crop_width + x]; 257 | 258 | const float dtop = (1 - y_lerp) * grad_val; 259 | pimage[top_y_index * image_width + left_x_index] += (1 - x_lerp) * dtop; 260 | pimage[top_y_index * image_width + right_x_index] += x_lerp * dtop; 261 | 262 | const float dbottom = y_lerp * grad_val; 263 | pimage[bottom_y_index * image_width + left_x_index] += (1 - x_lerp) * dbottom; 264 | pimage[bottom_y_index * image_width + right_x_index] += x_lerp * dbottom; 265 | } // end d 266 | } // end x 267 | } // end y 268 | } // end b 269 | } 270 | } 271 | 272 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 273 | m.def( 274 | "forward", 275 | &torch::crop_and_resize_forward, 276 | "crop_and_resize_forward"); 277 | m.def( 278 | "backward", 279 | &torch::crop_and_resize_backward, 280 | "crop_and_resize_forward"); 281 | } -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/src/crop_and_resize.h: -------------------------------------------------------------------------------- 1 | namespace at { 2 | struct Tensor; 3 | } // namespace at 4 | namespace torch { 5 | void crop_and_resize_forward( 6 | at::Tensor image, 7 | at::Tensor boxes, // [y1, x1, y2, x2] 8 | at::Tensor box_index, // range in [0, batch_size) // int tensor 9 | const float extrapolation_value, 10 | const int crop_height, 11 | const int crop_width, 12 | at::Tensor crops 13 | ); 14 | 15 | void crop_and_resize_backward( 16 | at::Tensor grads, 17 | at::Tensor boxes, // [y1, x1, y2, x2] 18 | at::Tensor box_index, // range in [0, batch_size) // int 19 | at::Tensor grads_image // resize to [bsize, c, hc, wc] 20 | ); 21 | } -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/src/crop_and_resize_gpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | //#include 4 | #include "cuda/crop_and_resize_kernel.h" 5 | 6 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 7 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 8 | #define CHECK_DIMS(x) AT_ASSERTM(x.dim() == 4, #x " must have 4 dimensions") 9 | 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | #define CHECK_FLOAT(x) AT_ASSERTM(x.type().scalarType() == torch::ScalarType::Float, #x " must be float Tensor") 13 | #define CHECK_INT(x) AT_ASSERTM(x.type().scalarType() == torch::ScalarType::Int, #x " must be int Tensor") 14 | //using namespace at; 15 | 16 | 17 | namespace torch { 18 | void crop_and_resize_gpu_forward( 19 | torch::Tensor image, 20 | torch::Tensor boxes, // [y1, x1, y2, x2] 21 | torch::Tensor box_index, // range in [0, batch_size) 22 | const float extrapolation_value, 23 | const int crop_height, 24 | const int crop_width, 25 | torch::Tensor crops 26 | ) { 27 | CHECK_INPUT(image); CHECK_FLOAT(image); CHECK_DIMS(image); 28 | CHECK_INPUT(boxes); CHECK_FLOAT(boxes); 29 | CHECK_INPUT(box_index); CHECK_INT(box_index); 30 | CHECK_INPUT(crops); CHECK_FLOAT(crops); 31 | 32 | const int batch_size = image.size(0); 33 | const int depth = image.size(1); 34 | const int image_height = image.size(2); 35 | const int image_width = image.size(3); 36 | 37 | const int num_boxes = boxes.size(0); 38 | 39 | // init output space 40 | // THCTensor_resize(state, crops, {num_boxes, depth, crop_height, crop_width}); 41 | 42 | crops.resize_({num_boxes, depth, crop_height, crop_width}); 43 | crops.zero_(); 44 | // THCudaTensor_resize4d(state, crops, num_boxes, depth, crop_height, crop_width); 45 | // THCudaTensor_zero(state, crops); 46 | 47 | 48 | 49 | // auto state = globalContext().getTHCState(); 50 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();// THCState_getCurrentStream(state); 51 | 52 | CropAndResizeLaucher( 53 | image.data(), 54 | boxes.data(), 55 | box_index.data(), 56 | num_boxes, batch_size, image_height, image_width, 57 | crop_height, crop_width, depth, extrapolation_value, 58 | crops.data(), 59 | stream 60 | ); 61 | } 62 | 63 | 64 | void crop_and_resize_gpu_backward( 65 | torch::Tensor grads, 66 | torch::Tensor boxes, // [y1, x1, y2, x2] 67 | torch::Tensor box_index, // range in [0, batch_size) 68 | torch::Tensor grads_image // resize to [bsize, c, hc, wc] 69 | ) { 70 | CHECK_INPUT(grads); CHECK_FLOAT(grads); 71 | CHECK_INPUT(boxes); CHECK_FLOAT(boxes); 72 | CHECK_INPUT(box_index); CHECK_INT(box_index); 73 | CHECK_INPUT(grads_image); CHECK_FLOAT(grads_image); CHECK_DIMS(grads_image); 74 | 75 | // shape 76 | const int batch_size = grads_image.size(0); 77 | const int depth = grads_image.size(1); 78 | const int image_height = grads_image.size(2); 79 | const int image_width = grads_image.size(3); 80 | 81 | const int num_boxes = grads.size(0); 82 | const int crop_height = grads.size(2); 83 | const int crop_width = grads.size(3); 84 | 85 | // init output space 86 | grads_image.zero_(); 87 | 88 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 89 | CropAndResizeBackpropImageLaucher( 90 | grads.data(), 91 | boxes.data(), 92 | box_index.data(), 93 | num_boxes, batch_size, image_height, image_width, 94 | crop_height, crop_width, depth, 95 | grads_image.data(), 96 | stream 97 | ); 98 | } 99 | } 100 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 101 | m.def( 102 | "forward", 103 | &torch::crop_and_resize_gpu_forward, 104 | "crop_and_resize_gpu_forward"); 105 | m.def( 106 | "backward", 107 | &torch::crop_and_resize_gpu_backward, 108 | "crop_and_resize_gpu_backward"); 109 | } 110 | -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/src/crop_and_resize_gpu.h: -------------------------------------------------------------------------------- 1 | namespace torch { 2 | //cuda tensors 3 | void crop_and_resize_gpu_forward( 4 | torch::Tensor image, 5 | torch::Tensor boxes, // [y1, x1, y2, x2] 6 | torch::Tensor box_index, // range in [0, batch_size) // int 7 | const float extrapolation_value, 8 | const int crop_height, 9 | const int crop_width, 10 | torch::Tensor crops 11 | ); 12 | 13 | void crop_and_resize_gpu_backward( 14 | torch::Tensor grads, 15 | torch::Tensor boxes, // [y1, x1, y2, x2] 16 | torch::Tensor box_index, // range in [0, batch_size) // int 17 | torch::Tensor grads_image // resize to [bsize, c, hc, wc] 18 | ); 19 | } -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/src/cuda/crop_and_resize_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "crop_and_resize_kernel.h" 4 | 5 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 7 | i += blockDim.x * gridDim.x) 8 | 9 | 10 | __global__ 11 | void CropAndResizeKernel( 12 | const int nthreads, const float *image_ptr, const float *boxes_ptr, 13 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 14 | int image_width, int crop_height, int crop_width, int depth, 15 | float extrapolation_value, float *crops_ptr) 16 | { 17 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) 18 | { 19 | // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 20 | // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) 21 | int idx = out_idx; 22 | const int x = idx % crop_width; 23 | idx /= crop_width; 24 | const int y = idx % crop_height; 25 | idx /= crop_height; 26 | const int d = idx % depth; 27 | const int b = idx / depth; 28 | 29 | const float y1 = boxes_ptr[b * 4]; 30 | const float x1 = boxes_ptr[b * 4 + 1]; 31 | const float y2 = boxes_ptr[b * 4 + 2]; 32 | const float x2 = boxes_ptr[b * 4 + 3]; 33 | 34 | const int b_in = box_ind_ptr[b]; 35 | if (b_in < 0 || b_in >= batch) 36 | { 37 | continue; 38 | } 39 | 40 | const float height_scale = 41 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 42 | : 0; 43 | const float width_scale = 44 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 45 | 46 | const float in_y = (crop_height > 1) 47 | ? y1 * (image_height - 1) + y * height_scale 48 | : 0.5 * (y1 + y2) * (image_height - 1); 49 | if (in_y < 0 || in_y > image_height - 1) 50 | { 51 | crops_ptr[out_idx] = extrapolation_value; 52 | continue; 53 | } 54 | 55 | const float in_x = (crop_width > 1) 56 | ? x1 * (image_width - 1) + x * width_scale 57 | : 0.5 * (x1 + x2) * (image_width - 1); 58 | if (in_x < 0 || in_x > image_width - 1) 59 | { 60 | crops_ptr[out_idx] = extrapolation_value; 61 | continue; 62 | } 63 | 64 | const int top_y_index = floorf(in_y); 65 | const int bottom_y_index = ceilf(in_y); 66 | const float y_lerp = in_y - top_y_index; 67 | 68 | const int left_x_index = floorf(in_x); 69 | const int right_x_index = ceilf(in_x); 70 | const float x_lerp = in_x - left_x_index; 71 | 72 | const float *pimage = image_ptr + (b_in * depth + d) * image_height * image_width; 73 | const float top_left = pimage[top_y_index * image_width + left_x_index]; 74 | const float top_right = pimage[top_y_index * image_width + right_x_index]; 75 | const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; 76 | const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; 77 | 78 | const float top = top_left + (top_right - top_left) * x_lerp; 79 | const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; 80 | crops_ptr[out_idx] = top + (bottom - top) * y_lerp; 81 | } 82 | } 83 | 84 | __global__ 85 | void CropAndResizeBackpropImageKernel( 86 | const int nthreads, const float *grads_ptr, const float *boxes_ptr, 87 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 88 | int image_width, int crop_height, int crop_width, int depth, 89 | float *grads_image_ptr) 90 | { 91 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) 92 | { 93 | // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 94 | // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) 95 | int idx = out_idx; 96 | const int x = idx % crop_width; 97 | idx /= crop_width; 98 | const int y = idx % crop_height; 99 | idx /= crop_height; 100 | const int d = idx % depth; 101 | const int b = idx / depth; 102 | 103 | const float y1 = boxes_ptr[b * 4]; 104 | const float x1 = boxes_ptr[b * 4 + 1]; 105 | const float y2 = boxes_ptr[b * 4 + 2]; 106 | const float x2 = boxes_ptr[b * 4 + 3]; 107 | 108 | const int b_in = box_ind_ptr[b]; 109 | if (b_in < 0 || b_in >= batch) 110 | { 111 | continue; 112 | } 113 | 114 | const float height_scale = 115 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 116 | : 0; 117 | const float width_scale = 118 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 119 | 120 | const float in_y = (crop_height > 1) 121 | ? y1 * (image_height - 1) + y * height_scale 122 | : 0.5 * (y1 + y2) * (image_height - 1); 123 | if (in_y < 0 || in_y > image_height - 1) 124 | { 125 | continue; 126 | } 127 | 128 | const float in_x = (crop_width > 1) 129 | ? x1 * (image_width - 1) + x * width_scale 130 | : 0.5 * (x1 + x2) * (image_width - 1); 131 | if (in_x < 0 || in_x > image_width - 1) 132 | { 133 | continue; 134 | } 135 | 136 | const int top_y_index = floorf(in_y); 137 | const int bottom_y_index = ceilf(in_y); 138 | const float y_lerp = in_y - top_y_index; 139 | 140 | const int left_x_index = floorf(in_x); 141 | const int right_x_index = ceilf(in_x); 142 | const float x_lerp = in_x - left_x_index; 143 | 144 | float *pimage = grads_image_ptr + (b_in * depth + d) * image_height * image_width; 145 | const float dtop = (1 - y_lerp) * grads_ptr[out_idx]; 146 | atomicAdd( 147 | pimage + top_y_index * image_width + left_x_index, 148 | (1 - x_lerp) * dtop 149 | ); 150 | atomicAdd( 151 | pimage + top_y_index * image_width + right_x_index, 152 | x_lerp * dtop 153 | ); 154 | 155 | const float dbottom = y_lerp * grads_ptr[out_idx]; 156 | atomicAdd( 157 | pimage + bottom_y_index * image_width + left_x_index, 158 | (1 - x_lerp) * dbottom 159 | ); 160 | atomicAdd( 161 | pimage + bottom_y_index * image_width + right_x_index, 162 | x_lerp * dbottom 163 | ); 164 | } 165 | } 166 | 167 | 168 | void CropAndResizeLaucher( 169 | const float *image_ptr, const float *boxes_ptr, 170 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 171 | int image_width, int crop_height, int crop_width, int depth, 172 | float extrapolation_value, float *crops_ptr, cudaStream_t stream) 173 | { 174 | const int total_count = num_boxes * crop_height * crop_width * depth; 175 | const int thread_per_block = 1024; 176 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 177 | cudaError_t err; 178 | 179 | if (total_count > 0) 180 | { 181 | CropAndResizeKernel<<>>( 182 | total_count, image_ptr, boxes_ptr, 183 | box_ind_ptr, num_boxes, batch, image_height, image_width, 184 | crop_height, crop_width, depth, extrapolation_value, crops_ptr); 185 | 186 | err = cudaGetLastError(); 187 | if (cudaSuccess != err) 188 | { 189 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 190 | exit(-1); 191 | } 192 | } 193 | } 194 | 195 | 196 | void CropAndResizeBackpropImageLaucher( 197 | const float *grads_ptr, const float *boxes_ptr, 198 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 199 | int image_width, int crop_height, int crop_width, int depth, 200 | float *grads_image_ptr, cudaStream_t stream) 201 | { 202 | const int total_count = num_boxes * crop_height * crop_width * depth; 203 | const int thread_per_block = 1024; 204 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 205 | cudaError_t err; 206 | 207 | if (total_count > 0) 208 | { 209 | CropAndResizeBackpropImageKernel<<>>( 210 | total_count, grads_ptr, boxes_ptr, 211 | box_ind_ptr, num_boxes, batch, image_height, image_width, 212 | crop_height, crop_width, depth, grads_image_ptr); 213 | 214 | err = cudaGetLastError(); 215 | if (cudaSuccess != err) 216 | { 217 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 218 | exit(-1); 219 | } 220 | } 221 | } -------------------------------------------------------------------------------- /model/component/roi_align/roi_align/src/cuda/crop_and_resize_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _CropAndResize_Kernel 2 | #define _CropAndResize_Kernel 3 | 4 | //#include 5 | #include 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | void CropAndResizeLaucher( 11 | const float *image_ptr, const float *boxes_ptr, 12 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 13 | int image_width, int crop_height, int crop_width, int depth, 14 | float extrapolation_value, float *crops_ptr, cudaStream_t stream); 15 | 16 | void CropAndResizeBackpropImageLaucher( 17 | const float *grads_ptr, const float *boxes_ptr, 18 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 19 | int image_width, int crop_height, int crop_width, int depth, 20 | float *grads_image_ptr, cudaStream_t stream); 21 | 22 | #ifdef __cplusplus 23 | } 24 | #endif 25 | 26 | #endif -------------------------------------------------------------------------------- /model/component/roi_align/setup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from setuptools import setup, find_packages 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 4 | 5 | modules = [ 6 | CppExtension( 7 | 'roi_align.crop_and_resize_cpu', 8 | ['roi_align/src/crop_and_resize.cpp'] 9 | ) 10 | ] 11 | 12 | if torch.cuda.is_available(): 13 | modules.append( 14 | CUDAExtension( 15 | 'roi_align.crop_and_resize_gpu', 16 | ['roi_align/src/crop_and_resize_gpu.cpp', 17 | 'roi_align/src/cuda/crop_and_resize_kernel.cu'], 18 | extra_compile_args={'cxx': ['-g', '-fopenmp'], 19 | 'nvcc': ['-O2']} 20 | ) 21 | ) 22 | 23 | setup( 24 | name='roi_align', 25 | version='0.0.1', 26 | description='PyTorch version of RoIAlign', 27 | author='Long Chen', 28 | author_email='longch1024@gmail.com', 29 | url='https://github.com/longcw/RoIAlign.pytorch', 30 | packages=find_packages(exclude=('tests',)), 31 | 32 | ext_modules=modules, 33 | cmdclass={'build_ext': BuildExtension}, 34 | install_requires=['torch'] 35 | ) 36 | -------------------------------------------------------------------------------- /model/component/roi_align/test.sh: -------------------------------------------------------------------------------- 1 | python tests/test.py 2 | python tests/test2.py 3 | python tests/crop_and_resize_example.py 4 | -------------------------------------------------------------------------------- /model/component/roi_align/tests/crop_and_resize_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import transforms, utils 4 | from torch.autograd import Variable, gradcheck 5 | from roi_align.crop_and_resize import CropAndResizeFunction 6 | import matplotlib.pyplot as plt 7 | from skimage.io import imread 8 | 9 | 10 | def to_varabile(tensor, requires_grad=False, is_cuda=True): 11 | if is_cuda: 12 | tensor = tensor.cuda() 13 | var = Variable(tensor, requires_grad=requires_grad) 14 | return var 15 | 16 | 17 | crop_height = 500 18 | crop_width = 500 19 | is_cuda = torch.cuda.is_available() 20 | 21 | # In this simple example the number of images and boxes is 2 22 | img_path1 = 'tests/images/choco.png' 23 | img_path2 = 'tests/images/snow.png' 24 | 25 | # Define the boxes ( crops ) 26 | # box = [y1/heigth , x1/width , y2/heigth , x2/width] 27 | boxes_data = torch.FloatTensor([[0, 0, 1, 1], [0, 0, 0.5, 0.5]]) 28 | 29 | # Create an index to say which box crops which image 30 | box_index_data = torch.IntTensor([0, 1]) 31 | 32 | # Import the images from file 33 | image_data1 = transforms.ToTensor()(imread(img_path1)).unsqueeze(0) 34 | image_data2 = transforms.ToTensor()(imread(img_path2)).unsqueeze(0) 35 | 36 | # Create a batch of 2 images 37 | image_data = torch.cat((image_data1, image_data2), 0) 38 | 39 | # Convert from numpy to Variables 40 | image_torch = to_varabile(image_data, is_cuda=is_cuda) 41 | boxes = to_varabile(boxes_data, is_cuda=is_cuda) 42 | box_index = to_varabile(box_index_data, is_cuda=is_cuda) 43 | 44 | # Crops and resize bbox1 from img1 and bbox2 from img2 45 | crops_torch = CropAndResizeFunction(crop_height, crop_width, 0)(image_torch, boxes, box_index) 46 | 47 | # Visualize the crops 48 | print(crops_torch.data.size()) 49 | crops_torch_data = crops_torch.data.cpu().numpy().transpose(0, 2, 3, 1) 50 | fig = plt.figure() 51 | plt.subplot(121) 52 | plt.imshow(crops_torch_data[0]) 53 | plt.subplot(122) 54 | plt.imshow(crops_torch_data[1]) 55 | plt.show() 56 | -------------------------------------------------------------------------------- /model/component/roi_align/tests/images/choco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinahHu/ADE-FewShot/41dc9cc481bfaf3bd9fb8bd76c1e63fcf127339d/model/component/roi_align/tests/images/choco.png -------------------------------------------------------------------------------- /model/component/roi_align/tests/images/snow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinahHu/ADE-FewShot/41dc9cc481bfaf3bd9fb8bd76c1e63fcf127339d/model/component/roi_align/tests/images/snow.png -------------------------------------------------------------------------------- /model/component/roi_align/tests/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import sys 4 | from torch import nn 5 | from torch.autograd import Variable, gradcheck 6 | try: 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | except: 10 | print("Unexpected error:", sys.exc_info()[0]) 11 | tf = None 12 | 13 | from roi_align.crop_and_resize import CropAndResizeFunction 14 | from roi_align.roi_align import RoIAlign 15 | 16 | 17 | def to_varabile(arr, requires_grad=False, is_cuda=True): 18 | tensor = torch.from_numpy(arr) 19 | if is_cuda: 20 | tensor = tensor.cuda() 21 | var = Variable(tensor, requires_grad=requires_grad) 22 | return var 23 | 24 | 25 | def generate_data(batch_size, depth, im_height, im_width, n_boxes, xyxy=False, box_normalize=True): 26 | 27 | # random rois 28 | xs = np.random.uniform(0, im_width, size=(n_boxes, 2)) 29 | ys = np.random.uniform(0, im_height, size=(n_boxes, 2)) 30 | if box_normalize: 31 | xs /= (im_width - 1) 32 | ys /= (im_height - 1) 33 | 34 | xs.sort(axis=1) 35 | ys.sort(axis=1) 36 | 37 | if xyxy: 38 | boxes_data = np.stack((xs[:, 0], ys[:, 0], xs[:, 1], ys[:, 1]), axis=-1).astype(np.float32) 39 | else: 40 | boxes_data = np.stack((ys[:, 0], xs[:, 0], ys[:, 1], xs[:, 1]), axis=-1).astype(np.float32) 41 | box_index_data = np.random.randint(0, batch_size, size=n_boxes, dtype=np.int32) 42 | image_data = np.random.randn(batch_size, depth, im_height, im_width).astype(np.float32) 43 | 44 | return image_data, boxes_data, box_index_data 45 | 46 | 47 | def compare_with_tf(crop_height, crop_width, is_cuda=True): 48 | # generate data 49 | image_data, boxes_data, box_index_data = generate_data( 50 | batch_size=2, 51 | depth=128, 52 | im_height=200, 53 | im_width=200, 54 | n_boxes=10, 55 | xyxy=False, box_normalize=True) 56 | # boxes_tf_data = np.stack((boxes_data[:, 1], boxes_data[:, 0], boxes_data[:, 3], boxes_data[:, 2]), axis=1) 57 | # boxes_tf_data[:, 0::2] /= (image_data.shape[2] - 1.) 58 | # boxes_tf_data[:, 1::2] /= (image_data.shape[3] - 1.) 59 | 60 | # rand conv layer 61 | conv_torch = nn.Conv2d(image_data.shape[1], 64, 3, padding=1, bias=False) 62 | if is_cuda: 63 | conv_torch = conv_torch.cuda() 64 | 65 | # pytorch forward 66 | image_torch = to_varabile(image_data, requires_grad=True, is_cuda=is_cuda) 67 | boxes = to_varabile(boxes_data, requires_grad=False, is_cuda=is_cuda) 68 | box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda) 69 | 70 | print('pytorch forward and backward start') 71 | crops_torch = CropAndResizeFunction(crop_height, crop_width, 0)(image_torch, boxes, box_index) 72 | crops_torch = conv_torch(crops_torch) 73 | crops_torch_data = crops_torch.data.cpu().numpy() 74 | 75 | # pytorch backward 76 | loss_torch = crops_torch.sum() 77 | loss_torch.backward() 78 | grad_torch_data = image_torch.grad.data.cpu().numpy() 79 | 80 | print('pytorch forward and backward end') 81 | 82 | # tf forward & backward 83 | image_tf = tf.placeholder(tf.float32, (None, None, None, None), name='image') 84 | boxes = tf.placeholder(tf.float32, (None, 4), name='boxes') 85 | box_index = tf.placeholder(tf.int32, (None,), name='box_index') 86 | 87 | image_t = tf.transpose(image_tf, (0, 2, 3, 1)) 88 | crops_tf = tf.image.crop_and_resize(image_t, boxes, box_index, (crop_height, crop_width)) 89 | conv_tf = tf.nn.conv2d(crops_tf, np.transpose(conv_torch.weight.data.cpu().numpy(), (2, 3, 1, 0)), 90 | [1, 1, 1, 1], padding='SAME') 91 | 92 | trans_tf = tf.transpose(conv_tf, (0, 3, 1, 2)) 93 | loss_tf = tf.reduce_sum(trans_tf) 94 | grad_tf = tf.gradients(loss_tf, image_tf)[0] 95 | 96 | with tf.Session() as sess: 97 | crops_tf_data, grad_tf_data = sess.run( 98 | (trans_tf, grad_tf), feed_dict={image_tf: image_data, boxes: boxes_data, box_index: box_index_data} 99 | ) 100 | 101 | crops_diff = np.abs(crops_tf_data - crops_torch_data) 102 | print('forward (maxval, min_err, max_err, mean_err):', 103 | crops_tf_data.max(), crops_diff.min(), crops_diff.max(), crops_diff.mean()) 104 | 105 | grad_diff = np.abs(grad_tf_data - grad_torch_data) 106 | print('backward (maxval, min_err, max_err, mean_err):', 107 | grad_tf_data.max(), grad_diff.min(), grad_diff.max(), grad_diff.mean()) 108 | 109 | 110 | def test_roialign(is_cuda=True): 111 | # generate data 112 | crop_height = 3 113 | crop_width = 3 114 | image_data, boxes_data, box_index_data = generate_data( 115 | batch_size=2, 116 | depth=2, 117 | im_height=10, 118 | im_width=10, 119 | n_boxes=2, 120 | xyxy=True, box_normalize=False) 121 | max_inp = np.abs(image_data).max() 122 | print('max_input:', max_inp) 123 | 124 | image_torch = to_varabile(image_data, requires_grad=True, is_cuda=is_cuda) 125 | boxes = to_varabile(boxes_data, requires_grad=False, is_cuda=is_cuda) 126 | box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda) 127 | 128 | roi_align = RoIAlign(crop_height, crop_width, transform_fpcoor=False) 129 | gradcheck(roi_align, (image_torch, boxes, box_index), eps=max_inp/500) 130 | 131 | print('test ok') 132 | 133 | 134 | if __name__ == '__main__': 135 | def main(): 136 | crop_height = 7 137 | crop_width = 7 138 | is_cuda = torch.cuda.is_available() 139 | 140 | if tf is not None: 141 | compare_with_tf(crop_height, crop_width, is_cuda=is_cuda) 142 | else: 143 | print('without tensorflow') 144 | 145 | test_roialign(is_cuda=is_cuda) 146 | 147 | main() 148 | -------------------------------------------------------------------------------- /model/component/roi_align/tests/test2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | from roi_align.roi_align import RoIAlign 6 | 7 | 8 | def to_varabile(arr, requires_grad=False, is_cuda=True): 9 | tensor = torch.from_numpy(arr) 10 | if is_cuda: 11 | tensor = tensor.cuda() 12 | var = Variable(tensor, requires_grad=requires_grad) 13 | return var 14 | 15 | 16 | # the data you want 17 | is_cuda = False 18 | image_data = np.tile(np.arange(7, dtype=np.float32), 7).reshape(7, 7) 19 | image_data = image_data[np.newaxis, np.newaxis] 20 | boxes_data = np.asarray([[0, 0, 3, 3]], dtype=np.float32) 21 | box_index_data = np.asarray([0], dtype=np.int32) 22 | 23 | image_torch = to_varabile(image_data, requires_grad=True, is_cuda=is_cuda) 24 | boxes = to_varabile(boxes_data, requires_grad=False, is_cuda=is_cuda) 25 | box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda) 26 | 27 | # set transform_fpcoor to False is the crop_and_resize 28 | roi_align = RoIAlign(3, 3, transform_fpcoor=True) 29 | print(roi_align(image_torch, boxes, box_index)) -------------------------------------------------------------------------------- /model/component/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class RotationClassifier(nn.Module): 8 | def __init__(self, args): 9 | super(RotationClassifier, self).__init__() 10 | self.num_class = 4 11 | self.fc = nn.Linear(512, self.num_class) 12 | self.loss = nn.CrossEntropyLoss(ignore_index=-1) 13 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 14 | 15 | def forward(self, x): 16 | feature, labels = x 17 | pooled_feature = self.global_pool(feature) 18 | pred = self.fc(pooled_feature.view(-1, 512)) 19 | loss = self.loss(pred, labels) 20 | 21 | return loss 22 | -------------------------------------------------------------------------------- /model/component/scene.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class SceneClassifier(nn.Module): 8 | def __init__(self, args): 9 | super(SceneClassifier, self).__init__() 10 | self.scene_num = -1 11 | self.in_dim = args.feat_dim 12 | self.feat_dim = -1 13 | for supervision in args.supervision: 14 | if supervision['name'] == 'scene': 15 | self.scene_num = supervision['other']['scene_num'] 16 | self.pool_size = supervision['other']['pool_size'] 17 | self.pool = nn.AdaptiveAvgPool2d(self.pool_size) 18 | self.fc = nn.Linear(self.pool_size * self.pool_size * self.in_dim, self.scene_num) 19 | self.loss = nn.CrossEntropyLoss() 20 | self.mode = 'train' 21 | 22 | def forward(self, agg_data): 23 | """ 24 | forward pipeline, compute loss function 25 | :param agg_data: refer to ../base_model.py 26 | :return: loss 27 | """ 28 | if self.mode == 'diagnosis': 29 | return self.diagnosis(agg_data) 30 | loss_sum = 0 31 | x = agg_data['feature_map'] 32 | scene = agg_data['scene'].long() 33 | 34 | # Shallow supervision only 35 | x = self.pool(x.unsqueeze(0)).view(-1) 36 | label = scene 37 | score = self.fc(x) 38 | loss_sum += self.loss(score.unsqueeze(0), label.unsqueeze(0)) 39 | return loss_sum 40 | -------------------------------------------------------------------------------- /model/component/seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import json 6 | import math 7 | 8 | 9 | class MaskPredictor(nn.Module): 10 | def __init__(self, args): 11 | super(MaskPredictor, self).__init__() 12 | self.in_dim = args.feat_dim 13 | self.args = args 14 | self.down_sampling_rate = args.down_sampling_rate 15 | 16 | self.fc1 = nn.Conv2d(self.in_dim, self.args.num_base_class + 1, kernel_size=3, stride=1, padding=1) 17 | 18 | self.base_classes = json.load(open('data/ADE/ADE_Origin/base_list.json', 'r')) 19 | 20 | @staticmethod 21 | def compute_anchor_location(anchor, scale, original_scale): 22 | anchor = np.array(anchor.detach().cpu()) 23 | original_scale = np.array(original_scale) 24 | scale = np.array(scale.cpu()) 25 | anchor[:, 2] = np.floor(anchor[:, 2] * scale[0] * original_scale[0]) 26 | anchor[:, 3] = np.ceil(anchor[:, 3] * scale[0] * original_scale[0]) 27 | anchor[:, 0] = np.floor(anchor[:, 0] * scale[1] * original_scale[1]) 28 | anchor[:, 1] = np.ceil(anchor[:, 1] * scale[1] * original_scale[1]) 29 | return anchor.astype(np.int) 30 | 31 | @staticmethod 32 | def binary_transform(mask, label): 33 | return mask[:, int(label.item()), :, :] 34 | 35 | def forward(self, agg_input): 36 | """ 37 | take in the feature map and make predictions 38 | :param agg_input: input data 39 | :return: loss averaged over instances 40 | """ 41 | feature_map = agg_input['feature_map'] 42 | mask = agg_input['seg'] 43 | feature_map = feature_map.unsqueeze(0) 44 | predicted_map = self.fc1(feature_map) 45 | predicted_map = F.interpolate(predicted_map, size=(mask.shape[0], mask.shape[1]), mode='nearest') 46 | mask = mask.unsqueeze(0) 47 | weight = torch.ones(self.args.num_base_class + 1).cuda() 48 | weight[-1] = 0.1 49 | 50 | loss = F.cross_entropy(predicted_map, mask.long(), weight=weight) 51 | return loss 52 | -------------------------------------------------------------------------------- /model/novel_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from torch.nn.functional import cosine_similarity 6 | import torch.nn.functional as F 7 | 8 | 9 | class NovelClassifier(nn.Module): 10 | def __init__(self, args): 11 | super(NovelClassifier, self).__init__() 12 | self.in_dim = args.crop_height * args.crop_width * args.feat_dim 13 | self.num_class = args.num_novel_class 14 | self.fc = nn.Linear(self.in_dim, self.num_class) 15 | self.loss = nn.CrossEntropyLoss(ignore_index=-1) 16 | self.mode = 'train' 17 | 18 | if hasattr(args, 'range_of_compute'): 19 | self.range_of_compute = args.range_of_compute 20 | 21 | def predict(self, x): 22 | feature = x['feature'] 23 | label = x['label'].long() 24 | pred = self.fc(feature) 25 | return self.acc(pred, label) 26 | 27 | def probability(self, x): 28 | feature = x['feature'] 29 | label = x['label'].long() 30 | pred = self.fc(feature) 31 | prob = F.softmax(pred) 32 | return prob, label 33 | 34 | def forward(self, x): 35 | if self.mode == 'val': 36 | return self.predict(x) 37 | elif self.mode == 'diagnosis': 38 | return self.diagnosis(x) 39 | elif self.mode == 'prob': 40 | return self.probability(x) 41 | 42 | feature = x['feature'] 43 | label = x['label'].long() 44 | pred = self.fc(feature) 45 | loss = self.loss(pred, label) 46 | acc = self._acc(pred, label) 47 | return loss, acc 48 | 49 | def _acc(self, pred, label): 50 | _, preds = torch.max(pred, dim=1) 51 | valid = (label >= 0).long() 52 | acc_sum = torch.sum(valid * (preds == label).long()) 53 | instance_sum = torch.sum(valid) 54 | acc = acc_sum.float() / (instance_sum.float() + 1e-10) 55 | return acc 56 | 57 | def acc(self, pred, label): 58 | category_acc = torch.zeros(2, self.num_class).cuda() 59 | acc_sum = 0 60 | num = pred.shape[0] 61 | preds = np.array(pred.detach().cpu()) 62 | preds = np.argsort(preds) 63 | label = np.array(label.detach().cpu()) 64 | for i in range(num): 65 | category_acc[1, label[i]] += 1 66 | if label[i] in preds[i, -self.range_of_compute:]: 67 | acc_sum += 1 68 | category_acc[0, label[i]] += 1 69 | acc = torch.tensor(acc_sum / (num + 1e-10)).cuda() 70 | return acc, category_acc 71 | 72 | 73 | class NovelCosClassifier(nn.Module): 74 | def __init__(self, args): 75 | super(NovelCosClassifier, self).__init__() 76 | self.in_dim = args.crop_height * args.crop_width * args.feat_dim 77 | self.num_class = args.num_novel_class 78 | self.fc = nn.Linear(self.in_dim, self.num_class) 79 | self.loss = nn.CrossEntropyLoss(ignore_index=-1) 80 | self.mode = 'train' 81 | 82 | self.t = torch.ones(1).cuda() * 10 83 | self.weight = nn.Parameter(torch.Tensor(self.num_class, self.in_dim)) 84 | self.reset_parameters() 85 | 86 | if hasattr(args, 'range_of_compute'): 87 | self.range_of_compute = args.range_of_compute 88 | 89 | def reset_parameters(self): 90 | stdv = 1. / math.sqrt(self.weight.size(1)) 91 | self.weight.data.uniform_(-stdv, stdv) 92 | 93 | def predict(self, x): 94 | feature = x['feature'] 95 | label = x['label'].long() 96 | xx = feature 97 | batch_size = xx.size(0) 98 | pred = self.t.cuda() * cosine_similarity( 99 | xx.unsqueeze(1).expand(batch_size, self.num_class, self.in_dim), 100 | self.weight.unsqueeze(0).expand(batch_size, self.num_class, self.in_dim).cuda(), 2) 101 | return self.acc(pred, label) 102 | 103 | def forward(self, x): 104 | if self.mode == 'val': 105 | return self.predict(x) 106 | elif self.mode == 'diagnosis': 107 | return self.diagnosis(x) 108 | 109 | feature = x['feature'] 110 | label = x['label'].long() 111 | 112 | xx = feature 113 | batch_size = xx.size(0) 114 | pred = self.t.cuda() * cosine_similarity( 115 | xx.unsqueeze(1).expand(batch_size, self.num_class, self.in_dim), 116 | self.weight.unsqueeze(0).expand(batch_size, self.num_class, self.in_dim).cuda(), 2) 117 | loss = self.loss(pred, label) 118 | 119 | acc = self._acc(pred, label) 120 | return loss, acc 121 | 122 | def _acc(self, pred, label): 123 | _, preds = torch.max(pred, dim=1) 124 | valid = (label >= 0).long() 125 | acc_sum = torch.sum(valid * (preds == label).long()) 126 | instance_sum = torch.sum(valid) 127 | acc = acc_sum.float() / (instance_sum.float() + 1e-10) 128 | return acc 129 | 130 | def acc(self, pred, label): 131 | category_acc = torch.zeros(2, self.num_class).cuda() 132 | acc_sum = 0 133 | num = pred.shape[0] 134 | preds = np.array(pred.detach().cpu()) 135 | preds = np.argsort(preds) 136 | label = np.array(label.detach().cpu()) 137 | for i in range(num): 138 | category_acc[1, label[i]] += 1 139 | if label[i] in preds[i, -self.range_of_compute:]: 140 | acc_sum += 1 141 | category_acc[0, label[i]] += 1 142 | acc = torch.tensor(acc_sum / (num + 1e-10)).cuda() 143 | return acc, category_acc 144 | -------------------------------------------------------------------------------- /model/parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinahHu/ADE-FewShot/41dc9cc481bfaf3bd9fb8bd76c1e63fcf127339d/model/parallel/__init__.py -------------------------------------------------------------------------------- /model/parallel/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def register_slave(self, identifier): 77 | """ 78 | Register an slave device. 79 | Args: 80 | identifier: an identifier, usually is the device id. 81 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 82 | """ 83 | if self._activated: 84 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 85 | self._activated = False 86 | self._registry.clear() 87 | future = FutureResult() 88 | self._registry[identifier] = _MasterRegistry(future) 89 | return SlavePipe(identifier, self._queue, future) 90 | 91 | def run_master(self, master_msg): 92 | """ 93 | Main entry for the master device in each forward pass. 94 | The messages were first collected from each devices (including the master device), and then 95 | an callback will be invoked to compute the message to be sent back to each devices 96 | (including the master device). 97 | Args: 98 | master_msg: the message that the master want to send to itself. This will be placed as the first 99 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 100 | Returns: the message to be sent back to the master device. 101 | """ 102 | self._activated = True 103 | 104 | intermediates = [(0, master_msg)] 105 | for i in range(self.nr_slaves): 106 | intermediates.append(self._queue.get()) 107 | 108 | results = self._master_callback(intermediates) 109 | assert results[0][0] == 0, 'The first result should belongs to the master.' 110 | 111 | for i, res in results: 112 | if i == 0: 113 | continue 114 | self._registry[i].result.put(res) 115 | 116 | for i in range(self.nr_slaves): 117 | assert self._queue.get() is True 118 | 119 | return results[0][1] 120 | 121 | @property 122 | def nr_slaves(self): 123 | return len(self._registry) -------------------------------------------------------------------------------- /model/parallel/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /preprocessing/addcontext.py: -------------------------------------------------------------------------------- 1 | import random 2 | random.seed(73) 3 | 4 | 5 | def enlarge(x1, x2, bound, ratio): 6 | l = x2 - x1 7 | p = int(l * ratio - l) 8 | 9 | d1 = random.randint(0, min(x1, p)) 10 | d2 = p - d1 11 | if d1 < 0: 12 | print("d1 < 0 error!") 13 | if d1 < 0: 14 | print("d2 < 0 error!") 15 | 16 | x1 -= d1 17 | x2 += d2 18 | 19 | return [x1, x2] 20 | 21 | 22 | def add_context(args, box, shape): 23 | Ratio = args.ratio 24 | l, r, u, d = box 25 | H = shape[0] 26 | W = shape[1] 27 | 28 | boxArea = (r - l) * (d - u) 29 | xthd = W / (r - l) 30 | ythd = H / (d - u) 31 | if H * W / boxArea < Ratio: 32 | Ratio = H * W / boxArea 33 | 34 | if random.randint(0, 1) == 0: 35 | xratio = random.uniform(1.0, min(Ratio, xthd)) 36 | yratio = Ratio / xratio 37 | else: 38 | yratio = random.uniform(1.0, min(Ratio, ythd)) 39 | xratio = Ratio / yratio 40 | 41 | #xratio = 1.76 42 | #yratio = 1.54 43 | 44 | l, r = enlarge(l, r, W, xratio) 45 | u, d = enlarge(u, d, H, yratio) 46 | box = [l, r, u, d] 47 | 48 | return box 49 | -------------------------------------------------------------------------------- /preprocessing/generate_list.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import argparse 4 | import os 5 | import random 6 | import math 7 | from addcontext import add_context 8 | 9 | 10 | def base_generation(args): 11 | origin_path = os.path.join(args.root_dataset, args.origin_dataset) 12 | supervision_path = os.path.join(args.root_dataset, args.supervision_dataset) 13 | base_set_path = os.path.join(origin_path, 'base_set.json') 14 | img_path_path = os.path.join(origin_path, 'img_path.json') 15 | img_path2size_path = os.path.join(origin_path, 'img_path2size.json') 16 | base_list_path = os.path.join(origin_path, 'base_list.json') 17 | f = open(base_set_path, 'r') 18 | base_set = json.load(f) 19 | f.close() 20 | f = open(img_path_path, 'r') 21 | img_path = json.load(f) 22 | f.close() 23 | f = open(img_path2size_path, 'r') 24 | img_path2size = json.load(f) 25 | f.close() 26 | f = open(base_list_path, 'r') 27 | base_list = json.load(f) 28 | f.close() 29 | 30 | # get other supervision 31 | supervision_path_names = [] 32 | for supervision_src in args.supervision_src: 33 | supervision_path_names.append({'name': supervision_src['name'], 'path': supervision_src['path'], 34 | 'type': supervision_src['type']}) 35 | supervision_contents = [] 36 | for supervision in supervision_path_names: 37 | path = os.path.join(supervision_path, supervision['path']) 38 | f = open(path, 'r') 39 | data = json.load(f) 40 | supervision_contents.append({'data': data, 'name': supervision['name'], 'type': supervision['type']}) 41 | f.close() 42 | 43 | # initialize the sample list,add image level information 44 | sample_list_train = [dict() for i in range(len(img_path))] 45 | for i in range(len(img_path)): 46 | sample_list_train[i]['fpath_img'] = img_path[i] 47 | sample_list_train[i]['height'], sample_list_train[i]['width'] = \ 48 | img_path2size[img_path[i]] 49 | sample_list_train[i]['index'] = i 50 | sample_list_train[i]['anchors'] = [] 51 | # add image level information 52 | for j in range(len(supervision_contents)): 53 | if supervision_contents[j]['type'] == 'img': 54 | sample_list_train[i][supervision_contents[j]['name']] = supervision_contents[j]['data'][i] 55 | 56 | sample_list_val = [dict() for i in range(len(img_path))] 57 | for i in range(len(img_path)): 58 | sample_list_val[i]['fpath_img'] = img_path[i] 59 | sample_list_val[i]['height'], sample_list_val[i]['width'] = \ 60 | img_path2size[img_path[i]] 61 | sample_list_val[i]['index'] = i 62 | sample_list_val[i]['anchors'] = [] 63 | 64 | # get the category information to split train and val 65 | # add supervision information to instance level 66 | all_list = [[] for category in base_list] 67 | for i, obj in enumerate(base_set): 68 | img_index = int(obj["img"]) 69 | category = base_list.index(int(obj["obj"])) 70 | box = obj["box"] 71 | path = img_path[int(obj["img"])] 72 | shape = img_path2size[path] 73 | if args.context: 74 | box = add_context(args, box, shape) 75 | annotation = {"img": img_index, "obj": category, "box": box} 76 | for supervision in supervision_contents: 77 | if supervision['type'] == 'inst': 78 | data = supervision['data'][i] 79 | annotation[supervision['name']] = data 80 | all_list[category].append(annotation) 81 | 82 | random.seed(73) 83 | for category in range(len(base_list)): 84 | if all_list[category] is []: 85 | continue 86 | random.shuffle(all_list[category]) 87 | 88 | # split into train and val 89 | for i in range(len(base_list)): 90 | length = len(all_list[i]) 91 | if length == 0: 92 | continue 93 | for j in range(0, math.ceil(5 * length / 6)): 94 | img_index = all_list[i][j]['img'] 95 | anchor = dict() 96 | anchor['anchor'] = all_list[i][j]['box'] 97 | anchor['label'] = i 98 | # add instance level supervision for train 99 | for supervision in supervision_contents: 100 | if supervision['type'] == 'inst': 101 | anchor[supervision['name']] = all_list[i][j][supervision['name']] 102 | sample_list_train[img_index]['anchors'].append(anchor) 103 | 104 | for j in range(math.ceil(5 * length / 6), length): 105 | img_index = all_list[i][j]['img'] 106 | sample_list_val[img_index]['anchors'].append({'anchor': all_list[i][j]['box'], 'label': i}) 107 | 108 | output_path = os.path.join(args.root_dataset, args.output) 109 | output_train = os.path.join(output_path, 'base_img_train.json') 110 | f = open(output_train, 'w') 111 | json.dump(sample_list_train, f) 112 | f.close() 113 | output_val = os.path.join(output_path, 'base_img_val.json') 114 | f = open(output_val, 'w') 115 | json.dump(sample_list_val, f) 116 | f.close() 117 | 118 | 119 | def novel_generation(args): 120 | origin_path = os.path.join(args.root_dataset, args.origin_dataset) 121 | novel_set_path = os.path.join(origin_path, 'novel_set.json') 122 | img_path_path = os.path.join(origin_path, 'img_path.json') 123 | img_path2size_path = os.path.join(origin_path, 'img_path2size.json') 124 | novel_list_path = os.path.join(origin_path, 'novel_val_list.json') 125 | f = open(novel_set_path, 'r') 126 | novel_set = json.load(f) 127 | f.close() 128 | f = open(img_path_path, 'r') 129 | img_path = json.load(f) 130 | f.close() 131 | f = open(img_path2size_path, 'r') 132 | img_path2size = json.load(f) 133 | f.close() 134 | f = open(novel_list_path, 'r') 135 | novel_list = json.load(f) 136 | f.close() 137 | 138 | # initialize the sample list 139 | sample_list_train = [dict() for i in range(len(img_path))] 140 | for i in range(len(img_path)): 141 | sample_list_train[i]['fpath_img'] = img_path[i] 142 | sample_list_train[i]['height'], sample_list_train[i]['width'] = \ 143 | img_path2size[img_path[i]] 144 | sample_list_train[i]['index'] = i 145 | sample_list_train[i]['anchors'] = [] 146 | sample_list_val = [dict() for i in range(len(img_path))] 147 | for i in range(len(img_path)): 148 | sample_list_val[i]['fpath_img'] = img_path[i] 149 | sample_list_val[i]['height'], sample_list_val[i]['width'] = \ 150 | img_path2size[img_path[i]] 151 | sample_list_val[i]['index'] = i 152 | sample_list_val[i]['anchors'] = [] 153 | 154 | # get the category information to split train and val 155 | all_list = [[] for category in novel_list] 156 | for obj in novel_set: 157 | img_index = int(obj["img"]) 158 | if int(obj["obj"]) not in novel_list: 159 | continue 160 | category = novel_list.index(int(obj["obj"])) 161 | box = obj["box"] 162 | path = img_path[int(obj["img"])] 163 | shape = img_path2size[path] 164 | if args.context: 165 | box = add_context(args, box, shape) 166 | annotation = {"img": img_index, "obj": category, "box": box} 167 | all_list[category].append(annotation) 168 | 169 | random.seed(73) 170 | for category in range(len(novel_list)): 171 | if all_list[category] is []: 172 | continue 173 | random.shuffle(all_list[category]) 174 | 175 | # split into train and val 176 | for i in range(len(novel_list)): 177 | length = len(all_list[i]) 178 | if length == 0: 179 | continue 180 | for j in range(0, args.shot): 181 | img_index = all_list[i][j]['img'] 182 | sample_list_train[img_index]['anchors'].append({'anchor': all_list[i][j]['box'], 'label': i}) 183 | 184 | for j in range(args.shot, length): 185 | img_index = all_list[i][j]['img'] 186 | sample_list_val[img_index]['anchors'].append({'anchor': all_list[i][j]['box'], 'label': i}) 187 | 188 | output_path = os.path.join(args.root_dataset, args.output) 189 | output_train = os.path.join(output_path, 'novel_img_train.json') 190 | f = open(output_train, 'w') 191 | json.dump(sample_list_train, f) 192 | f.close() 193 | output_val = os.path.join(output_path, 'novel_img_val.json') 194 | f = open(output_val, 'w') 195 | json.dump(sample_list_val, f) 196 | f.close() 197 | 198 | 199 | if __name__ == '__main__': 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument('-root_dataset', default='../data/ADE', help='data file') 202 | parser.add_argument('-origin_dataset', default='ADE_Origin/', help='origin dir') 203 | parser.add_argument('--supervision_dataset', default='ADE_Supervision/', help='supervision information') 204 | parser.add_argument('-part', default='Base', help='Base or Novel') 205 | parser.add_argument('-shot', default=5, help='shot in Novel') 206 | parser.add_argument('-img_size', default='img_path2size.json', help='img size file') 207 | parser.add_argument('--supervision_src', default=json.load(open('./supervision.json', 'r')), type=list) 208 | parser.add_argument('-context', type=bool, default=True) 209 | parser.add_argument('-ratio', type=float, default=2.7) 210 | # example [{'type': 'img', 'name': 'seg', 'path': '1.json'}, 211 | # {'type': 'inst', 'name': 'attr', 'path': 'attr.json'}] 212 | 213 | args = parser.parse_args() 214 | setattr(args, 'output', 'ADE_' + args.part) 215 | if args.part == 'Base': 216 | base_generation(args) 217 | elif args.part == 'Novel': 218 | novel_generation(args) -------------------------------------------------------------------------------- /preprocessing/supervision.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"name": "attr", "type": "inst", "path": "attr.json", "weight": 50, "other": {"attr_num": 159}}, 3 | {"name": "seg", "type": "img", "path": "seg.json"}, 4 | {"name": "bbox", "type": "inst", "path": "bbox.json"}, 5 | {"name": "scene", "type": "img", "path": "scene.json", "other": {"scene_num": 1739}}, 6 | {"name": "hierarchy", "type": "inst", "path": "hierarchy.json", "weight": 15, "other": {}} 7 | ] 8 | -------------------------------------------------------------------------------- /preprocessing/supervison_generation/attr.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate original data file for attribute 3 | Format: 4 | [[attr_index for each sample], [..], ...] 5 | """ 6 | import json 7 | import numpy as np 8 | import os 9 | import argparse 10 | 11 | 12 | def generate_attr(args): 13 | f = open(args.attr_file, 'r') 14 | attr = json.load(f) 15 | f.close() 16 | f = open(args.base_set, 'r') 17 | base_set = json.load(f) 18 | f.close() 19 | f = open(args.base_list, 'r') 20 | base_list = json.load(f) 21 | f.close() 22 | 23 | attr_list = [] 24 | for i, sample in enumerate(base_set): 25 | category = base_list.index(int(sample["obj"])) 26 | attr_list.append(attr[category]) 27 | 28 | f = open(args.output, 'w') 29 | json.dump(attr_list, f) 30 | f.close() 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--attr_file', default='../../data/ADE/ADE_Origin/part.json') 36 | parser.add_argument('--output', default='../../data/ADE/ADE_Supervision/part.json') 37 | parser.add_argument('--base_set', default='../../data/ADE/ADE_Origin/base_set.json') 38 | parser.add_argument('--base_list', default='../../data/ADE/ADE_Origin/base_list.json') 39 | parser.add_argument('--attr_num', default=113) 40 | 41 | args = parser.parse_args() 42 | 43 | generate_attr(args) -------------------------------------------------------------------------------- /preprocessing/supervison_generation/bbox.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate original data file for bounding boxes 3 | Format: 4 | [[anchor for each sample], [..], ...] 5 | """ 6 | import json 7 | import numpy as np 8 | import os 9 | import argparse 10 | 11 | 12 | def generate_bbox(args): 13 | f = open(args.base_set, 'r') 14 | base_set = json.load(f) 15 | f.close() 16 | 17 | bbox_list = [] 18 | for i, sample in enumerate(base_set): 19 | bbox_list.append(sample['box']) 20 | f = open(args.output, 'w') 21 | json.dump(bbox_list, f) 22 | f.close() 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--base_set', default='../../data/ADE/ADE_Origin/base_set.json') 28 | parser.add_argument('--base_list', default='../../data/ADE/ADE_Origin/base_list.json') 29 | parser.add_argument('--output', default='../../data/ADE/ADE_Supervision/bbox.json') 30 | 31 | args = parser.parse_args() 32 | 33 | generate_bbox(args) 34 | -------------------------------------------------------------------------------- /preprocessing/supervison_generation/bkg.py: -------------------------------------------------------------------------------- 1 | """ 2 | get a file for the segmentation mask with bkg 3 | store the path of the segmentation images 4 | """ 5 | import json 6 | import argparse 7 | import os 8 | import numpy as np 9 | import cv2 10 | 11 | 12 | def generate_bkg(args): 13 | """ 14 | get the img path, replace the tail of file name into mask type 15 | :param args: argument 16 | :return: nothing, store the segment file in the allocated path 17 | """ 18 | f = open(args.img_path_file, 'r') 19 | img_paths = json.load(f) 20 | f.close() 21 | 22 | f = open('../../data/ADE/ADE_Origin/all_list.json') 23 | base_list = json.load(f) 24 | f.close() 25 | 26 | base_map = {} 27 | for i in range(len(base_list)): 28 | base_map[base_list[i]] = i 29 | base_set = set(base_map.keys()) 30 | 31 | seg_paths = [] 32 | length = len(img_paths) 33 | for i, img_path in enumerate(img_paths): 34 | seg_path_original = img_path[:-4] + '_seg.png' 35 | if not os.path.exists(os.path.join('../../../' + seg_path_original)): 36 | raise RuntimeError('{} not exists'.format(seg_path_original)) 37 | seg_path = img_path[:-4] + '_seg_base.png' 38 | seg_paths.append(seg_path) 39 | 40 | segmentation = cv2.imread(os.path.join('../../../' + seg_path_original)) 41 | B, G, R = np.transpose(segmentation, (2, 0, 1)) 42 | seg_map = (G + 256 * (R / 10)) 43 | seg_map_new = np.zeros((seg_map.shape[0], seg_map.shape[1], 3)) 44 | H, W = seg_map.shape 45 | for h in range(H): 46 | for w in range(W): 47 | val = base_map[seg_map[h, w]] 48 | p0 = val % 256 49 | val = val // 256 50 | p1 = val % 256 51 | p2 = val // 256 52 | seg_map_new[h, w, 0] = p0 53 | seg_map_new[h, w, 1] = p1 54 | seg_map_new[h, w, 2] = p2 55 | cv2.imwrite('../../../' + seg_path, seg_map_new.astype(np.uint8)) 56 | 57 | if i % 1 == 0: 58 | print('{} / {}'.format(i, length)) 59 | 60 | f = open(args.output, 'w') 61 | json.dump(seg_paths, f) 62 | f.close() 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--img_path_file', default='../../data/ADE/ADE_Origin/img_path.json') 68 | parser.add_argument('--output', default='../../data/ADE/ADE_Supervision/bkg.json') 69 | args = parser.parse_args() 70 | 71 | generate_bkg(args) 72 | -------------------------------------------------------------------------------- /preprocessing/supervison_generation/hierarchy.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate original data file for hierarchy 3 | Format: 4 | [[[class_index for level i], ... ], [..], ...] 5 | """ 6 | import json 7 | import numpy as np 8 | import os 9 | import argparse 10 | 11 | 12 | def generate_hierarchy(args): 13 | f = open(args.hierarchy_file, 'r') 14 | hierarchy = json.load(f) 15 | f.close() 16 | f = open(args.base_set, 'r') 17 | base_set = json.load(f) 18 | f.close() 19 | f = open(args.base_list, 'r') 20 | base_list = json.load(f) 21 | f.close() 22 | 23 | hierarchy_list = [] 24 | for i, sample in enumerate(base_set): 25 | category = base_list.index(int(sample["obj"])) 26 | hierarchy_list.append(hierarchy[category]) 27 | 28 | f = open(args.output, 'w') 29 | json.dump(hierarchy_list, f) 30 | f.close() 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--hierarchy_file', default='../../data/ADE/ADE_Origin/hierarchy.json') 36 | parser.add_argument('--output', default='../../data/ADE/ADE_Supervision/hierarchy.json') 37 | parser.add_argument('--base_set', default='../../data/ADE/ADE_Origin/base_set.json') 38 | parser.add_argument('--base_list', default='../../data/ADE/ADE_Origin/base_list.json') 39 | parser.add_argument('--layer_width', default='../../data/ADE/ADE_Origin/layer_width.json') 40 | 41 | args = parser.parse_args() 42 | 43 | generate_hierarchy(args) -------------------------------------------------------------------------------- /preprocessing/supervison_generation/scene.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate original data file for hierarchy 3 | Format: 4 | [[[class_index for level i], ... ], [..], ...] 5 | """ 6 | import json 7 | import numpy as np 8 | import os 9 | import argparse 10 | import re 11 | 12 | def generate_hierarchy(args): 13 | f = open(args.scene_file, 'r') 14 | scene = json.load(f) 15 | f.close() 16 | 17 | f = open(args.output, 'w') 18 | json.dump(scene, f) 19 | f.close() 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--scene_file', default='../../data/ADE/ADE_Origin/scene.json') 25 | parser.add_argument('--output', default='../../data/ADE/ADE_Supervision/scene.json') 26 | 27 | args = parser.parse_args() 28 | 29 | generate_hierarchy(args) -------------------------------------------------------------------------------- /preprocessing/supervison_generation/seg.py: -------------------------------------------------------------------------------- 1 | """ 2 | get a file for the segmentation mask 3 | store the path of the segmentation images 4 | """ 5 | import json 6 | import argparse 7 | import os 8 | import numpy as np 9 | import cv2 10 | 11 | 12 | def generate_seg(args): 13 | """ 14 | get the img path, replace the tail of file name into mask type 15 | :param args: argument 16 | :return: nothing, store the segment file in the allocated path 17 | """ 18 | f = open(args.img_path_file, 'r') 19 | img_paths = json.load(f) 20 | f.close() 21 | 22 | f = open('../../data/ADE/ADE_Origin/base_list.json') 23 | base_list = json.load(f) 24 | f.close() 25 | 26 | base_map = {} 27 | for i in range(len(base_list)): 28 | base_map[base_list[i]] = i 29 | base_set = set(base_map.keys()) 30 | 31 | seg_paths = [] 32 | length = len(img_paths) 33 | for i, img_path in enumerate(img_paths): 34 | seg_path_original = img_path[:-4] + '_seg.png' 35 | if not os.path.exists(os.path.join('../../../' + seg_path_original)): 36 | raise RuntimeError('{} not exists'.format(seg_path_original)) 37 | seg_path = img_path[:-4] + '_seg_base.png' 38 | seg_paths.append(seg_path) 39 | segmentation = cv2.imread(os.path.join('../../../' + seg_path_original)) 40 | B, G, R = np.transpose(segmentation, (2, 0, 1)) 41 | seg_map = (G + 256 * (R / 10)) 42 | H, W = seg_map.shape 43 | for h in range(H): 44 | for w in range(W): 45 | if seg_map[h, w] not in base_set: 46 | seg_map[h, w] = 189 47 | else: 48 | seg_map[h, w] = base_map[seg_map[h, w]] 49 | cv2.imwrite('../../../' + seg_path, seg_map.astype(np.uint8)) 50 | if i % 1 == 0: 51 | print('{} / {}'.format(i, length)) 52 | 53 | f = open(args.output, 'w') 54 | json.dump(seg_paths, f) 55 | f.close() 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--img_path_file', default='../../data/ADE/ADE_Origin/img_path.json') 61 | parser.add_argument('--output', default='../../data/ADE/ADE_Supervision/seg.json') 62 | args = parser.parse_args() 63 | 64 | generate_seg(args) 65 | -------------------------------------------------------------------------------- /supervision.json: -------------------------------------------------------------------------------- 1 | [{"name": "seg", "type": "img", "content": "map", "weight": 0.5, "lr": 1e-1, "other": {}}, 2 | {"name": "hierarchy", "type": "inst", "path": "hierarchy.json", "weight": 1.0, "lr": 1e-1, "other": {"layer_width": [10, 58, 145, 42]}}, 3 | {"name": "attr", "type": "inst", "content": "list", "weight": 25.0, "lr": 1e-1, "other": {"num_attr": 159}}, 4 | {"name": "part", "type": "inst", "path": "part.json", "weight": 25.0, "lr": 1e-1, "other": {"num_attr": 113}}, 5 | {"name": "scene", "type": "img", "content": "scene", "path": "scene.json", "weight": 0.2, "lr": 1e-1, "other": {"scene_num": 1739, "pool_size": 3 }}, 6 | {"name": "bbox", "type": "inst", "content": "list", "weight": 5.0, "lr": 1e-1, "other": {"pool_size": 3}}, 7 | {"name": "patch_location", "type": "self", "weight": 1, "lr": 1e-1, "other": {}}, 8 | {"name": "rotation", "type": "self", "weight": 10.0, "lr": 1e-1, "other": {}}] 9 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | srun -p long --gres=gpu:4 --pty python train.py --comment seg_attr_hierarchy_10 --ckpt ckpt/seg_attr_hierarchy_10/ --start_epoch 1 --num_epoch 5 --model_weight ../models/seg_attr_10.pth 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import functools 4 | import fnmatch 5 | import numpy as np 6 | import torch 7 | import json 8 | 9 | 10 | def find_recursive(root_dir, ext='.jpg'): 11 | files = [] 12 | for root, dirnames, filenames in os.walk(root_dir): 13 | for filename in fnmatch.filter(filenames, '*' + ext): 14 | files.append(os.path.join(root, filename)) 15 | return files 16 | 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value""" 20 | def __init__(self): 21 | self.initialized = False 22 | self.val = None 23 | self.avg = None 24 | self.sum = None 25 | self.count = None 26 | 27 | def initialize(self, val, weight): 28 | self.val = val 29 | self.avg = val 30 | self.sum = val * weight 31 | self.count = weight 32 | self.initialized = True 33 | 34 | def update(self, val, weight=1): 35 | if not self.initialized: 36 | self.initialize(val, weight) 37 | else: 38 | self.add(val, weight) 39 | 40 | def add(self, val, weight): 41 | self.val = val 42 | self.sum += val * weight 43 | self.count += weight 44 | self.avg = self.sum / self.count 45 | 46 | def value(self): 47 | return self.val 48 | 49 | def average(self): 50 | return self.avg 51 | 52 | 53 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 54 | ar = np.asanyarray(ar).flatten() 55 | 56 | optional_indices = return_index or return_inverse 57 | optional_returns = optional_indices or return_counts 58 | 59 | if ar.size == 0: 60 | if not optional_returns: 61 | ret = ar 62 | else: 63 | ret = (ar,) 64 | if return_index: 65 | ret += (np.empty(0, np.bool),) 66 | if return_inverse: 67 | ret += (np.empty(0, np.bool),) 68 | if return_counts: 69 | ret += (np.empty(0, np.intp),) 70 | return ret 71 | if optional_indices: 72 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 73 | aux = ar[perm] 74 | else: 75 | ar.sort() 76 | aux = ar 77 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 78 | 79 | if not optional_returns: 80 | ret = aux[flag] 81 | else: 82 | ret = (aux[flag],) 83 | if return_index: 84 | ret += (perm[flag],) 85 | if return_inverse: 86 | iflag = np.cumsum(flag) - 1 87 | inv_idx = np.empty(ar.shape, dtype=np.intp) 88 | inv_idx[perm] = iflag 89 | ret += (inv_idx,) 90 | if return_counts: 91 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 92 | ret += (np.diff(idx),) 93 | return ret 94 | 95 | 96 | def colorEncode(labelmap, colors, mode='BGR'): 97 | labelmap = labelmap.astype('int') 98 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 99 | dtype=np.uint8) 100 | for label in unique(labelmap): 101 | if label < 0: 102 | continue 103 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 104 | np.tile(colors[label], 105 | (labelmap.shape[0], labelmap.shape[1], 1)) 106 | 107 | if mode == 'BGR': 108 | return labelmap_rgb[:, :, ::-1] 109 | else: 110 | return labelmap_rgb 111 | 112 | 113 | def accuracy(preds, label): 114 | valid = (label >= 0) 115 | acc_sum = (valid * (preds == label)).sum() 116 | valid_sum = valid.sum() 117 | acc = float(acc_sum) / (valid_sum + 1e-10) 118 | return acc, valid_sum 119 | 120 | 121 | def intersectionAndUnion(imPred, imLab, numClass): 122 | imPred = np.asarray(imPred).copy() 123 | imLab = np.asarray(imLab).copy() 124 | 125 | imPred += 1 126 | imLab += 1 127 | # Remove classes from unlabeled pixels in gt image. 128 | # We should not penalize detections in unlabeled portions of the image. 129 | imPred = imPred * (imLab > 0) 130 | 131 | # Compute area intersection: 132 | intersection = imPred * (imPred == imLab) 133 | (area_intersection, _) = np.histogram( 134 | intersection, bins=numClass, range=(1, numClass)) 135 | 136 | # Compute area union: 137 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 138 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 139 | area_union = area_pred + area_lab - area_intersection 140 | 141 | return (area_intersection, area_union) 142 | 143 | 144 | class NotSupportedCliException(Exception): 145 | pass 146 | 147 | 148 | def process_range(xpu, inp): 149 | start, end = map(int, inp) 150 | if start > end: 151 | end, start = start, end 152 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) 153 | 154 | 155 | REGEX = [ 156 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), 157 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), 158 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), 159 | functools.partial(process_range, 'gpu')), 160 | (re.compile(r'^(\d+)-(\d+)$'), 161 | functools.partial(process_range, 'gpu')), 162 | ] 163 | 164 | 165 | def parse_devices(input_devices): 166 | 167 | """Parse user's devices input str to standard format. 168 | e.g. [gpu0, gpu1, ...] 169 | 170 | """ 171 | ret = [] 172 | for d in input_devices.split(','): 173 | for regex, func in REGEX: 174 | m = regex.match(d.lower().strip()) 175 | if m: 176 | tmp = func(m.groups()) 177 | # prevent duplicate 178 | for x in tmp: 179 | if x not in ret: 180 | ret.append(x) 181 | break 182 | else: 183 | raise NotSupportedCliException( 184 | 'Can not recognize device: "{}"'.format(d)) 185 | return ret 186 | 187 | 188 | def selective_load_weights(network, path): 189 | print("Load weight from {}".format(path)) 190 | network.load_state_dict( 191 | torch.load(path, map_location=lambda storage, loc: storage), strict=False) 192 | 193 | 194 | def set_fixed_weights(network): 195 | print("Set fixed weights") 196 | final_blocks = len(network.backbone.layer4._modules.keys()) 197 | if final_blocks == 2: 198 | print("No need to fix") 199 | return 200 | 201 | for param in network.backbone.parameters(): 202 | param.requires_grad = False 203 | final_blocks = len(network.backbone.layer4._modules.keys()) 204 | if final_blocks != 2: 205 | for param in network.backbone.layer4._modules[str(final_blocks-1)].parameters(): 206 | param.requires_grad = False 207 | print("Fix end") 208 | 209 | 210 | def category_acc(acc_data, args): 211 | acc = acc_data[0].float() / (acc_data[1].float() + 1e-10) 212 | return acc.mean() 213 | --------------------------------------------------------------------------------