├── .gitignore ├── LICENSE ├── LaMa_inpainting.ipynb ├── README.md ├── bin ├── analyze_errors.py ├── blur_predicts.py ├── calc_dataset_stats.py ├── debug │ └── analyze_overlapping_masks.sh ├── evaluate_predicts.py ├── evaluator_example.py ├── extract_masks.py ├── filter_sharded_dataset.py ├── gen_debug_mask_dataset.py ├── gen_mask_dataset.py ├── gen_mask_dataset_hydra.py ├── gen_outpainting_dataset.py ├── make_checkpoint.py ├── mask_example.py ├── paper_runfiles │ ├── blur_tests.sh │ ├── env.sh │ ├── find_best_checkpoint.py │ ├── generate_test_celeba-hq.sh │ ├── generate_test_ffhq.sh │ ├── generate_test_paris.sh │ ├── generate_test_paris_256.sh │ ├── generate_val_test.sh │ ├── predict_inner_features.sh │ └── update_test_data_stats.sh ├── predict.py ├── predict_inner_features.py ├── report_from_tb.py ├── sample_from_dataset.py ├── side_by_side.py ├── split_tar.py ├── to_jit.py └── train.py ├── conda_env.yml ├── configs ├── analyze_mask_errors.yaml ├── data_gen │ ├── random_medium_256.yaml │ ├── random_medium_512.yaml │ ├── random_thick_256.yaml │ ├── random_thick_512.yaml │ ├── random_thin_256.yaml │ └── random_thin_512.yaml ├── debug_mask_gen.yaml ├── eval1.yaml ├── eval2.yaml ├── eval2_cpu.yaml ├── eval2_gpu.yaml ├── eval2_jpg.yaml ├── eval2_segm.yaml ├── eval2_segm_test.yaml ├── eval2_test.yaml ├── places2-categories_157.txt ├── prediction │ └── default.yaml ├── test_large_30k.lst └── training │ ├── ablv2_work.yaml │ ├── ablv2_work_ffc075.yaml │ ├── ablv2_work_md.yaml │ ├── ablv2_work_no_fm.yaml │ ├── ablv2_work_no_segmpl.yaml │ ├── ablv2_work_no_segmpl_csdilirpl.yaml │ ├── ablv2_work_no_segmpl_csdilirpl_celeba_csdilirpl1_new.yaml │ ├── ablv2_work_no_segmpl_csirpl.yaml │ ├── ablv2_work_no_segmpl_csirpl_celeba_csirpl03_new.yaml │ ├── ablv2_work_no_segmpl_vgg.yaml │ ├── ablv2_work_no_segmpl_vgg_celeba_l2_vgg003_new.yaml │ ├── ablv2_work_nodil_segmpl.yaml │ ├── ablv2_work_small_holes.yaml │ ├── big-lama-celeba.yaml │ ├── big-lama-regular-celeba.yaml │ ├── big-lama-regular.yaml │ ├── big-lama.yaml │ ├── data │ ├── abl-02-thin-bb.yaml │ ├── abl-04-256-mh-dist-celeba.yaml │ ├── abl-04-256-mh-dist-web.yaml │ └── abl-04-256-mh-dist.yaml │ ├── discriminator │ └── pix2pixhd_nlayer.yaml │ ├── evaluator │ └── default_inpainted.yaml │ ├── generator │ ├── ffc_resnet_075.yaml │ ├── pix2pixhd_global.yaml │ ├── pix2pixhd_global_sigmoid.yaml │ └── pix2pixhd_multidilated_catin_4dil_9b.yaml │ ├── hydra │ ├── no_time.yaml │ └── overrides.yaml │ ├── lama-fourier-celeba.yaml │ ├── lama-fourier.yaml │ ├── lama-regular-celeba.yaml │ ├── lama-regular.yaml │ ├── lama_small_train_masks.yaml │ ├── location │ ├── celeba_example.yaml │ ├── docker.yaml │ └── places_example.yaml │ ├── optimizers │ └── default_optimizers.yaml │ ├── trainer │ ├── any_gpu_large_ssim_ddp_final.yaml │ ├── any_gpu_large_ssim_ddp_final_benchmark.yaml │ └── any_gpu_large_ssim_ddp_final_celeba.yaml │ └── visualizer │ └── directory.yaml ├── docker ├── 1_generate_masks_from_raw_images.sh ├── 2_predict_with_gpu.sh ├── 3_evaluate.sh ├── Dockerfile ├── Dockerfile-cuda111 ├── build-cuda111.sh ├── build.sh └── entrypoint.sh ├── fetch_data ├── celebahq_dataset_prepare.sh ├── celebahq_gen_masks.sh ├── eval_sampler.py ├── places_challenge_train_download.sh ├── places_standard_evaluation_prepare_data.sh ├── places_standard_test_val_gen_masks.sh ├── places_standard_test_val_prepare.sh ├── places_standard_test_val_sample.sh ├── places_standard_train_prepare.sh ├── sampler.py ├── train_shuffled.flist └── val_shuffled.flist ├── models ├── ade20k │ ├── __init__.py │ ├── base.py │ ├── color150.mat │ ├── mobilenet.py │ ├── object150_info.csv │ ├── resnet.py │ ├── segm_lib │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── batchnorm.py │ │ │ │ ├── comm.py │ │ │ │ ├── replicate.py │ │ │ │ ├── tests │ │ │ │ │ ├── test_numeric_batchnorm.py │ │ │ │ │ └── test_sync_batchnorm.py │ │ │ │ └── unittest.py │ │ │ └── parallel │ │ │ │ ├── __init__.py │ │ │ │ └── data_parallel.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── dataloader.py │ │ │ ├── dataset.py │ │ │ ├── distributed.py │ │ │ └── sampler.py │ │ │ └── th.py │ └── utils.py └── lpips_models │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── requirements.txt └── saicinpainting ├── __init__.py ├── evaluation ├── __init__.py ├── data.py ├── evaluator.py ├── losses │ ├── __init__.py │ ├── base_loss.py │ ├── fid │ │ ├── __init__.py │ │ ├── fid_score.py │ │ └── inception.py │ ├── lpips.py │ └── ssim.py ├── masks │ ├── README.md │ ├── __init__.py │ ├── countless │ │ ├── .gitignore │ │ ├── README.md │ │ ├── __init__.py │ │ ├── countless2d.py │ │ ├── countless3d.py │ │ ├── images │ │ │ ├── gcim.jpg │ │ │ ├── gray_segmentation.png │ │ │ ├── segmentation.png │ │ │ └── sparse.png │ │ ├── memprof │ │ │ ├── countless2d_gcim_N_1000.png │ │ │ ├── countless2d_quick_gcim_N_1000.png │ │ │ ├── countless3d.png │ │ │ ├── countless3d_dynamic.png │ │ │ ├── countless3d_dynamic_generalized.png │ │ │ └── countless3d_generalized.png │ │ ├── requirements.txt │ │ └── test.py │ └── mask.py ├── refinement.py ├── utils.py └── vis.py ├── training ├── __init__.py ├── data │ ├── __init__.py │ ├── aug.py │ ├── datasets.py │ └── masks.py ├── losses │ ├── __init__.py │ ├── adversarial.py │ ├── constants.py │ ├── distance_weighting.py │ ├── feature_matching.py │ ├── perceptual.py │ ├── segmentation.py │ └── style_loss.py ├── modules │ ├── __init__.py │ ├── base.py │ ├── depthwise_sep_conv.py │ ├── fake_fakes.py │ ├── ffc.py │ ├── multidilated_conv.py │ ├── multiscale.py │ ├── pix2pixhd.py │ ├── spatial_transform.py │ └── squeeze_excitation.py ├── trainers │ ├── __init__.py │ ├── base.py │ └── default.py └── visualizers │ ├── __init__.py │ ├── base.py │ ├── colors.py │ ├── directory.py │ └── noop.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # temporary files 132 | ## IDEA 133 | .idea/ 134 | ## vscode 135 | .vscode/ 136 | ## vim 137 | *.sw? 138 | -------------------------------------------------------------------------------- /bin/blur_predicts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import tqdm 8 | 9 | from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset 10 | from saicinpainting.evaluation.utils import load_yaml 11 | 12 | 13 | def main(args): 14 | config = load_yaml(args.config) 15 | 16 | if not args.predictdir.endswith('/'): 17 | args.predictdir += '/' 18 | 19 | dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs) 20 | 21 | os.makedirs(os.path.dirname(args.outpath), exist_ok=True) 22 | 23 | for img_i in tqdm.trange(len(dataset)): 24 | pred_fname = dataset.pred_filenames[img_i] 25 | cur_out_fname = os.path.join(args.outpath, pred_fname[len(args.predictdir):]) 26 | os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True) 27 | 28 | sample = dataset[img_i] 29 | img = sample['image'] 30 | mask = sample['mask'] 31 | inpainted = sample['inpainted'] 32 | 33 | inpainted_blurred = cv2.GaussianBlur(np.transpose(inpainted, (1, 2, 0)), 34 | ksize=(args.k, args.k), 35 | sigmaX=args.s, sigmaY=args.s, 36 | borderType=cv2.BORDER_REFLECT) 37 | 38 | cur_res = (1 - mask) * np.transpose(img, (1, 2, 0)) + mask * inpainted_blurred 39 | cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') 40 | cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) 41 | cv2.imwrite(cur_out_fname, cur_res) 42 | 43 | 44 | if __name__ == '__main__': 45 | import argparse 46 | 47 | aparser = argparse.ArgumentParser() 48 | aparser.add_argument('config', type=str, help='Path to evaluation config') 49 | aparser.add_argument('datadir', type=str, 50 | help='Path to folder with images and masks (output of gen_mask_dataset.py)') 51 | aparser.add_argument('predictdir', type=str, 52 | help='Path to folder with predicts (e.g. predict_hifill_baseline.py)') 53 | aparser.add_argument('outpath', type=str, help='Where to put results') 54 | aparser.add_argument('-s', type=float, default=0.1, help='Gaussian blur sigma') 55 | aparser.add_argument('-k', type=int, default=5, help='Kernel size in gaussian blur') 56 | 57 | main(aparser.parse_args()) 58 | -------------------------------------------------------------------------------- /bin/calc_dataset_stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | import numpy as np 6 | import tqdm 7 | from scipy.ndimage.morphology import distance_transform_edt 8 | 9 | from saicinpainting.evaluation.data import InpaintingDataset 10 | from saicinpainting.evaluation.vis import save_item_for_vis 11 | 12 | 13 | def main(args): 14 | dataset = InpaintingDataset(args.datadir, img_suffix='.png') 15 | 16 | area_bins = np.linspace(0, 1, args.area_bins + 1) 17 | 18 | heights = [] 19 | widths = [] 20 | image_areas = [] 21 | hole_areas = [] 22 | hole_area_percents = [] 23 | known_pixel_distances = [] 24 | 25 | area_bins_count = np.zeros(args.area_bins) 26 | area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)] 27 | 28 | bin2i = [[] for _ in range(args.area_bins)] 29 | 30 | for i, item in enumerate(tqdm.tqdm(dataset)): 31 | h, w = item['image'].shape[1:] 32 | heights.append(h) 33 | widths.append(w) 34 | full_area = h * w 35 | image_areas.append(full_area) 36 | bin_mask = item['mask'] > 0.5 37 | hole_area = bin_mask.sum() 38 | hole_areas.append(hole_area) 39 | hole_percent = hole_area / full_area 40 | hole_area_percents.append(hole_percent) 41 | bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1) 42 | area_bins_count[bin_i] += 1 43 | bin2i[bin_i].append(i) 44 | 45 | cur_dist = distance_transform_edt(bin_mask) 46 | cur_dist_inside_mask = cur_dist[bin_mask] 47 | known_pixel_distances.append(cur_dist_inside_mask.mean()) 48 | 49 | os.makedirs(args.outdir, exist_ok=True) 50 | with open(os.path.join(args.outdir, 'summary.txt'), 'w') as f: 51 | f.write(f'''Location: {args.datadir} 52 | 53 | Number of samples: {len(dataset)} 54 | 55 | Image height: min {min(heights):5d} max {max(heights):5d} mean {np.mean(heights):.2f} 56 | Image width: min {min(widths):5d} max {max(widths):5d} mean {np.mean(widths):.2f} 57 | Image area: min {min(image_areas):7d} max {max(image_areas):7d} mean {np.mean(image_areas):.2f} 58 | Hole area: min {min(hole_areas):7d} max {max(hole_areas):7d} mean {np.mean(hole_areas):.2f} 59 | Hole area %: min {min(hole_area_percents) * 100:2.2f} max {max(hole_area_percents) * 100:2.2f} mean {np.mean(hole_area_percents) * 100:2.2f} 60 | Dist 2known: min {min(known_pixel_distances):2.2f} max {max(known_pixel_distances):2.2f} mean {np.mean(known_pixel_distances):2.2f} median {np.median(known_pixel_distances):2.2f} 61 | 62 | Stats by hole area %: 63 | ''') 64 | for bin_i in range(args.area_bins): 65 | f.write(f'{area_bin_titles[bin_i]}%: ' 66 | f'samples number {area_bins_count[bin_i]}, ' 67 | f'{area_bins_count[bin_i] / len(dataset) * 100:.1f}%\n') 68 | 69 | for bin_i in range(args.area_bins): 70 | bindir = os.path.join(args.outdir, 'samples', area_bin_titles[bin_i]) 71 | os.makedirs(bindir, exist_ok=True) 72 | bin_idx = bin2i[bin_i] 73 | for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False): 74 | save_item_for_vis(dataset[sample_i], os.path.join(bindir, f'{sample_i}.png')) 75 | 76 | 77 | if __name__ == '__main__': 78 | import argparse 79 | 80 | aparser = argparse.ArgumentParser() 81 | aparser.add_argument('datadir', type=str, 82 | help='Path to folder with images and masks (output of gen_mask_dataset.py)') 83 | aparser.add_argument('outdir', type=str, help='Where to put results') 84 | aparser.add_argument('--samples-n', type=int, default=10, 85 | help='Number of sample images with masks to copy for visualization for each area bin') 86 | aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have') 87 | 88 | main(aparser.parse_args()) 89 | -------------------------------------------------------------------------------- /bin/debug/analyze_overlapping_masks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASEDIR="$(dirname $0)" 4 | 5 | # paths are valid for mml7 6 | 7 | # select images 8 | #ls /data/inpainting/work/data/train | shuf | head -2000 | xargs -n1 -I{} cp {} /data/inpainting/mask_analysis/src 9 | 10 | # generate masks 11 | #"$BASEDIR/../gen_debug_mask_dataset.py" \ 12 | # "$BASEDIR/../../configs/debug_mask_gen.yaml" \ 13 | # "/data/inpainting/mask_analysis/src" \ 14 | # "/data/inpainting/mask_analysis/generated" 15 | 16 | # predict 17 | #"$BASEDIR/../predict.py" \ 18 | # model.path="simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/saved_checkpoint/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15_epoch22-step-574999" \ 19 | # indir="/data/inpainting/mask_analysis/generated" \ 20 | # outdir="/data/inpainting/mask_analysis/predicted" \ 21 | # dataset.img_suffix=.jpg \ 22 | # +out_ext=.jpg 23 | 24 | # analyze good and bad samples 25 | "$BASEDIR/../analyze_errors.py" \ 26 | --only-report \ 27 | --n-jobs 8 \ 28 | "$BASEDIR/../../configs/analyze_mask_errors.yaml" \ 29 | "/data/inpainting/mask_analysis/small/generated" \ 30 | "/data/inpainting/mask_analysis/small/predicted" \ 31 | "/data/inpainting/mask_analysis/small/report" 32 | -------------------------------------------------------------------------------- /bin/evaluate_predicts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | import pandas as pd 6 | 7 | from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset 8 | from saicinpainting.evaluation.evaluator import InpaintingEvaluator, lpips_fid100_f1 9 | from saicinpainting.evaluation.losses.base_loss import SegmentationAwareSSIM, \ 10 | SegmentationClassStats, SSIMScore, LPIPSScore, FIDScore, SegmentationAwareLPIPS, SegmentationAwareFID 11 | from saicinpainting.evaluation.utils import load_yaml 12 | 13 | 14 | def main(args): 15 | config = load_yaml(args.config) 16 | 17 | dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs) 18 | 19 | metrics = { 20 | 'ssim': SSIMScore(), 21 | 'lpips': LPIPSScore(), 22 | 'fid': FIDScore() 23 | } 24 | enable_segm = config.get('segmentation', dict(enable=False)).get('enable', False) 25 | if enable_segm: 26 | weights_path = os.path.expandvars(config.segmentation.weights_path) 27 | metrics.update(dict( 28 | segm_stats=SegmentationClassStats(weights_path=weights_path), 29 | segm_ssim=SegmentationAwareSSIM(weights_path=weights_path), 30 | segm_lpips=SegmentationAwareLPIPS(weights_path=weights_path), 31 | segm_fid=SegmentationAwareFID(weights_path=weights_path) 32 | )) 33 | evaluator = InpaintingEvaluator(dataset, scores=metrics, 34 | integral_title='lpips_fid100_f1', integral_func=lpips_fid100_f1, 35 | **config.evaluator_kwargs) 36 | 37 | os.makedirs(os.path.dirname(args.outpath), exist_ok=True) 38 | 39 | results = evaluator.evaluate() 40 | 41 | results = pd.DataFrame(results).stack(1).unstack(0) 42 | results.dropna(axis=1, how='all', inplace=True) 43 | results.to_csv(args.outpath, sep='\t', float_format='%.4f') 44 | 45 | if enable_segm: 46 | only_short_results = results[[c for c in results.columns if not c[0].startswith('segm_')]].dropna(axis=1, how='all') 47 | only_short_results.to_csv(args.outpath + '_short', sep='\t', float_format='%.4f') 48 | 49 | print(only_short_results) 50 | 51 | segm_metrics_results = results[['segm_ssim', 'segm_lpips', 'segm_fid']].dropna(axis=1, how='all').transpose().unstack(0).reorder_levels([1, 0], axis=1) 52 | segm_metrics_results.drop(['mean', 'std'], axis=0, inplace=True) 53 | 54 | segm_stats_results = results['segm_stats'].dropna(axis=1, how='all').transpose() 55 | segm_stats_results.index = pd.MultiIndex.from_tuples(n.split('/') for n in segm_stats_results.index) 56 | segm_stats_results = segm_stats_results.unstack(0).reorder_levels([1, 0], axis=1) 57 | segm_stats_results.sort_index(axis=1, inplace=True) 58 | segm_stats_results.dropna(axis=0, how='all', inplace=True) 59 | 60 | segm_results = pd.concat([segm_metrics_results, segm_stats_results], axis=1, sort=True) 61 | segm_results.sort_values(('mask_freq', 'total'), ascending=False, inplace=True) 62 | 63 | segm_results.to_csv(args.outpath + '_segm', sep='\t', float_format='%.4f') 64 | else: 65 | print(results) 66 | 67 | 68 | if __name__ == '__main__': 69 | import argparse 70 | 71 | aparser = argparse.ArgumentParser() 72 | aparser.add_argument('config', type=str, help='Path to evaluation config') 73 | aparser.add_argument('datadir', type=str, 74 | help='Path to folder with images and masks (output of gen_mask_dataset.py)') 75 | aparser.add_argument('predictdir', type=str, 76 | help='Path to folder with predicts (e.g. predict_hifill_baseline.py)') 77 | aparser.add_argument('outpath', type=str, help='Where to put results') 78 | 79 | main(aparser.parse_args()) 80 | -------------------------------------------------------------------------------- /bin/evaluator_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from skimage import io 7 | from skimage.transform import resize 8 | from torch.utils.data import Dataset 9 | 10 | from saicinpainting.evaluation.evaluator import InpaintingEvaluator 11 | from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore 12 | 13 | 14 | class SimpleImageDataset(Dataset): 15 | def __init__(self, root_dir, image_size=(400, 600)): 16 | self.root_dir = root_dir 17 | self.files = sorted(os.listdir(root_dir)) 18 | self.image_size = image_size 19 | 20 | def __getitem__(self, index): 21 | img_name = os.path.join(self.root_dir, self.files[index]) 22 | image = io.imread(img_name) 23 | image = resize(image, self.image_size, anti_aliasing=True) 24 | image = torch.FloatTensor(image).permute(2, 0, 1) 25 | return image 26 | 27 | def __len__(self): 28 | return len(self.files) 29 | 30 | 31 | def create_rectangle_mask(height, width): 32 | mask = np.ones((height, width)) 33 | up_left_corner = width // 4, height // 4 34 | down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1) 35 | cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED) 36 | return mask 37 | 38 | 39 | class Model(): 40 | def __call__(self, img_batch, mask_batch): 41 | mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None] 42 | inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :] 43 | return inpainted 44 | 45 | 46 | class SimpleImageSquareMaskDataset(Dataset): 47 | def __init__(self, dataset): 48 | self.dataset = dataset 49 | self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size)) 50 | self.model = Model() 51 | 52 | def __getitem__(self, index): 53 | img = self.dataset[index] 54 | mask = self.mask.clone() 55 | inpainted = self.model(img[None, ...], mask[None, ...]) 56 | return dict(image=img, mask=mask, inpainted=inpainted) 57 | 58 | def __len__(self): 59 | return len(self.dataset) 60 | 61 | 62 | dataset = SimpleImageDataset('imgs') 63 | mask_dataset = SimpleImageSquareMaskDataset(dataset) 64 | model = Model() 65 | metrics = { 66 | 'ssim': SSIMScore(), 67 | 'lpips': LPIPSScore(), 68 | 'fid': FIDScore() 69 | } 70 | 71 | evaluator = InpaintingEvaluator( 72 | mask_dataset, scores=metrics, batch_size=3, area_grouping=True 73 | ) 74 | 75 | results = evaluator.evaluate(model) 76 | print(results) 77 | -------------------------------------------------------------------------------- /bin/extract_masks.py: -------------------------------------------------------------------------------- 1 | import PIL.Image as Image 2 | import numpy as np 3 | import os 4 | 5 | 6 | def main(args): 7 | if not args.indir.endswith('/'): 8 | args.indir += '/' 9 | os.makedirs(args.outdir, exist_ok=True) 10 | 11 | src_images = [ 12 | args.indir+fname for fname in os.listdir(args.indir)] 13 | 14 | tgt_masks = [ 15 | args.outdir+fname[:-4] + f'_mask000.png' 16 | for fname in os.listdir(args.indir)] 17 | 18 | for img_name, msk_name in zip(src_images, tgt_masks): 19 | #print(img) 20 | #print(msk) 21 | 22 | image = Image.open(img_name).convert('RGB') 23 | image = np.transpose(np.array(image), (2, 0, 1)) 24 | 25 | mask = (image == 255).astype(int) 26 | 27 | print(mask.dtype, mask.shape) 28 | 29 | 30 | Image.fromarray( 31 | np.clip(mask[0,:,:] * 255, 0, 255).astype('uint8'),mode='L' 32 | ).save(msk_name) 33 | 34 | 35 | 36 | 37 | ''' 38 | for infile in src_images: 39 | try: 40 | file_relpath = infile[len(indir):] 41 | img_outpath = os.path.join(outdir, file_relpath) 42 | os.makedirs(os.path.dirname(img_outpath), exist_ok=True) 43 | 44 | image = Image.open(infile).convert('RGB') 45 | 46 | mask = 47 | 48 | Image.fromarray( 49 | np.clip( 50 | cur_mask * 255, 0, 255).astype('uint8'), 51 | mode='L' 52 | ).save(cur_basename + f'_mask{i:03d}.png') 53 | ''' 54 | 55 | 56 | 57 | if __name__ == '__main__': 58 | import argparse 59 | aparser = argparse.ArgumentParser() 60 | aparser.add_argument('--indir', type=str, help='Path to folder with images') 61 | aparser.add_argument('--outdir', type=str, help='Path to folder to store aligned images and masks to') 62 | 63 | main(aparser.parse_args()) 64 | -------------------------------------------------------------------------------- /bin/filter_sharded_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import math 5 | import os 6 | import random 7 | 8 | import braceexpand 9 | import webdataset as wds 10 | 11 | DEFAULT_CATS_FILE = os.path.join(os.path.dirname(__file__), '..', 'configs', 'places2-categories_157.txt') 12 | 13 | def is_good_key(key, cats): 14 | return any(c in key for c in cats) 15 | 16 | 17 | def main(args): 18 | if args.categories == 'nofilter': 19 | good_categories = None 20 | else: 21 | with open(args.categories, 'r') as f: 22 | good_categories = set(line.strip().split(' ')[0] for line in f if line.strip()) 23 | 24 | all_input_files = list(braceexpand.braceexpand(args.infile)) 25 | chunk_size = int(math.ceil(len(all_input_files) / args.n_read_streams)) 26 | 27 | input_iterators = [iter(wds.Dataset(all_input_files[start : start + chunk_size]).shuffle(args.shuffle_buffer)) 28 | for start in range(0, len(all_input_files), chunk_size)] 29 | output_datasets = [wds.ShardWriter(args.outpattern.format(i)) for i in range(args.n_write_streams)] 30 | 31 | good_readers = list(range(len(input_iterators))) 32 | step_i = 0 33 | good_samples = 0 34 | bad_samples = 0 35 | while len(good_readers) > 0: 36 | if step_i % args.print_freq == 0: 37 | print(f'Iterations done {step_i}; readers alive {good_readers}; good samples {good_samples}; bad samples {bad_samples}') 38 | 39 | step_i += 1 40 | 41 | ri = random.choice(good_readers) 42 | try: 43 | sample = next(input_iterators[ri]) 44 | except StopIteration: 45 | good_readers = list(set(good_readers) - {ri}) 46 | continue 47 | 48 | if good_categories is not None and not is_good_key(sample['__key__'], good_categories): 49 | bad_samples += 1 50 | continue 51 | 52 | wi = random.randint(0, args.n_write_streams - 1) 53 | output_datasets[wi].write(sample) 54 | good_samples += 1 55 | 56 | 57 | if __name__ == '__main__': 58 | import argparse 59 | 60 | aparser = argparse.ArgumentParser() 61 | aparser.add_argument('--categories', type=str, default=DEFAULT_CATS_FILE) 62 | aparser.add_argument('--shuffle-buffer', type=int, default=10000) 63 | aparser.add_argument('--n-read-streams', type=int, default=10) 64 | aparser.add_argument('--n-write-streams', type=int, default=10) 65 | aparser.add_argument('--print-freq', type=int, default=1000) 66 | aparser.add_argument('infile', type=str) 67 | aparser.add_argument('outpattern', type=str) 68 | 69 | main(aparser.parse_args()) 70 | -------------------------------------------------------------------------------- /bin/gen_debug_mask_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import glob 4 | import os 5 | 6 | import PIL.Image as Image 7 | import cv2 8 | import numpy as np 9 | import tqdm 10 | import shutil 11 | 12 | 13 | from saicinpainting.evaluation.utils import load_yaml 14 | 15 | 16 | def generate_masks_for_img(infile, outmask_pattern, mask_size=200, step=0.5): 17 | inimg = Image.open(infile) 18 | width, height = inimg.size 19 | step_abs = int(mask_size * step) 20 | 21 | mask = np.zeros((height, width), dtype='uint8') 22 | mask_i = 0 23 | 24 | for start_vertical in range(0, height - step_abs, step_abs): 25 | for start_horizontal in range(0, width - step_abs, step_abs): 26 | mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 255 27 | 28 | cv2.imwrite(outmask_pattern.format(mask_i), mask) 29 | 30 | mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 0 31 | mask_i += 1 32 | 33 | 34 | def main(args): 35 | if not args.indir.endswith('/'): 36 | args.indir += '/' 37 | if not args.outdir.endswith('/'): 38 | args.outdir += '/' 39 | 40 | config = load_yaml(args.config) 41 | 42 | in_files = list(glob.glob(os.path.join(args.indir, '**', f'*{config.img_ext}'), recursive=True)) 43 | for infile in tqdm.tqdm(in_files): 44 | outimg = args.outdir + infile[len(args.indir):] 45 | outmask_pattern = outimg[:-len(config.img_ext)] + '_mask{:04d}.png' 46 | 47 | os.makedirs(os.path.dirname(outimg), exist_ok=True) 48 | shutil.copy2(infile, outimg) 49 | 50 | generate_masks_for_img(infile, outmask_pattern, **config.gen_kwargs) 51 | 52 | 53 | if __name__ == '__main__': 54 | import argparse 55 | 56 | aparser = argparse.ArgumentParser() 57 | aparser.add_argument('config', type=str, help='Path to config for dataset generation') 58 | aparser.add_argument('indir', type=str, help='Path to folder with images') 59 | aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to') 60 | 61 | main(aparser.parse_args()) 62 | -------------------------------------------------------------------------------- /bin/gen_outpainting_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import glob 3 | import logging 4 | import os 5 | import shutil 6 | import sys 7 | import traceback 8 | 9 | from saicinpainting.evaluation.data import load_image 10 | from saicinpainting.evaluation.utils import move_to_device 11 | 12 | os.environ['OMP_NUM_THREADS'] = '1' 13 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 14 | os.environ['MKL_NUM_THREADS'] = '1' 15 | os.environ['VECLIB_MAXIMUM_THREADS'] = '1' 16 | os.environ['NUMEXPR_NUM_THREADS'] = '1' 17 | 18 | import cv2 19 | import hydra 20 | import numpy as np 21 | import torch 22 | import tqdm 23 | import yaml 24 | from omegaconf import OmegaConf 25 | from torch.utils.data._utils.collate import default_collate 26 | 27 | from saicinpainting.training.data.datasets import make_default_val_dataset 28 | from saicinpainting.training.trainers import load_checkpoint 29 | from saicinpainting.utils import register_debug_signal_handlers 30 | 31 | LOGGER = logging.getLogger(__name__) 32 | 33 | 34 | def main(args): 35 | try: 36 | if not args.indir.endswith('/'): 37 | args.indir += '/' 38 | 39 | for in_img in glob.glob(os.path.join(args.indir, '**', '*' + args.img_suffix), recursive=True): 40 | if 'mask' in os.path.basename(in_img): 41 | continue 42 | 43 | out_img_path = os.path.join(args.outdir, os.path.splitext(in_img[len(args.indir):])[0] + '.png') 44 | out_mask_path = f'{os.path.splitext(out_img_path)[0]}_mask.png' 45 | 46 | os.makedirs(os.path.dirname(out_img_path), exist_ok=True) 47 | 48 | img = load_image(in_img) 49 | height, width = img.shape[1:] 50 | pad_h, pad_w = int(height * args.coef / 2), int(width * args.coef / 2) 51 | 52 | mask = np.zeros((height, width), dtype='uint8') 53 | 54 | if args.expand: 55 | img = np.pad(img, ((0, 0), (pad_h, pad_h), (pad_w, pad_w))) 56 | mask = np.pad(mask, ((pad_h, pad_h), (pad_w, pad_w)), mode='constant', constant_values=255) 57 | else: 58 | mask[:pad_h] = 255 59 | mask[-pad_h:] = 255 60 | mask[:, :pad_w] = 255 61 | mask[:, -pad_w:] = 255 62 | 63 | # img = np.pad(img, ((0, 0), (pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode='symmetric') 64 | # mask = np.pad(mask, ((pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode = 'symmetric') 65 | 66 | img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype('uint8') 67 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 68 | cv2.imwrite(out_img_path, img) 69 | 70 | cv2.imwrite(out_mask_path, mask) 71 | except KeyboardInterrupt: 72 | LOGGER.warning('Interrupted by user') 73 | except Exception as ex: 74 | LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}') 75 | sys.exit(1) 76 | 77 | 78 | if __name__ == '__main__': 79 | import argparse 80 | 81 | aparser = argparse.ArgumentParser() 82 | aparser.add_argument('indir', type=str, help='Root directory with images') 83 | aparser.add_argument('outdir', type=str, help='Where to store results') 84 | aparser.add_argument('--img-suffix', type=str, default='.png', help='Input image extension') 85 | aparser.add_argument('--expand', action='store_true', help='Generate mask by padding (true) or by cropping (false)') 86 | aparser.add_argument('--coef', type=float, default=0.2, help='How much to crop/expand in order to get masks') 87 | 88 | main(aparser.parse_args()) 89 | -------------------------------------------------------------------------------- /bin/make_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import shutil 5 | 6 | import torch 7 | 8 | 9 | def get_checkpoint_files(s): 10 | s = s.strip() 11 | if ',' in s: 12 | return [get_checkpoint_files(chunk) for chunk in s.split(',')] 13 | return 'last.ckpt' if s == 'last' else f'{s}.ckpt' 14 | 15 | 16 | def main(args): 17 | checkpoint_fnames = get_checkpoint_files(args.epochs) 18 | if isinstance(checkpoint_fnames, str): 19 | checkpoint_fnames = [checkpoint_fnames] 20 | assert len(checkpoint_fnames) >= 1 21 | 22 | checkpoint_path = os.path.join(args.indir, 'models', checkpoint_fnames[0]) 23 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 24 | del checkpoint['optimizer_states'] 25 | 26 | if len(checkpoint_fnames) > 1: 27 | for fname in checkpoint_fnames[1:]: 28 | print('sum', fname) 29 | sum_tensors_cnt = 0 30 | other_cp = torch.load(os.path.join(args.indir, 'models', fname), map_location='cpu') 31 | for k in checkpoint['state_dict'].keys(): 32 | if checkpoint['state_dict'][k].dtype is torch.float: 33 | checkpoint['state_dict'][k].data.add_(other_cp['state_dict'][k].data) 34 | sum_tensors_cnt += 1 35 | print('summed', sum_tensors_cnt, 'tensors') 36 | 37 | for k in checkpoint['state_dict'].keys(): 38 | if checkpoint['state_dict'][k].dtype is torch.float: 39 | checkpoint['state_dict'][k].data.mul_(1 / float(len(checkpoint_fnames))) 40 | 41 | state_dict = checkpoint['state_dict'] 42 | 43 | if not args.leave_discriminators: 44 | for k in list(state_dict.keys()): 45 | if k.startswith('discriminator.'): 46 | del state_dict[k] 47 | 48 | if not args.leave_losses: 49 | for k in list(state_dict.keys()): 50 | if k.startswith('loss_'): 51 | del state_dict[k] 52 | 53 | out_checkpoint_path = os.path.join(args.outdir, 'models', 'best.ckpt') 54 | os.makedirs(os.path.dirname(out_checkpoint_path), exist_ok=True) 55 | 56 | torch.save(checkpoint, out_checkpoint_path) 57 | 58 | shutil.copy2(os.path.join(args.indir, 'config.yaml'), 59 | os.path.join(args.outdir, 'config.yaml')) 60 | 61 | 62 | if __name__ == '__main__': 63 | import argparse 64 | 65 | aparser = argparse.ArgumentParser() 66 | aparser.add_argument('indir', 67 | help='Path to directory with output of training ' 68 | '(i.e. directory, which has samples, modules, config.yaml and train.log') 69 | aparser.add_argument('outdir', 70 | help='Where to put minimal checkpoint, which can be consumed by "bin/predict.py"') 71 | aparser.add_argument('--epochs', type=str, default='last', 72 | help='Which checkpoint to take. ' 73 | 'Can be "last" or integer - number of epoch') 74 | aparser.add_argument('--leave-discriminators', action='store_true', 75 | help='If enabled, the state of discriminators will not be removed from the checkpoint') 76 | aparser.add_argument('--leave-losses', action='store_true', 77 | help='If enabled, weights of nn-based losses (e.g. perceptual) will not be removed') 78 | 79 | main(aparser.parse_args()) 80 | -------------------------------------------------------------------------------- /bin/mask_example.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from skimage import io 3 | from skimage.transform import resize 4 | 5 | from saicinpainting.evaluation.masks.mask import SegmentationMask 6 | 7 | im = io.imread('imgs/ex4.jpg') 8 | im = resize(im, (512, 1024), anti_aliasing=True) 9 | mask_seg = SegmentationMask(num_variants_per_mask=10) 10 | mask_examples = mask_seg.get_masks(im) 11 | for i, example in enumerate(mask_examples): 12 | plt.imshow(example) 13 | plt.show() 14 | plt.imsave(f'tmp/img_masks/{i}.png', example) 15 | -------------------------------------------------------------------------------- /bin/paper_runfiles/blur_tests.sh: -------------------------------------------------------------------------------- 1 | ##!/usr/bin/env bash 2 | # 3 | ## !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst 4 | # 5 | ## paths to data are valid for mml7 6 | #PLACES_ROOT="/data/inpainting/Places365" 7 | #OUT_DIR="/data/inpainting/paper_data/Places365_val_test" 8 | # 9 | #source "$(dirname $0)/env.sh" 10 | # 11 | #for datadir in test_large_30k # val_large 12 | #do 13 | # for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512 14 | # do 15 | # "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \ 16 | # "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8 17 | # 18 | # "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats" 19 | # done 20 | # 21 | # for conf in segm_256 segm_512 22 | # do 23 | # "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \ 24 | # "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2 25 | # 26 | # "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats" 27 | # done 28 | #done 29 | # 30 | #IN_DIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k/random_medium_512" 31 | #PRED_DIR="/data/inpainting/predictions/final/images/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37/random_medium_512" 32 | #BLUR_OUT_DIR="/data/inpainting/predictions/final/blur/images" 33 | # 34 | #for b in 0.1 35 | # 36 | #"$BINDIR/blur_predicts.py" "$BASEDIR/../../configs/eval2.yaml" "$CUR_IN_DIR" "$CUR_OUT_DIR" "$CUR_EVAL_DIR" 37 | # 38 | -------------------------------------------------------------------------------- /bin/paper_runfiles/env.sh: -------------------------------------------------------------------------------- 1 | DIRNAME="$(dirname $0)" 2 | DIRNAME="$(realpath ""$DIRNAME"")" 3 | 4 | BINDIR="$DIRNAME/.." 5 | SRCDIR="$BINDIR/.." 6 | CONFIGDIR="$SRCDIR/configs" 7 | 8 | export PYTHONPATH="$SRCDIR:$PYTHONPATH" 9 | -------------------------------------------------------------------------------- /bin/paper_runfiles/find_best_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import os 5 | from argparse import ArgumentParser 6 | 7 | 8 | def ssim_fid100_f1(metrics, fid_scale=100): 9 | ssim = metrics.loc['total', 'ssim']['mean'] 10 | fid = metrics.loc['total', 'fid']['mean'] 11 | fid_rel = max(0, fid_scale - fid) / fid_scale 12 | f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3) 13 | return f1 14 | 15 | 16 | def find_best_checkpoint(model_list, models_dir): 17 | with open(model_list) as f: 18 | models = [m.strip() for m in f.readlines()] 19 | with open(f'{model_list}_best', 'w') as f: 20 | for model in models: 21 | print(model) 22 | best_f1 = 0 23 | best_epoch = 0 24 | best_step = 0 25 | with open(os.path.join(models_dir, model, 'train.log')) as fm: 26 | lines = fm.readlines() 27 | for line_index in range(len(lines)): 28 | line = lines[line_index] 29 | if 'Validation metrics after epoch' in line: 30 | sharp_index = line.index('#') 31 | cur_ep = line[sharp_index + 1:] 32 | comma_index = cur_ep.index(',') 33 | cur_ep = int(cur_ep[:comma_index]) 34 | total_index = line.index('total ') 35 | step = int(line[total_index:].split()[1].strip()) 36 | total_line = lines[line_index + 5] 37 | if not total_line.startswith('total'): 38 | continue 39 | words = total_line.strip().split() 40 | f1 = float(words[-1]) 41 | print(f'\tEpoch: {cur_ep}, f1={f1}') 42 | if f1 > best_f1: 43 | best_f1 = f1 44 | best_epoch = cur_ep 45 | best_step = step 46 | f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n') 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = ArgumentParser() 51 | parser.add_argument('model_list') 52 | parser.add_argument('models_dir') 53 | args = parser.parse_args() 54 | find_best_checkpoint(args.model_list, args.models_dir) 55 | -------------------------------------------------------------------------------- /bin/paper_runfiles/generate_test_celeba-hq.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # paths to data are valid for mml-ws01 4 | OUT_DIR="/media/inpainting/paper_data/CelebA-HQ_val_test" 5 | 6 | source "$(dirname $0)/env.sh" 7 | 8 | for datadir in "val" "test" 9 | do 10 | for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512 11 | do 12 | "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-celeba-hq \ 13 | location.out_dir=$OUT_DIR cropping.out_square_crop=False 14 | 15 | "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats" 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /bin/paper_runfiles/generate_test_ffhq.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # paths to data are valid for mml-ws01 4 | OUT_DIR="/media/inpainting/paper_data/FFHQ_val" 5 | 6 | source "$(dirname $0)/env.sh" 7 | 8 | for datadir in test 9 | do 10 | for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512 11 | do 12 | "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-ffhq \ 13 | location.out_dir=$OUT_DIR cropping.out_square_crop=False 14 | 15 | "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats" 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /bin/paper_runfiles/generate_test_paris.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # paths to data are valid for mml-ws01 4 | OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val" 5 | 6 | source "$(dirname $0)/env.sh" 7 | 8 | for datadir in paris_eval_gt 9 | do 10 | for conf in random_thin_256 random_medium_256 random_thick_256 segm_256 11 | do 12 | "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \ 13 | location.out_dir=OUT_DIR cropping.out_square_crop=False cropping.out_min_size=227 14 | 15 | "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats" 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /bin/paper_runfiles/generate_test_paris_256.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # paths to data are valid for mml-ws01 4 | OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val_256" 5 | 6 | source "$(dirname $0)/env.sh" 7 | 8 | for datadir in paris_eval_gt 9 | do 10 | for conf in random_thin_256 random_medium_256 random_thick_256 segm_256 11 | do 12 | "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \ 13 | location.out_dir=$OUT_DIR cropping.out_square_crop=False cropping.out_min_size=256 14 | 15 | "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats" 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /bin/paper_runfiles/generate_val_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst 4 | 5 | # paths to data are valid for mml7 6 | PLACES_ROOT="/data/inpainting/Places365" 7 | OUT_DIR="/data/inpainting/paper_data/Places365_val_test" 8 | 9 | source "$(dirname $0)/env.sh" 10 | 11 | for datadir in test_large_30k # val_large 12 | do 13 | for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512 14 | do 15 | "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \ 16 | "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8 17 | 18 | "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats" 19 | done 20 | 21 | for conf in segm_256 segm_512 22 | do 23 | "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \ 24 | "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2 25 | 26 | "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats" 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /bin/paper_runfiles/predict_inner_features.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # paths to data are valid for mml7 4 | 5 | source "$(dirname $0)/env.sh" 6 | 7 | "$BINDIR/predict_inner_features.py" \ 8 | -cn default_inner_features_ffc \ 9 | model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-34-05_train_ablv2_work_ffc075_resume_epoch39" \ 10 | indir="/data/inpainting/paper_data/inner_features_vis/input/" \ 11 | outdir="/data/inpainting/paper_data/inner_features_vis/output/ffc" \ 12 | dataset.img_suffix=.png 13 | 14 | 15 | "$BINDIR/predict_inner_features.py" \ 16 | -cn default_inner_features_work \ 17 | model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37" \ 18 | indir="/data/inpainting/paper_data/inner_features_vis/input/" \ 19 | outdir="/data/inpainting/paper_data/inner_features_vis/output/work" \ 20 | dataset.img_suffix=.png 21 | -------------------------------------------------------------------------------- /bin/paper_runfiles/update_test_data_stats.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # paths to data are valid for mml7 4 | 5 | source "$(dirname $0)/env.sh" 6 | 7 | #INDIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k" 8 | # 9 | #for dataset in random_medium_256 random_medium_512 random_thick_256 random_thick_512 random_thin_256 random_thin_512 10 | #do 11 | # "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2" 12 | #done 13 | # 14 | #"$BINDIR/calc_dataset_stats.py" "/data/inpainting/evalset2" "/data/inpainting/evalset2_stats2" 15 | 16 | 17 | INDIR="/data/inpainting/paper_data/CelebA-HQ_val_test/test" 18 | 19 | for dataset in random_medium_256 random_thick_256 random_thin_256 20 | do 21 | "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2" 22 | done 23 | 24 | 25 | INDIR="/data/inpainting/paper_data/Paris_StreetView_Dataset_val_256/paris_eval_gt" 26 | 27 | for dataset in random_medium_256 random_thick_256 random_thin_256 28 | do 29 | "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2" 30 | done -------------------------------------------------------------------------------- /bin/predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Example command: 4 | # ./bin/predict.py \ 5 | # model.path= \ 6 | # indir= \ 7 | # outdir= 8 | 9 | import logging 10 | import os 11 | import sys 12 | import traceback 13 | 14 | from saicinpainting.evaluation.utils import move_to_device 15 | from saicinpainting.evaluation.refinement import refine_predict 16 | os.environ['OMP_NUM_THREADS'] = '1' 17 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 18 | os.environ['MKL_NUM_THREADS'] = '1' 19 | os.environ['VECLIB_MAXIMUM_THREADS'] = '1' 20 | os.environ['NUMEXPR_NUM_THREADS'] = '1' 21 | 22 | import cv2 23 | import hydra 24 | import numpy as np 25 | import torch 26 | import tqdm 27 | import yaml 28 | from omegaconf import OmegaConf 29 | from torch.utils.data._utils.collate import default_collate 30 | 31 | from saicinpainting.training.data.datasets import make_default_val_dataset 32 | from saicinpainting.training.trainers import load_checkpoint 33 | from saicinpainting.utils import register_debug_signal_handlers 34 | 35 | LOGGER = logging.getLogger(__name__) 36 | 37 | 38 | @hydra.main(config_path='../configs/prediction', config_name='default.yaml') 39 | def main(predict_config: OmegaConf): 40 | try: 41 | if sys.platform != 'win32': 42 | register_debug_signal_handlers() # kill -10 will result in traceback dumped into log 43 | 44 | device = torch.device("cpu") 45 | 46 | train_config_path = os.path.join(predict_config.model.path, 'config.yaml') 47 | with open(train_config_path, 'r') as f: 48 | train_config = OmegaConf.create(yaml.safe_load(f)) 49 | 50 | train_config.training_model.predict_only = True 51 | train_config.visualizer.kind = 'noop' 52 | 53 | out_ext = predict_config.get('out_ext', '.png') 54 | 55 | checkpoint_path = os.path.join(predict_config.model.path, 56 | 'models', 57 | predict_config.model.checkpoint) 58 | model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu') 59 | model.freeze() 60 | if not predict_config.get('refine', False): 61 | model.to(device) 62 | 63 | if not predict_config.indir.endswith('/'): 64 | predict_config.indir += '/' 65 | 66 | dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset) 67 | for img_i in tqdm.trange(len(dataset)): 68 | mask_fname = dataset.mask_filenames[img_i] 69 | cur_out_fname = os.path.join( 70 | predict_config.outdir, 71 | os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext 72 | ) 73 | os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True) 74 | batch = default_collate([dataset[img_i]]) 75 | if predict_config.get('refine', False): 76 | assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement" 77 | # image unpadding is taken care of in the refiner, so that output image 78 | # is same size as the input image 79 | cur_res = refine_predict(batch, model, **predict_config.refiner) 80 | cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy() 81 | else: 82 | with torch.no_grad(): 83 | batch = move_to_device(batch, device) 84 | batch['mask'] = (batch['mask'] > 0) * 1 85 | batch = model(batch) 86 | cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy() 87 | unpad_to_size = batch.get('unpad_to_size', None) 88 | if unpad_to_size is not None: 89 | orig_height, orig_width = unpad_to_size 90 | cur_res = cur_res[:orig_height, :orig_width] 91 | 92 | cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') 93 | cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) 94 | cv2.imwrite(cur_out_fname, cur_res) 95 | 96 | except KeyboardInterrupt: 97 | LOGGER.warning('Interrupted by user') 98 | except Exception as ex: 99 | LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}') 100 | sys.exit(1) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /bin/report_from_tb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import glob 4 | import os 5 | import re 6 | 7 | import tensorflow as tf 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | 11 | GROUPING_RULES = [ 12 | re.compile(r'^(?Ptrain|test|val|extra_val_.*?(256|512))_(?P.*)', re.I) 13 | ] 14 | 15 | 16 | DROP_RULES = [ 17 | re.compile(r'_std$', re.I) 18 | ] 19 | 20 | 21 | def need_drop(tag): 22 | for rule in DROP_RULES: 23 | if rule.search(tag): 24 | return True 25 | return False 26 | 27 | 28 | def get_group_and_title(tag): 29 | for rule in GROUPING_RULES: 30 | match = rule.search(tag) 31 | if match is None: 32 | continue 33 | return match.group('group'), match.group('title') 34 | return None, None 35 | 36 | 37 | def main(args): 38 | os.makedirs(args.outdir, exist_ok=True) 39 | 40 | ignored_events = set() 41 | 42 | for orig_fname in glob.glob(args.inglob): 43 | cur_dirpath = os.path.dirname(orig_fname) # remove filename, this should point to "version_0" directory 44 | subdirname = os.path.basename(cur_dirpath) # == "version_0" most of time 45 | exp_root_path = os.path.dirname(cur_dirpath) # remove "version_0" 46 | exp_name = os.path.basename(exp_root_path) 47 | 48 | writers_by_group = {} 49 | 50 | for e in tf.compat.v1.train.summary_iterator(orig_fname): 51 | for v in e.summary.value: 52 | if need_drop(v.tag): 53 | continue 54 | 55 | cur_group, cur_title = get_group_and_title(v.tag) 56 | if cur_group is None: 57 | if v.tag not in ignored_events: 58 | print(f'WARNING: Could not detect group for {v.tag}, ignoring it') 59 | ignored_events.add(v.tag) 60 | continue 61 | 62 | cur_writer = writers_by_group.get(cur_group, None) 63 | if cur_writer is None: 64 | if args.include_version: 65 | cur_outdir = os.path.join(args.outdir, exp_name, f'{subdirname}_{cur_group}') 66 | else: 67 | cur_outdir = os.path.join(args.outdir, exp_name, cur_group) 68 | cur_writer = SummaryWriter(cur_outdir) 69 | writers_by_group[cur_group] = cur_writer 70 | 71 | cur_writer.add_scalar(cur_title, v.simple_value, global_step=e.step, walltime=e.wall_time) 72 | 73 | 74 | if __name__ == '__main__': 75 | import argparse 76 | 77 | aparser = argparse.ArgumentParser() 78 | aparser.add_argument('inglob', type=str) 79 | aparser.add_argument('outdir', type=str) 80 | aparser.add_argument('--include-version', action='store_true', 81 | help='Include subdirectory name e.g. "version_0" into output path') 82 | 83 | main(aparser.parse_args()) 84 | -------------------------------------------------------------------------------- /bin/sample_from_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | import numpy as np 6 | import tqdm 7 | from skimage import io 8 | from skimage.segmentation import mark_boundaries 9 | 10 | from saicinpainting.evaluation.data import InpaintingDataset 11 | from saicinpainting.evaluation.vis import save_item_for_vis 12 | 13 | def save_mask_for_sidebyside(item, out_file): 14 | mask = item['mask']# > 0.5 15 | if mask.ndim == 3: 16 | mask = mask[0] 17 | mask = np.clip(mask * 255, 0, 255).astype('uint8') 18 | io.imsave(out_file, mask) 19 | 20 | def save_img_for_sidebyside(item, out_file): 21 | img = np.transpose(item['image'], (1, 2, 0)) 22 | img = np.clip(img * 255, 0, 255).astype('uint8') 23 | io.imsave(out_file, img) 24 | 25 | def save_masked_img_for_sidebyside(item, out_file): 26 | mask = item['mask'] 27 | img = item['image'] 28 | 29 | img = (1-mask) * img + mask 30 | img = np.transpose(img, (1, 2, 0)) 31 | 32 | img = np.clip(img * 255, 0, 255).astype('uint8') 33 | io.imsave(out_file, img) 34 | 35 | def main(args): 36 | dataset = InpaintingDataset(args.datadir, img_suffix='.png') 37 | 38 | area_bins = np.linspace(0, 1, args.area_bins + 1) 39 | 40 | heights = [] 41 | widths = [] 42 | image_areas = [] 43 | hole_areas = [] 44 | hole_area_percents = [] 45 | area_bins_count = np.zeros(args.area_bins) 46 | area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)] 47 | 48 | bin2i = [[] for _ in range(args.area_bins)] 49 | 50 | for i, item in enumerate(tqdm.tqdm(dataset)): 51 | h, w = item['image'].shape[1:] 52 | heights.append(h) 53 | widths.append(w) 54 | full_area = h * w 55 | image_areas.append(full_area) 56 | hole_area = (item['mask'] == 1).sum() 57 | hole_areas.append(hole_area) 58 | hole_percent = hole_area / full_area 59 | hole_area_percents.append(hole_percent) 60 | bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1) 61 | area_bins_count[bin_i] += 1 62 | bin2i[bin_i].append(i) 63 | 64 | os.makedirs(args.outdir, exist_ok=True) 65 | 66 | for bin_i in range(args.area_bins): 67 | bindir = os.path.join(args.outdir, area_bin_titles[bin_i]) 68 | os.makedirs(bindir, exist_ok=True) 69 | bin_idx = bin2i[bin_i] 70 | for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False): 71 | item = dataset[sample_i] 72 | path = os.path.join(bindir, dataset.img_filenames[sample_i].split('/')[-1]) 73 | save_masked_img_for_sidebyside(item, path) 74 | 75 | 76 | if __name__ == '__main__': 77 | import argparse 78 | 79 | aparser = argparse.ArgumentParser() 80 | aparser.add_argument('--datadir', type=str, 81 | help='Path to folder with images and masks (output of gen_mask_dataset.py)') 82 | aparser.add_argument('--outdir', type=str, help='Where to put results') 83 | aparser.add_argument('--samples-n', type=int, default=10, 84 | help='Number of sample images with masks to copy for visualization for each area bin') 85 | aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have') 86 | 87 | main(aparser.parse_args()) 88 | -------------------------------------------------------------------------------- /bin/side_by_side.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset 9 | from saicinpainting.evaluation.utils import load_yaml 10 | from saicinpainting.training.visualizers.base import visualize_mask_and_images 11 | 12 | 13 | def main(args): 14 | config = load_yaml(args.config) 15 | 16 | datasets = [PrecomputedInpaintingResultsDataset(args.datadir, cur_predictdir, **config.dataset_kwargs) 17 | for cur_predictdir in args.predictdirs] 18 | assert len({len(ds) for ds in datasets}) == 1 19 | len_first = len(datasets[0]) 20 | 21 | indices = list(range(len_first)) 22 | if len_first > args.max_n: 23 | indices = sorted(random.sample(indices, args.max_n)) 24 | 25 | os.makedirs(args.outpath, exist_ok=True) 26 | 27 | filename2i = {} 28 | 29 | keys = ['image'] + [i for i in range(len(datasets))] 30 | for img_i in indices: 31 | try: 32 | mask_fname = os.path.basename(datasets[0].mask_filenames[img_i]) 33 | if mask_fname in filename2i: 34 | filename2i[mask_fname] += 1 35 | idx = filename2i[mask_fname] 36 | mask_fname_only, ext = os.path.split(mask_fname) 37 | mask_fname = f'{mask_fname_only}_{idx}{ext}' 38 | else: 39 | filename2i[mask_fname] = 1 40 | 41 | cur_vis_dict = datasets[0][img_i] 42 | for ds_i, ds in enumerate(datasets): 43 | cur_vis_dict[ds_i] = ds[img_i]['inpainted'] 44 | 45 | vis_img = visualize_mask_and_images(cur_vis_dict, keys, 46 | last_without_mask=False, 47 | mask_only_first=True, 48 | black_mask=args.black) 49 | vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8') 50 | 51 | out_fname = os.path.join(args.outpath, mask_fname) 52 | 53 | 54 | 55 | vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR) 56 | cv2.imwrite(out_fname, vis_img) 57 | except Exception as ex: 58 | print(f'Could not process {img_i} due to {ex}') 59 | 60 | 61 | if __name__ == '__main__': 62 | import argparse 63 | 64 | aparser = argparse.ArgumentParser() 65 | aparser.add_argument('--max-n', type=int, default=100, help='Maximum number of images to print') 66 | aparser.add_argument('--black', action='store_true', help='Whether to fill mask on GT with black') 67 | aparser.add_argument('config', type=str, help='Path to evaluation config (e.g. configs/eval1.yaml)') 68 | aparser.add_argument('outpath', type=str, help='Where to put results') 69 | aparser.add_argument('datadir', type=str, 70 | help='Path to folder with images and masks') 71 | aparser.add_argument('predictdirs', type=str, 72 | nargs='+', 73 | help='Path to folders with predicts') 74 | 75 | 76 | main(aparser.parse_args()) 77 | -------------------------------------------------------------------------------- /bin/split_tar.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import tqdm 5 | import webdataset as wds 6 | 7 | 8 | def main(args): 9 | input_dataset = wds.Dataset(args.infile) 10 | output_dataset = wds.ShardWriter(args.outpattern) 11 | for rec in tqdm.tqdm(input_dataset): 12 | output_dataset.write(rec) 13 | 14 | 15 | if __name__ == '__main__': 16 | import argparse 17 | 18 | aparser = argparse.ArgumentParser() 19 | aparser.add_argument('infile', type=str) 20 | aparser.add_argument('outpattern', type=str) 21 | 22 | main(aparser.parse_args()) 23 | -------------------------------------------------------------------------------- /bin/to_jit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | import torch 6 | import yaml 7 | from omegaconf import OmegaConf 8 | from torch import nn 9 | 10 | from saicinpainting.training.trainers import load_checkpoint 11 | from saicinpainting.utils import register_debug_signal_handlers 12 | 13 | 14 | class JITWrapper(nn.Module): 15 | def __init__(self, model): 16 | super().__init__() 17 | self.model = model 18 | 19 | def forward(self, image, mask): 20 | batch = { 21 | "image": image, 22 | "mask": mask 23 | } 24 | out = self.model(batch) 25 | return out["inpainted"] 26 | 27 | 28 | @hydra.main(config_path="../configs/prediction", config_name="default.yaml") 29 | def main(predict_config: OmegaConf): 30 | if sys.platform != 'win32': 31 | register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log 32 | 33 | train_config_path = os.path.join(predict_config.model.path, "config.yaml") 34 | with open(train_config_path, "r") as f: 35 | train_config = OmegaConf.create(yaml.safe_load(f)) 36 | 37 | train_config.training_model.predict_only = True 38 | train_config.visualizer.kind = "noop" 39 | 40 | checkpoint_path = os.path.join( 41 | predict_config.model.path, "models", predict_config.model.checkpoint 42 | ) 43 | model = load_checkpoint( 44 | train_config, checkpoint_path, strict=False, map_location="cpu" 45 | ) 46 | model.eval() 47 | jit_model_wrapper = JITWrapper(model) 48 | 49 | image = torch.rand(1, 3, 120, 120) 50 | mask = torch.rand(1, 1, 120, 120) 51 | output = jit_model_wrapper(image, mask) 52 | 53 | if torch.cuda.is_available(): 54 | device = torch.device("cuda") 55 | else: 56 | device = torch.device("cpu") 57 | 58 | image = image.to(device) 59 | mask = mask.to(device) 60 | traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device) 61 | 62 | save_path = Path(predict_config.save_path) 63 | save_path.parent.mkdir(parents=True, exist_ok=True) 64 | 65 | print(f"Saving big-lama.pt model to {save_path}") 66 | traced_model.save(save_path) 67 | 68 | print(f"Checking jit model output...") 69 | jit_model = torch.jit.load(str(save_path)) 70 | jit_output = jit_model(image, mask) 71 | diff = (output - jit_output).abs().sum() 72 | print(f"diff: {diff}") 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import os 5 | import sys 6 | import traceback 7 | 8 | os.environ['OMP_NUM_THREADS'] = '1' 9 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 10 | os.environ['MKL_NUM_THREADS'] = '1' 11 | os.environ['VECLIB_MAXIMUM_THREADS'] = '1' 12 | os.environ['NUMEXPR_NUM_THREADS'] = '1' 13 | 14 | import hydra 15 | from omegaconf import OmegaConf 16 | from pytorch_lightning import Trainer 17 | from pytorch_lightning.callbacks import ModelCheckpoint 18 | from pytorch_lightning.loggers import TensorBoardLogger 19 | from pytorch_lightning.plugins import DDPPlugin 20 | 21 | from saicinpainting.training.trainers import make_training_model 22 | from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \ 23 | handle_deterministic_config 24 | 25 | LOGGER = logging.getLogger(__name__) 26 | 27 | 28 | @handle_ddp_subprocess() 29 | @hydra.main(config_path='../configs/training', config_name='tiny_test.yaml') 30 | def main(config: OmegaConf): 31 | try: 32 | need_set_deterministic = handle_deterministic_config(config) 33 | 34 | if sys.platform != 'win32': 35 | register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log 36 | 37 | is_in_ddp_subprocess = handle_ddp_parent_process() 38 | 39 | config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir) 40 | if not is_in_ddp_subprocess: 41 | LOGGER.info(OmegaConf.to_yaml(config)) 42 | OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml')) 43 | 44 | checkpoints_dir = os.path.join(os.getcwd(), 'models') 45 | os.makedirs(checkpoints_dir, exist_ok=True) 46 | 47 | # there is no need to suppress this logger in ddp, because it handles rank on its own 48 | metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd())) 49 | metrics_logger.log_hyperparams(config) 50 | 51 | training_model = make_training_model(config) 52 | 53 | trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True) 54 | if need_set_deterministic: 55 | trainer_kwargs['deterministic'] = True 56 | 57 | trainer = Trainer( 58 | # there is no need to suppress checkpointing in ddp, because it handles rank on its own 59 | callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs), 60 | logger=metrics_logger, 61 | default_root_dir=os.getcwd(), 62 | **trainer_kwargs 63 | ) 64 | trainer.fit(training_model) 65 | except KeyboardInterrupt: 66 | LOGGER.warning('Interrupted by user') 67 | except Exception as ex: 68 | LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}') 69 | sys.exit(1) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /configs/analyze_mask_errors.yaml: -------------------------------------------------------------------------------- 1 | dataset_kwargs: 2 | img_suffix: .jpg 3 | inpainted_suffix: .jpg 4 | 5 | take_global_top: 30 6 | take_worst_best_top: 30 7 | take_overlapping_top: 30 -------------------------------------------------------------------------------- /configs/data_gen/random_medium_256.yaml: -------------------------------------------------------------------------------- 1 | generator_kind: random 2 | 3 | mask_generator_kwargs: 4 | irregular_proba: 1 5 | irregular_kwargs: 6 | min_times: 4 7 | max_times: 5 8 | max_width: 50 9 | max_angle: 4 10 | max_len: 100 11 | 12 | box_proba: 0.3 13 | box_kwargs: 14 | margin: 0 15 | bbox_min_size: 10 16 | bbox_max_size: 50 17 | max_times: 5 18 | min_times: 1 19 | 20 | segm_proba: 0 21 | squares_proba: 0 22 | 23 | variants_n: 5 24 | 25 | max_masks_per_image: 1 26 | 27 | cropping: 28 | out_min_size: 256 29 | handle_small_mode: upscale 30 | out_square_crop: True 31 | crop_min_overlap: 1 32 | 33 | max_tamper_area: 0.5 34 | -------------------------------------------------------------------------------- /configs/data_gen/random_medium_512.yaml: -------------------------------------------------------------------------------- 1 | generator_kind: random 2 | 3 | mask_generator_kwargs: 4 | irregular_proba: 1 5 | irregular_kwargs: 6 | min_times: 4 7 | max_times: 10 8 | max_width: 100 9 | max_angle: 4 10 | max_len: 200 11 | 12 | box_proba: 0.3 13 | box_kwargs: 14 | margin: 0 15 | bbox_min_size: 30 16 | bbox_max_size: 150 17 | max_times: 5 18 | min_times: 1 19 | 20 | segm_proba: 0 21 | squares_proba: 0 22 | 23 | variants_n: 5 24 | 25 | max_masks_per_image: 1 26 | 27 | cropping: 28 | out_min_size: 512 29 | handle_small_mode: upscale 30 | out_square_crop: True 31 | crop_min_overlap: 1 32 | 33 | max_tamper_area: 0.5 34 | -------------------------------------------------------------------------------- /configs/data_gen/random_thick_256.yaml: -------------------------------------------------------------------------------- 1 | generator_kind: random 2 | 3 | mask_generator_kwargs: 4 | irregular_proba: 1 5 | irregular_kwargs: 6 | min_times: 1 7 | max_times: 5 8 | max_width: 100 9 | max_angle: 4 10 | max_len: 200 11 | 12 | box_proba: 0.3 13 | box_kwargs: 14 | margin: 10 15 | bbox_min_size: 30 16 | bbox_max_size: 150 17 | max_times: 3 18 | min_times: 1 19 | 20 | segm_proba: 0 21 | squares_proba: 0 22 | 23 | variants_n: 5 24 | 25 | max_masks_per_image: 1 26 | 27 | cropping: 28 | out_min_size: 256 29 | handle_small_mode: upscale 30 | out_square_crop: True 31 | crop_min_overlap: 1 32 | 33 | max_tamper_area: 0.5 34 | -------------------------------------------------------------------------------- /configs/data_gen/random_thick_512.yaml: -------------------------------------------------------------------------------- 1 | generator_kind: random 2 | 3 | mask_generator_kwargs: 4 | irregular_proba: 1 5 | irregular_kwargs: 6 | min_times: 1 7 | max_times: 5 8 | max_width: 250 9 | max_angle: 4 10 | max_len: 450 11 | 12 | box_proba: 0.3 13 | box_kwargs: 14 | margin: 10 15 | bbox_min_size: 30 16 | bbox_max_size: 300 17 | max_times: 4 18 | min_times: 1 19 | 20 | segm_proba: 0 21 | squares_proba: 0 22 | 23 | variants_n: 5 24 | 25 | max_masks_per_image: 1 26 | 27 | cropping: 28 | out_min_size: 512 29 | handle_small_mode: upscale 30 | out_square_crop: True 31 | crop_min_overlap: 1 32 | 33 | max_tamper_area: 0.5 34 | -------------------------------------------------------------------------------- /configs/data_gen/random_thin_256.yaml: -------------------------------------------------------------------------------- 1 | generator_kind: random 2 | 3 | mask_generator_kwargs: 4 | irregular_proba: 1 5 | irregular_kwargs: 6 | min_times: 4 7 | max_times: 50 8 | max_width: 10 9 | max_angle: 4 10 | max_len: 40 11 | box_proba: 0 12 | segm_proba: 0 13 | squares_proba: 0 14 | 15 | variants_n: 5 16 | 17 | max_masks_per_image: 1 18 | 19 | cropping: 20 | out_min_size: 256 21 | handle_small_mode: upscale 22 | out_square_crop: True 23 | crop_min_overlap: 1 24 | 25 | max_tamper_area: 0.5 26 | -------------------------------------------------------------------------------- /configs/data_gen/random_thin_512.yaml: -------------------------------------------------------------------------------- 1 | generator_kind: random 2 | 3 | mask_generator_kwargs: 4 | irregular_proba: 1 5 | irregular_kwargs: 6 | min_times: 4 7 | max_times: 70 8 | max_width: 20 9 | max_angle: 4 10 | max_len: 100 11 | box_proba: 0 12 | segm_proba: 0 13 | squares_proba: 0 14 | 15 | variants_n: 5 16 | 17 | max_masks_per_image: 1 18 | 19 | cropping: 20 | out_min_size: 512 21 | handle_small_mode: upscale 22 | out_square_crop: True 23 | crop_min_overlap: 1 24 | 25 | max_tamper_area: 0.5 26 | -------------------------------------------------------------------------------- /configs/debug_mask_gen.yaml: -------------------------------------------------------------------------------- 1 | img_ext: .jpg 2 | 3 | gen_kwargs: 4 | mask_size: 200 5 | step: 0.5 6 | -------------------------------------------------------------------------------- /configs/eval1.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 8 3 | 4 | dataset_kwargs: 5 | img_suffix: .png 6 | inpainted_suffix: .jpg -------------------------------------------------------------------------------- /configs/eval2.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 8 3 | device: cuda 4 | 5 | dataset_kwargs: 6 | img_suffix: .png 7 | inpainted_suffix: .png -------------------------------------------------------------------------------- /configs/eval2_cpu.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 8 3 | device: cpu 4 | 5 | dataset_kwargs: 6 | img_suffix: .png 7 | inpainted_suffix: .png -------------------------------------------------------------------------------- /configs/eval2_gpu.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 8 3 | 4 | dataset_kwargs: 5 | img_suffix: .png 6 | inpainted_suffix: .png -------------------------------------------------------------------------------- /configs/eval2_jpg.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 8 3 | 4 | dataset_kwargs: 5 | img_suffix: .png 6 | inpainted_suffix: .jpg -------------------------------------------------------------------------------- /configs/eval2_segm.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 8 3 | 4 | dataset_kwargs: 5 | img_suffix: .png 6 | inpainted_suffix: .png 7 | 8 | segmentation: 9 | enable: True 10 | weights_path: ${TORCH_HOME} 11 | -------------------------------------------------------------------------------- /configs/eval2_segm_test.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 1 3 | 4 | dataset_kwargs: 5 | img_suffix: _input.png 6 | inpainted_suffix: .png 7 | pad_out_to_modulo: 8 8 | 9 | segmentation: 10 | enable: True 11 | weights_path: ${TORCH_HOME} 12 | -------------------------------------------------------------------------------- /configs/eval2_test.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 1 3 | 4 | dataset_kwargs: 5 | img_suffix: _input.png 6 | inpainted_suffix: .png 7 | pad_out_to_modulo: 8 8 | -------------------------------------------------------------------------------- /configs/places2-categories_157.txt: -------------------------------------------------------------------------------- 1 | /a/airplane_cabin 1 2 | /a/airport_terminal 2 3 | /a/alcove 3 4 | /a/alley 4 5 | /a/amphitheater 5 6 | /a/amusement_park 7 7 | /a/apartment_building/outdoor 8 8 | /a/aqueduct 10 9 | /a/arcade 11 10 | /a/arch 12 11 | /a/archive 14 12 | /a/art_gallery 19 13 | /a/artists_loft 22 14 | /a/assembly_line 23 15 | /a/atrium/public 25 16 | /a/attic 26 17 | /a/auditorium 27 18 | /b/bakery/shop 31 19 | /b/balcony/exterior 32 20 | /b/balcony/interior 33 21 | /b/ballroom 35 22 | /b/banquet_hall 38 23 | /b/barndoor 41 24 | /b/basement 43 25 | /b/basketball_court/indoor 44 26 | /b/bathroom 45 27 | /b/bazaar/indoor 46 28 | /b/bazaar/outdoor 47 29 | /b/beach_house 49 30 | /b/bedchamber 51 31 | /b/bedroom 52 32 | /b/berth 55 33 | /b/boardwalk 57 34 | /b/boathouse 59 35 | /b/bookstore 60 36 | /b/booth/indoor 61 37 | /b/bow_window/indoor 63 38 | /b/bowling_alley 64 39 | /b/bridge 66 40 | /b/building_facade 67 41 | /b/bus_interior 70 42 | /b/bus_station/indoor 71 43 | /c/cabin/outdoor 74 44 | /c/campus 77 45 | /c/canal/urban 79 46 | /c/candy_store 80 47 | /c/carrousel 83 48 | /c/castle 84 49 | /c/chalet 87 50 | /c/childs_room 89 51 | /c/church/indoor 90 52 | /c/church/outdoor 91 53 | /c/closet 95 54 | /c/conference_center 101 55 | /c/conference_room 102 56 | /c/construction_site 103 57 | /c/corridor 106 58 | /c/cottage 107 59 | /c/courthouse 108 60 | /c/courtyard 109 61 | /d/delicatessen 114 62 | /d/department_store 115 63 | /d/diner/outdoor 119 64 | /d/dining_hall 120 65 | /d/dining_room 121 66 | /d/doorway/outdoor 123 67 | /d/dorm_room 124 68 | /d/downtown 125 69 | /d/driveway 127 70 | /e/elevator/door 129 71 | /e/elevator_lobby 130 72 | /e/elevator_shaft 131 73 | /e/embassy 132 74 | /e/entrance_hall 134 75 | /e/escalator/indoor 135 76 | /f/fastfood_restaurant 139 77 | /f/fire_escape 143 78 | /f/fire_station 144 79 | /f/food_court 148 80 | /g/galley 155 81 | /g/garage/outdoor 157 82 | /g/gas_station 158 83 | /g/gazebo/exterior 159 84 | /g/general_store/indoor 160 85 | /g/general_store/outdoor 161 86 | /g/greenhouse/outdoor 166 87 | /g/gymnasium/indoor 168 88 | /h/hangar/outdoor 170 89 | /h/hardware_store 172 90 | /h/home_office 176 91 | /h/home_theater 177 92 | /h/hospital 178 93 | /h/hotel/outdoor 181 94 | /h/hotel_room 182 95 | /h/house 183 96 | /h/hunting_lodge/outdoor 184 97 | /i/industrial_area 192 98 | /i/inn/outdoor 193 99 | /j/jacuzzi/indoor 195 100 | /j/jail_cell 196 101 | /k/kasbah 200 102 | /k/kitchen 203 103 | /l/laundromat 208 104 | /l/library/indoor 212 105 | /l/library/outdoor 213 106 | /l/lighthouse 214 107 | /l/living_room 215 108 | /l/loading_dock 216 109 | /l/lobby 217 110 | /l/lock_chamber 218 111 | /m/mansion 220 112 | /m/manufactured_home 221 113 | /m/mausoleum 226 114 | /m/medina 227 115 | /m/mezzanine 228 116 | /m/mosque/outdoor 230 117 | /m/movie_theater/indoor 235 118 | /m/museum/outdoor 237 119 | /n/nursery 240 120 | /o/oast_house 242 121 | /o/office 244 122 | /o/office_building 245 123 | /o/office_cubicles 246 124 | /p/pagoda 251 125 | /p/palace 252 126 | /p/pantry 253 127 | /p/parking_garage/indoor 255 128 | /p/parking_garage/outdoor 256 129 | /p/pavilion 260 130 | /p/pet_shop 261 131 | /p/porch 272 132 | /r/reception 280 133 | /r/recreation_room 281 134 | /r/restaurant_patio 286 135 | /r/rope_bridge 291 136 | /r/ruin 292 137 | /s/sauna 295 138 | /s/schoolhouse 296 139 | /s/server_room 298 140 | /s/shed 299 141 | /s/shopfront 301 142 | /s/shopping_mall/indoor 302 143 | /s/shower 303 144 | /s/skyscraper 307 145 | /s/staircase 317 146 | /s/storage_room 318 147 | /s/subway_station/platform 320 148 | /s/synagogue/outdoor 327 149 | /t/television_room 328 150 | /t/temple/asia 330 151 | /t/throne_room 331 152 | /t/tower 334 153 | /t/train_station/platform 337 154 | /u/utility_room 343 155 | /w/waiting_room 352 156 | /w/wet_bar 358 157 | /y/youth_hostel 363 -------------------------------------------------------------------------------- /configs/prediction/default.yaml: -------------------------------------------------------------------------------- 1 | indir: no # to be overriden in CLI 2 | outdir: no # to be overriden in CLI 3 | 4 | model: 5 | path: no # to be overriden in CLI 6 | checkpoint: best.ckpt 7 | 8 | dataset: 9 | kind: default 10 | img_suffix: .png 11 | pad_out_to_modulo: 8 12 | 13 | device: cuda 14 | out_key: inpainted 15 | 16 | refine: False # refiner will only run if this is True 17 | refiner: 18 | gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0," 19 | modulo: ${dataset.pad_out_to_modulo} 20 | n_iters: 15 # number of iterations of refinement for each scale 21 | lr: 0.002 # learning rate 22 | min_side: 512 # all sides of image on all scales should be >= min_side / sqrt(2) 23 | max_scales: 3 # max number of downscaling scales for the image-mask pyramid 24 | px_budget: 1800000 # pixels budget. Any image will be resized to satisfy height*width <= px_budget -------------------------------------------------------------------------------- /configs/training/ablv2_work.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 30 25 | weights_path: ${env:TORCH_HOME} 26 | 27 | defaults: 28 | - location: docker 29 | - data: abl-04-256-mh-dist 30 | - generator: pix2pixhd_global_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_ffc075.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 30 25 | weights_path: ${env:TORCH_HOME} 26 | 27 | defaults: 28 | - location: docker 29 | - data: abl-04-256-mh-dist 30 | - generator: ffc_resnet_075 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_md.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 30 25 | weights_path: ${env:TORCH_HOME} 26 | 27 | defaults: 28 | - location: docker 29 | - data: abl-04-256-mh-dist 30 | - generator: pix2pixhd_multidilated_catin_4dil_9b 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final_benchmark 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_no_fm.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 0 23 | resnet_pl: 24 | weight: 30 25 | weights_path: ${env:TORCH_HOME} 26 | 27 | defaults: 28 | - location: mlp-mow-final 29 | - data: abl-04-256-mh-dist 30 | - generator: pix2pixhd_global_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_no_segmpl.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 0 25 | # weights_path: ${env:TORCH_HOME} 26 | 27 | defaults: 28 | - location: docker 29 | - data: abl-04-256-mh-dist 30 | - generator: pix2pixhd_global_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_no_segmpl_csdilirpl.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 1 25 | segmentation: false 26 | 27 | defaults: 28 | - location: docker 29 | - data: abl-04-256-mh-dist 30 | - generator: pix2pixhd_global_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final_benchmark 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_no_segmpl_csdilirpl_celeba_csdilirpl1_new.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0 14 | adversarial: 15 | kind: r1 16 | weight: 10 17 | gp_coef: 0.001 18 | mask_as_fake_target: true 19 | allow_scale_mask: true 20 | feature_matching: 21 | weight: 100 22 | segm_pl: 23 | weight: 1 24 | imagenet_weights: true 25 | 26 | defaults: 27 | - location: celeba 28 | - data: abl-04-256-mh-dist-celeba 29 | - generator: pix2pixhd_global_sigmoid 30 | - discriminator: pix2pixhd_nlayer 31 | - optimizers: default_optimizers 32 | - visualizer: directory 33 | - evaluator: default_inpainted 34 | - trainer: any_gpu_large_ssim_ddp_final_celeba 35 | - hydra: overrides -------------------------------------------------------------------------------- /configs/training/ablv2_work_no_segmpl_csirpl.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 0.3 25 | arch_encoder: 'resnet50' 26 | segmentation: false 27 | 28 | defaults: 29 | - location: docker 30 | - data: abl-04-256-mh-dist 31 | - generator: pix2pixhd_global_sigmoid 32 | - discriminator: pix2pixhd_nlayer 33 | - optimizers: default_optimizers 34 | - visualizer: directory 35 | - evaluator: default_inpainted 36 | - trainer: any_gpu_large_ssim_ddp_final 37 | - hydra: overrides 38 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_no_segmpl_csirpl_celeba_csirpl03_new.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0 14 | adversarial: 15 | kind: r1 16 | weight: 10 17 | gp_coef: 0.001 18 | mask_as_fake_target: true 19 | allow_scale_mask: true 20 | feature_matching: 21 | weight: 100 22 | segm_pl: 23 | weight: 0.3 24 | arch_encoder: resnet50 25 | imagenet_weights: true 26 | 27 | defaults: 28 | - location: celeba 29 | - data: abl-04-256-mh-dist-celeba 30 | - generator: pix2pixhd_global_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final_celeba 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_no_segmpl_vgg.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0.03 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 0 25 | # weights_path: ${env:TORCH_HOME} 26 | 27 | defaults: 28 | - location: docker 29 | - data: abl-04-256-mh-dist 30 | - generator: pix2pixhd_global_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_no_segmpl_vgg_celeba_l2_vgg003_new.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0.03 14 | kwargs: 15 | metric: l2 16 | adversarial: 17 | kind: r1 18 | weight: 10 19 | gp_coef: 0.001 20 | mask_as_fake_target: true 21 | allow_scale_mask: true 22 | feature_matching: 23 | weight: 100 24 | segm_pl: 25 | weight: 0 26 | 27 | defaults: 28 | - location: celeba 29 | - data: abl-04-256-mh-dist-celeba 30 | - generator: pix2pixhd_global_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final_celeba 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_nodil_segmpl.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | arch_encoder: resnet50 25 | weight: 30 26 | weights_path: ${env:TORCH_HOME} 27 | 28 | defaults: 29 | - location: docker 30 | - data: abl-04-256-mh-dist 31 | - generator: pix2pixhd_global_sigmoid 32 | - discriminator: pix2pixhd_nlayer 33 | - optimizers: default_optimizers 34 | - visualizer: directory 35 | - evaluator: default_inpainted 36 | - trainer: any_gpu_large_ssim_ddp_final 37 | - hydra: overrides 38 | -------------------------------------------------------------------------------- /configs/training/ablv2_work_small_holes.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: True 7 | store_discr_outputs_for_vis: True 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: True 20 | allow_scale_mask: True 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 30 25 | weights_path: ${env:TORCH_HOME} 26 | 27 | defaults: 28 | - location: docker 29 | - data: abl-02-thin-bb 30 | - generator: pix2pixhd_global_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final 36 | - hydra: overrides 37 | -------------------------------------------------------------------------------- /configs/training/big-lama-celeba.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0 14 | adversarial: 15 | kind: r1 16 | weight: 10 17 | gp_coef: 0.001 18 | mask_as_fake_target: true 19 | allow_scale_mask: true 20 | feature_matching: 21 | weight: 100 22 | resnet_pl: 23 | weight: 30 24 | weights_path: ${env:TORCH_HOME} 25 | 26 | generator: 27 | kind: ffc_resnet 28 | input_nc: 4 29 | output_nc: 3 30 | ngf: 64 31 | n_downsampling: 3 32 | n_blocks: 18 33 | add_out_act: sigmoid 34 | init_conv_kwargs: 35 | ratio_gin: 0 36 | ratio_gout: 0 37 | enable_lfu: false 38 | downsample_conv_kwargs: 39 | ratio_gin: ${generator.init_conv_kwargs.ratio_gout} 40 | ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin} 41 | enable_lfu: false 42 | resnet_conv_kwargs: 43 | ratio_gin: 0.75 44 | ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin} 45 | enable_lfu: false 46 | 47 | defaults: 48 | - location: celeba 49 | - data: abl-04-256-mh-dist-celeba 50 | - discriminator: pix2pixhd_nlayer 51 | - optimizers: default_optimizers 52 | - visualizer: directory 53 | - evaluator: default_inpainted 54 | - trainer: any_gpu_large_ssim_ddp_final_celeba 55 | - hydra: overrides 56 | -------------------------------------------------------------------------------- /configs/training/big-lama-regular-celeba.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | generator: 4 | kind: pix2pixhd_global 5 | input_nc: 4 6 | output_nc: 3 7 | ngf: 64 8 | n_downsampling: 3 9 | n_blocks: 15 10 | conv_kind: default 11 | add_out_act: sigmoid 12 | 13 | training_model: 14 | kind: default 15 | visualize_each_iters: 1000 16 | concat_mask: true 17 | store_discr_outputs_for_vis: true 18 | 19 | losses: 20 | l1: 21 | weight_missing: 0 22 | weight_known: 10 23 | perceptual: 24 | weight: 0 25 | adversarial: 26 | kind: r1 27 | weight: 10 28 | gp_coef: 0.001 29 | mask_as_fake_target: true 30 | allow_scale_mask: true 31 | feature_matching: 32 | weight: 100 33 | resnet_pl: 34 | weight: 30 35 | weights_path: ${env:TORCH_HOME} 36 | 37 | defaults: 38 | - location: celeba 39 | - data: abl-04-256-mh-dist-celeba 40 | - discriminator: pix2pixhd_nlayer 41 | - optimizers: default_optimizers 42 | - visualizer: directory 43 | - evaluator: default_inpainted 44 | - trainer: any_gpu_large_ssim_ddp_final_celeba 45 | - hydra: overrides -------------------------------------------------------------------------------- /configs/training/big-lama-regular.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | generator: 4 | kind: pix2pixhd_global 5 | input_nc: 4 6 | output_nc: 3 7 | ngf: 64 8 | n_downsampling: 3 9 | n_blocks: 15 10 | conv_kind: default 11 | add_out_act: sigmoid 12 | 13 | training_model: 14 | kind: default 15 | visualize_each_iters: 1000 16 | concat_mask: true 17 | store_discr_outputs_for_vis: true 18 | 19 | losses: 20 | l1: 21 | weight_missing: 0 22 | weight_known: 10 23 | perceptual: 24 | weight: 0 25 | adversarial: 26 | kind: r1 27 | weight: 10 28 | gp_coef: 0.001 29 | mask_as_fake_target: true 30 | allow_scale_mask: true 31 | feature_matching: 32 | weight: 100 33 | resnet_pl: 34 | weight: 30 35 | weights_path: ${env:TORCH_HOME} 36 | 37 | defaults: 38 | - location: docker 39 | - data: abl-04-256-mh-dist 40 | - discriminator: pix2pixhd_nlayer 41 | - optimizers: default_optimizers 42 | - visualizer: directory 43 | - evaluator: default_inpainted 44 | - trainer: any_gpu_large_ssim_ddp_final 45 | - hydra: overrides -------------------------------------------------------------------------------- /configs/training/big-lama.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0 14 | adversarial: 15 | kind: r1 16 | weight: 10 17 | gp_coef: 0.001 18 | mask_as_fake_target: true 19 | allow_scale_mask: true 20 | feature_matching: 21 | weight: 100 22 | resnet_pl: 23 | weight: 30 24 | weights_path: ${env:TORCH_HOME} 25 | 26 | generator: 27 | kind: ffc_resnet 28 | input_nc: 4 29 | output_nc: 3 30 | ngf: 64 31 | n_downsampling: 3 32 | n_blocks: 18 33 | add_out_act: sigmoid 34 | init_conv_kwargs: 35 | ratio_gin: 0 36 | ratio_gout: 0 37 | enable_lfu: false 38 | downsample_conv_kwargs: 39 | ratio_gin: ${generator.init_conv_kwargs.ratio_gout} 40 | ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin} 41 | enable_lfu: false 42 | resnet_conv_kwargs: 43 | ratio_gin: 0.75 44 | ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin} 45 | enable_lfu: false 46 | 47 | defaults: 48 | - location: docker 49 | - data: abl-04-256-mh-dist 50 | - discriminator: pix2pixhd_nlayer 51 | - optimizers: default_optimizers 52 | - visualizer: directory 53 | - evaluator: default_inpainted 54 | - trainer: any_gpu_large_ssim_ddp_final 55 | - hydra: overrides 56 | -------------------------------------------------------------------------------- /configs/training/data/abl-02-thin-bb.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # try to resemble mask generation of DeepFill v2 4 | # official tf version: https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_ops.py#L168 5 | # pytorch version: https://github.com/zhaoyuzhi/deepfillv2/blob/62dad2c601400e14d79f4d1e090c2effcb9bf3eb/deepfillv2/dataset.py#L40 6 | # another unofficial pytorch version: https://github.com/avalonstrel/GatedConvolution/blob/master/config/inpaint.yml 7 | # they are a bit different, official version has slightly larger masks 8 | 9 | batch_size: 10 10 | val_batch_size: 2 11 | num_workers: 3 12 | 13 | train: 14 | indir: ${location.data_root_dir}/train 15 | out_size: 256 16 | 17 | mask_gen_kwargs: # probabilities do not need to sum to 1, they are re-normalized in mask generator 18 | irregular_proba: 1 19 | irregular_kwargs: 20 | max_angle: 4 21 | max_len: 80 # math.sqrt(H*H+W*W) / 8 + math.sqrt(H*H+W*W) / 16 https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_ops.py#L189 22 | max_width: 40 23 | max_times: 12 24 | min_times: 4 25 | 26 | box_proba: 1 27 | box_kwargs: 28 | margin: 0 29 | bbox_min_size: 30 30 | bbox_max_size: 128 31 | max_times: 1 32 | min_times: 1 33 | 34 | segm_proba: 0 # not working yet due to RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method 35 | 36 | transform_variant: default 37 | dataloader_kwargs: 38 | batch_size: ${data.batch_size} 39 | shuffle: True 40 | num_workers: ${data.num_workers} 41 | 42 | val: 43 | indir: ${location.data_root_dir}/val 44 | img_suffix: .png 45 | dataloader_kwargs: 46 | batch_size: ${data.val_batch_size} 47 | shuffle: False 48 | num_workers: ${data.num_workers} 49 | 50 | #extra_val: 51 | # random_thin_256: 52 | # indir: ${location.data_root_dir}/extra_val/random_thin_256 53 | # img_suffix: .png 54 | # dataloader_kwargs: 55 | # batch_size: ${data.val_batch_size} 56 | # shuffle: False 57 | # num_workers: ${data.num_workers} 58 | # random_medium_256: 59 | # indir: ${location.data_root_dir}/extra_val/random_medium_256 60 | # img_suffix: .png 61 | # dataloader_kwargs: 62 | # batch_size: ${data.val_batch_size} 63 | # shuffle: False 64 | # num_workers: ${data.num_workers} 65 | # random_thick_256: 66 | # indir: ${location.data_root_dir}/extra_val/random_thick_256 67 | # img_suffix: .png 68 | # dataloader_kwargs: 69 | # batch_size: ${data.val_batch_size} 70 | # shuffle: False 71 | # num_workers: ${data.num_workers} 72 | # random_thin_512: 73 | # indir: ${location.data_root_dir}/extra_val/random_thin_512 74 | # img_suffix: .png 75 | # dataloader_kwargs: 76 | # batch_size: ${data.val_batch_size} 77 | # shuffle: False 78 | # num_workers: ${data.num_workers} 79 | # random_medium_512: 80 | # indir: ${location.data_root_dir}/extra_val/random_medium_512 81 | # img_suffix: .png 82 | # dataloader_kwargs: 83 | # batch_size: ${data.val_batch_size} 84 | # shuffle: False 85 | # num_workers: ${data.num_workers} 86 | # random_thick_512: 87 | # indir: ${location.data_root_dir}/extra_val/random_thick_512 88 | # img_suffix: .png 89 | # dataloader_kwargs: 90 | # batch_size: ${data.val_batch_size} 91 | # shuffle: False 92 | # num_workers: ${data.num_workers} 93 | # segm_256: 94 | # indir: ${location.data_root_dir}/extra_val/segm_256 95 | # img_suffix: .png 96 | # dataloader_kwargs: 97 | # batch_size: ${data.val_batch_size} 98 | # shuffle: False 99 | # num_workers: ${data.num_workers} 100 | # segm_512: 101 | # indir: ${location.data_root_dir}/extra_val/segm_512 102 | # img_suffix: .png 103 | # dataloader_kwargs: 104 | # batch_size: ${data.val_batch_size} 105 | # shuffle: False 106 | # num_workers: ${data.num_workers} 107 | 108 | visual_test: 109 | indir: ${location.data_root_dir}/visual_test 110 | img_suffix: _input.png 111 | pad_out_to_modulo: 32 112 | dataloader_kwargs: 113 | batch_size: 1 114 | shuffle: False 115 | num_workers: ${data.num_workers} 116 | -------------------------------------------------------------------------------- /configs/training/data/abl-04-256-mh-dist-celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 5 4 | val_batch_size: 3 5 | num_workers: 3 6 | 7 | train: 8 | indir: ${location.data_root_dir}/train_256 9 | out_size: 256 10 | mask_gen_kwargs: # probabilities do not need to sum to 1, they are re-normalized in mask generator 11 | irregular_proba: 1 12 | irregular_kwargs: 13 | max_angle: 4 14 | max_len: 200 15 | max_width: 100 16 | max_times: 5 17 | min_times: 1 18 | 19 | box_proba: 1 20 | box_kwargs: 21 | margin: 10 22 | bbox_min_size: 30 23 | bbox_max_size: 150 24 | max_times: 4 25 | min_times: 1 26 | 27 | segm_proba: 0 28 | 29 | transform_variant: no_augs 30 | dataloader_kwargs: 31 | batch_size: ${data.batch_size} 32 | shuffle: True 33 | num_workers: ${data.num_workers} 34 | 35 | val: 36 | indir: ${location.data_root_dir}/val_256 37 | img_suffix: .png 38 | dataloader_kwargs: 39 | batch_size: ${data.val_batch_size} 40 | shuffle: False 41 | num_workers: ${data.num_workers} 42 | 43 | visual_test: null 44 | -------------------------------------------------------------------------------- /configs/training/data/abl-04-256-mh-dist-web.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 10 4 | val_batch_size: 2 5 | num_workers: 3 6 | 7 | train: 8 | kind: default_web 9 | shuffle_buffer: 200 10 | indir: ${location.data_root_dir}/train_standard/part{00000..00039}.tar 11 | out_size: 256 12 | mask_gen_kwargs: # probabilities do not need to sum to 1, they are re-normalized in mask generator 13 | irregular_proba: 1 14 | irregular_kwargs: 15 | max_angle: 4 16 | max_len: 200 17 | max_width: 100 18 | max_times: 5 19 | min_times: 1 20 | 21 | box_proba: 1 22 | box_kwargs: 23 | margin: 10 24 | bbox_min_size: 30 25 | bbox_max_size: 150 26 | max_times: 4 27 | min_times: 1 28 | 29 | segm_proba: 0 30 | 31 | transform_variant: distortions 32 | dataloader_kwargs: 33 | batch_size: ${data.batch_size} 34 | shuffle: True 35 | num_workers: ${data.num_workers} 36 | 37 | val: 38 | indir: ${location.data_root_dir}/val 39 | img_suffix: .png 40 | dataloader_kwargs: 41 | batch_size: ${data.val_batch_size} 42 | shuffle: False 43 | num_workers: ${data.num_workers} 44 | 45 | #extra_val: 46 | # random_thin_256: 47 | # indir: ${location.data_root_dir}/final_extra_val/random_thin_256 48 | # img_suffix: .png 49 | # dataloader_kwargs: 50 | # batch_size: ${data.val_batch_size} 51 | # shuffle: False 52 | # num_workers: ${data.num_workers} 53 | # random_medium_256: 54 | # indir: ${location.data_root_dir}/final_extra_val/random_medium_256 55 | # img_suffix: .png 56 | # dataloader_kwargs: 57 | # batch_size: ${data.val_batch_size} 58 | # shuffle: False 59 | # num_workers: ${data.num_workers} 60 | # random_thick_256: 61 | # indir: ${location.data_root_dir}/final_extra_val/random_thick_256 62 | # img_suffix: .png 63 | # dataloader_kwargs: 64 | # batch_size: ${data.val_batch_size} 65 | # shuffle: False 66 | # num_workers: ${data.num_workers} 67 | # random_thin_512: 68 | # indir: ${location.data_root_dir}/final_extra_val/random_thin_512 69 | # img_suffix: .png 70 | # dataloader_kwargs: 71 | # batch_size: ${data.val_batch_size} 72 | # shuffle: False 73 | # num_workers: ${data.num_workers} 74 | # random_medium_512: 75 | # indir: ${location.data_root_dir}/final_extra_val/random_medium_512 76 | # img_suffix: .png 77 | # dataloader_kwargs: 78 | # batch_size: ${data.val_batch_size} 79 | # shuffle: False 80 | # num_workers: ${data.num_workers} 81 | # random_thick_512: 82 | # indir: ${location.data_root_dir}/final_extra_val/random_thick_512 83 | # img_suffix: .png 84 | # dataloader_kwargs: 85 | # batch_size: ${data.val_batch_size} 86 | # shuffle: False 87 | # num_workers: ${data.num_workers} 88 | # segm_256: 89 | # indir: ${location.data_root_dir}/final_extra_val/segm_256 90 | # img_suffix: .png 91 | # dataloader_kwargs: 92 | # batch_size: ${data.val_batch_size} 93 | # shuffle: False 94 | # num_workers: ${data.num_workers} 95 | # segm_512: 96 | # indir: ${location.data_root_dir}/final_extra_val/segm_512 97 | # img_suffix: .png 98 | # dataloader_kwargs: 99 | # batch_size: ${data.val_batch_size} 100 | # shuffle: False 101 | # num_workers: ${data.num_workers} 102 | 103 | visual_test: 104 | indir: ${location.data_root_dir}/visual_test 105 | img_suffix: _input.png 106 | pad_out_to_modulo: 32 107 | dataloader_kwargs: 108 | batch_size: 1 109 | shuffle: False 110 | num_workers: ${data.num_workers} 111 | -------------------------------------------------------------------------------- /configs/training/data/abl-04-256-mh-dist.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 10 4 | val_batch_size: 2 5 | num_workers: 3 6 | 7 | train: 8 | indir: ${location.data_root_dir}/train 9 | out_size: 256 10 | mask_gen_kwargs: # probabilities do not need to sum to 1, they are re-normalized in mask generator 11 | irregular_proba: 1 12 | irregular_kwargs: 13 | max_angle: 4 14 | max_len: 200 15 | max_width: 100 16 | max_times: 5 17 | min_times: 1 18 | 19 | box_proba: 1 20 | box_kwargs: 21 | margin: 10 22 | bbox_min_size: 30 23 | bbox_max_size: 150 24 | max_times: 4 25 | min_times: 1 26 | 27 | segm_proba: 0 28 | 29 | transform_variant: distortions 30 | dataloader_kwargs: 31 | batch_size: ${data.batch_size} 32 | shuffle: True 33 | num_workers: ${data.num_workers} 34 | 35 | val: 36 | indir: ${location.data_root_dir}/val 37 | img_suffix: .png 38 | dataloader_kwargs: 39 | batch_size: ${data.val_batch_size} 40 | shuffle: False 41 | num_workers: ${data.num_workers} 42 | 43 | #extra_val: 44 | # random_thin_256: 45 | # indir: ${location.data_root_dir}/extra_val/random_thin_256 46 | # img_suffix: .png 47 | # dataloader_kwargs: 48 | # batch_size: ${data.val_batch_size} 49 | # shuffle: False 50 | # num_workers: ${data.num_workers} 51 | # random_medium_256: 52 | # indir: ${location.data_root_dir}/extra_val/random_medium_256 53 | # img_suffix: .png 54 | # dataloader_kwargs: 55 | # batch_size: ${data.val_batch_size} 56 | # shuffle: False 57 | # num_workers: ${data.num_workers} 58 | # random_thick_256: 59 | # indir: ${location.data_root_dir}/extra_val/random_thick_256 60 | # img_suffix: .png 61 | # dataloader_kwargs: 62 | # batch_size: ${data.val_batch_size} 63 | # shuffle: False 64 | # num_workers: ${data.num_workers} 65 | # random_thin_512: 66 | # indir: ${location.data_root_dir}/extra_val/random_thin_512 67 | # img_suffix: .png 68 | # dataloader_kwargs: 69 | # batch_size: ${data.val_batch_size} 70 | # shuffle: False 71 | # num_workers: ${data.num_workers} 72 | # random_medium_512: 73 | # indir: ${location.data_root_dir}/extra_val/random_medium_512 74 | # img_suffix: .png 75 | # dataloader_kwargs: 76 | # batch_size: ${data.val_batch_size} 77 | # shuffle: False 78 | # num_workers: ${data.num_workers} 79 | # random_thick_512: 80 | # indir: ${location.data_root_dir}/extra_val/random_thick_512 81 | # img_suffix: .png 82 | # dataloader_kwargs: 83 | # batch_size: ${data.val_batch_size} 84 | # shuffle: False 85 | # num_workers: ${data.num_workers} 86 | # segm_256: 87 | # indir: ${location.data_root_dir}/extra_val/segm_256 88 | # img_suffix: .png 89 | # dataloader_kwargs: 90 | # batch_size: ${data.val_batch_size} 91 | # shuffle: False 92 | # num_workers: ${data.num_workers} 93 | # segm_512: 94 | # indir: ${location.data_root_dir}/extra_val/segm_512 95 | # img_suffix: .png 96 | # dataloader_kwargs: 97 | # batch_size: ${data.val_batch_size} 98 | # shuffle: False 99 | # num_workers: ${data.num_workers} 100 | 101 | visual_test: 102 | indir: ${location.data_root_dir}/visual_test 103 | img_suffix: .png 104 | pad_out_to_modulo: 32 105 | dataloader_kwargs: 106 | batch_size: 1 107 | shuffle: False 108 | num_workers: ${data.num_workers} 109 | -------------------------------------------------------------------------------- /configs/training/discriminator/pix2pixhd_nlayer.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: pix2pixhd_nlayer 3 | input_nc: 3 4 | ndf: 64 5 | n_layers: 4 6 | -------------------------------------------------------------------------------- /configs/training/evaluator/default_inpainted.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: default 3 | inpainted_key: inpainted # if you want to evaluate before blending with original image by mask, set predicted_image 4 | integral_kind: ssim_fid100_f1 5 | -------------------------------------------------------------------------------- /configs/training/generator/ffc_resnet_075.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: ffc_resnet 3 | input_nc: 4 4 | output_nc: 3 5 | ngf: 64 6 | n_downsampling: 3 7 | n_blocks: 9 8 | add_out_act: sigmoid 9 | 10 | init_conv_kwargs: 11 | ratio_gin: 0 12 | ratio_gout: 0 13 | enable_lfu: False 14 | 15 | downsample_conv_kwargs: 16 | ratio_gin: ${generator.init_conv_kwargs.ratio_gout} 17 | ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin} 18 | enable_lfu: False 19 | 20 | resnet_conv_kwargs: 21 | ratio_gin: 0.75 22 | ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin} 23 | enable_lfu: False 24 | -------------------------------------------------------------------------------- /configs/training/generator/pix2pixhd_global.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: pix2pixhd_global 3 | input_nc: 4 4 | output_nc: 3 5 | ngf: 64 6 | n_downsampling: 3 7 | n_blocks: 9 8 | conv_kind: default -------------------------------------------------------------------------------- /configs/training/generator/pix2pixhd_global_sigmoid.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: pix2pixhd_global 3 | input_nc: 4 4 | output_nc: 3 5 | ngf: 64 6 | n_downsampling: 3 7 | n_blocks: 9 8 | conv_kind: default 9 | add_out_act: sigmoid 10 | -------------------------------------------------------------------------------- /configs/training/generator/pix2pixhd_multidilated_catin_4dil_9b.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: pix2pixhd_multidilated 3 | input_nc: 4 4 | output_nc: 3 5 | ngf: 64 6 | n_downsampling: 3 7 | n_blocks: 9 8 | conv_kind: default 9 | add_out_act: sigmoid 10 | multidilation_kwargs: 11 | comb_mode: cat_in 12 | dilation_num: 4 13 | -------------------------------------------------------------------------------- /configs/training/hydra/no_time.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | run: 3 | dir: ${location.out_root_dir}/${env:USER}_${hydra:job.name}_${hydra:job.config_name}_${run_title} 4 | sweep: 5 | dir: ${hydra:run.dir}_sweep 6 | subdir: ${hydra.job.num} 7 | -------------------------------------------------------------------------------- /configs/training/hydra/overrides.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | run: 3 | dir: ${location.out_root_dir}/${env:USER}_${now:%Y-%m-%d_%H-%M-%S}_${hydra:job.name}_${hydra:job.config_name}_${run_title} 4 | sweep: 5 | dir: ${hydra:run.dir}_sweep 6 | subdir: ${hydra.job.num} 7 | -------------------------------------------------------------------------------- /configs/training/lama-fourier-celeba.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0 14 | adversarial: 15 | kind: r1 16 | weight: 10 17 | gp_coef: 0.001 18 | mask_as_fake_target: true 19 | allow_scale_mask: true 20 | feature_matching: 21 | weight: 100 22 | resnet_pl: 23 | weight: 30 24 | weights_path: ${env:TORCH_HOME} 25 | 26 | defaults: 27 | - location: celeba 28 | - data: abl-04-256-mh-dist-celeba 29 | - generator: ffc_resnet_075 30 | - discriminator: pix2pixhd_nlayer 31 | - optimizers: default_optimizers 32 | - visualizer: directory 33 | - evaluator: default_inpainted 34 | - trainer: any_gpu_large_ssim_ddp_final_celeba 35 | - hydra: overrides -------------------------------------------------------------------------------- /configs/training/lama-fourier.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0 14 | adversarial: 15 | kind: r1 16 | weight: 10 17 | gp_coef: 0.001 18 | mask_as_fake_target: true 19 | allow_scale_mask: true 20 | feature_matching: 21 | weight: 100 22 | resnet_pl: 23 | weight: 30 24 | weights_path: ${env:TORCH_HOME} 25 | 26 | defaults: 27 | - location: docker 28 | - data: abl-04-256-mh-dist 29 | - generator: ffc_resnet_075 30 | - discriminator: pix2pixhd_nlayer 31 | - optimizers: default_optimizers 32 | - visualizer: directory 33 | - evaluator: default_inpainted 34 | - trainer: any_gpu_large_ssim_ddp_final 35 | - hydra: overrides -------------------------------------------------------------------------------- /configs/training/lama-regular-celeba.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0 14 | adversarial: 15 | kind: r1 16 | weight: 10 17 | gp_coef: 0.001 18 | mask_as_fake_target: true 19 | allow_scale_mask: true 20 | feature_matching: 21 | weight: 100 22 | resnet_pl: 23 | weight: 30 24 | weights_path: ${env:TORCH_HOME} 25 | 26 | defaults: 27 | - location: celeba 28 | - data: abl-04-256-mh-dist-celeba 29 | - generator: pix2pixhd_global_sigmoid 30 | - discriminator: pix2pixhd_nlayer 31 | - optimizers: default_optimizers 32 | - visualizer: directory 33 | - evaluator: default_inpainted 34 | - trainer: any_gpu_large_ssim_ddp_final_celeba 35 | - hydra: overrides 36 | -------------------------------------------------------------------------------- /configs/training/lama-regular.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | losses: 9 | l1: 10 | weight_missing: 0 11 | weight_known: 10 12 | perceptual: 13 | weight: 0 14 | adversarial: 15 | kind: r1 16 | weight: 10 17 | gp_coef: 0.001 18 | mask_as_fake_target: true 19 | allow_scale_mask: true 20 | feature_matching: 21 | weight: 100 22 | resnet_pl: 23 | weight: 30 24 | weights_path: ${env:TORCH_HOME} 25 | 26 | defaults: 27 | - location: docker 28 | - data: abl-04-256-mh-dist 29 | - generator: pix2pixhd_global_sigmoid 30 | - discriminator: pix2pixhd_nlayer 31 | - optimizers: default_optimizers 32 | - visualizer: directory 33 | - evaluator: default_inpainted 34 | - trainer: any_gpu_large_ssim_ddp_final 35 | - hydra: overrides 36 | -------------------------------------------------------------------------------- /configs/training/lama_small_train_masks.yaml: -------------------------------------------------------------------------------- 1 | run_title: '' 2 | 3 | training_model: 4 | kind: default 5 | visualize_each_iters: 1000 6 | concat_mask: true 7 | store_discr_outputs_for_vis: true 8 | 9 | losses: 10 | l1: 11 | weight_missing: 0 12 | weight_known: 10 13 | perceptual: 14 | weight: 0 15 | adversarial: 16 | kind: r1 17 | weight: 10 18 | gp_coef: 0.001 19 | mask_as_fake_target: true 20 | allow_scale_mask: true 21 | feature_matching: 22 | weight: 100 23 | resnet_pl: 24 | weight: 30 25 | weights_path: ${env:TORCH_HOME} 26 | 27 | defaults: 28 | - location: docker 29 | - data: abl-02-thin-bb 30 | - generator: pix2pixhd_sigmoid 31 | - discriminator: pix2pixhd_nlayer 32 | - optimizers: default_optimizers 33 | - visualizer: directory 34 | - evaluator: default_inpainted 35 | - trainer: any_gpu_large_ssim_ddp_final 36 | - hydra: overrides -------------------------------------------------------------------------------- /configs/training/location/celeba_example.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data_root_dir: /home/user/lama/celeba-hq-dataset/ 3 | out_root_dir: /home/user/lama/experiments/ 4 | tb_dir: /home/user/lama/tb_logs/ 5 | pretrained_models: /home/user/lama/ 6 | -------------------------------------------------------------------------------- /configs/training/location/docker.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data_root_dir: /data/data 3 | out_root_dir: /data/experiments 4 | tb_dir: /data/tb_logs 5 | pretrained_models: /some_path 6 | -------------------------------------------------------------------------------- /configs/training/location/places_example.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data_root_dir: /home/user/inpainting-lama/places_standard_dataset/ 3 | out_root_dir: /home/user/inpainting-lama/experiments 4 | tb_dir: /home/user/inpainting-lama/tb_logs 5 | pretrained_models: /home/user/inpainting-lama/ 6 | -------------------------------------------------------------------------------- /configs/training/optimizers/default_optimizers.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | generator: 3 | kind: adam 4 | lr: 0.001 5 | discriminator: 6 | kind: adam 7 | lr: 0.0001 8 | -------------------------------------------------------------------------------- /configs/training/trainer/any_gpu_large_ssim_ddp_final.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kwargs: 3 | gpus: -1 4 | accelerator: ddp 5 | max_epochs: 40 6 | gradient_clip_val: 1 7 | log_gpu_memory: None # set to min_max or all for debug 8 | limit_train_batches: 25000 9 | val_check_interval: ${trainer.kwargs.limit_train_batches} 10 | # fast_dev_run: True # uncomment for faster debug 11 | # track_grad_norm: 2 # uncomment to track L2 gradients norm 12 | log_every_n_steps: 250 13 | precision: 32 14 | # precision: 16 15 | # amp_backend: native 16 | # amp_level: O1 17 | # resume_from_checkpoint: path # override via command line trainer.resume_from_checkpoint=path_to_checkpoint 18 | terminate_on_nan: False 19 | # auto_scale_batch_size: True # uncomment to find largest batch size 20 | check_val_every_n_epoch: 1 21 | num_sanity_val_steps: 8 22 | # limit_val_batches: 1000000 23 | replace_sampler_ddp: False 24 | 25 | checkpoint_kwargs: 26 | verbose: True 27 | save_top_k: 5 28 | save_last: True 29 | period: 1 30 | monitor: val_ssim_fid100_f1_total_mean 31 | mode: max -------------------------------------------------------------------------------- /configs/training/trainer/any_gpu_large_ssim_ddp_final_benchmark.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kwargs: 3 | gpus: -1 4 | accelerator: ddp 5 | max_epochs: 40 6 | gradient_clip_val: 1 7 | log_gpu_memory: None # set to min_max or all for debug 8 | limit_train_batches: 25000 9 | val_check_interval: ${trainer.kwargs.limit_train_batches} 10 | # fast_dev_run: True # uncomment for faster debug 11 | # track_grad_norm: 2 # uncomment to track L2 gradients norm 12 | log_every_n_steps: 250 13 | precision: 32 14 | # precision: 16 15 | # amp_backend: native 16 | # amp_level: O1 17 | # resume_from_checkpoint: path # override via command line trainer.resume_from_checkpoint=path_to_checkpoint 18 | terminate_on_nan: False 19 | # auto_scale_batch_size: True # uncomment to find largest batch size 20 | check_val_every_n_epoch: 1 21 | num_sanity_val_steps: 8 22 | # limit_val_batches: 1000000 23 | replace_sampler_ddp: False 24 | benchmark: True 25 | 26 | checkpoint_kwargs: 27 | verbose: True 28 | save_top_k: 5 29 | save_last: True 30 | period: 1 31 | monitor: val_ssim_fid100_f1_total_mean 32 | mode: max 33 | -------------------------------------------------------------------------------- /configs/training/trainer/any_gpu_large_ssim_ddp_final_celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kwargs: 3 | gpus: -1 4 | accelerator: ddp 5 | max_epochs: 40 6 | gradient_clip_val: 1 7 | log_gpu_memory: None 8 | limit_train_batches: 25000 9 | val_check_interval: 2600 10 | log_every_n_steps: 250 11 | precision: 32 12 | terminate_on_nan: False 13 | check_val_every_n_epoch: 1 14 | num_sanity_val_steps: 8 15 | replace_sampler_ddp: False 16 | checkpoint_kwargs: 17 | verbose: True 18 | save_top_k: 5 19 | save_last: True 20 | period: 1 21 | monitor: val_ssim_fid100_f1_total_mean 22 | mode: max -------------------------------------------------------------------------------- /configs/training/visualizer/directory.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | kind: directory 3 | outdir: samples 4 | key_order: 5 | - image 6 | - predicted_image 7 | - discr_output_fake 8 | - discr_output_real 9 | - inpainted 10 | rescale_keys: 11 | - discr_output_fake 12 | - discr_output_real 13 | -------------------------------------------------------------------------------- /docker/1_generate_masks_from_raw_images.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | if (( $# < 3 )) 5 | then 6 | echo "Usage: $0 config_name input_images_dir image_mask_dataset_out_dir [other args to gen_mask_dataset.py]" 7 | exit 1 8 | fi 9 | 10 | CURDIR="$(dirname $0)" 11 | SRCDIR="$CURDIR/.." 12 | SRCDIR="$(realpath $SRCDIR)" 13 | 14 | CONFIG_LOCAL_PATH="$(realpath $1)" 15 | INPUT_LOCAL_DIR="$(realpath $2)" 16 | OUTPUT_LOCAL_DIR="$(realpath $3)" 17 | shift 3 18 | 19 | mkdir -p "$OUTPUT_LOCAL_DIR" 20 | 21 | docker run \ 22 | -v "$SRCDIR":/home/user/project \ 23 | -v "$CONFIG_LOCAL_PATH":/data/config.yaml \ 24 | -v "$INPUT_LOCAL_DIR":/data/input \ 25 | -v "$OUTPUT_LOCAL_DIR":/data/output \ 26 | -u $(id -u):$(id -g) \ 27 | --name="lama-mask-gen" \ 28 | --rm \ 29 | windj007/lama \ 30 | /home/user/project/bin/gen_mask_dataset.py \ 31 | /data/config.yaml /data/input /data/output $@ 32 | -------------------------------------------------------------------------------- /docker/2_predict_with_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | if (( $# < 3 )) 5 | then 6 | echo "Usage: $0 model_dir input_dir output_dir [other arguments to predict.py]" 7 | exit 1 8 | fi 9 | 10 | CURDIR="$(dirname $0)" 11 | SRCDIR="$CURDIR/.." 12 | SRCDIR="$(realpath $SRCDIR)" 13 | 14 | MODEL_LOCAL_DIR="$(realpath $1)" 15 | INPUT_LOCAL_DIR="$(realpath $2)" 16 | OUTPUT_LOCAL_DIR="$(realpath $3)" 17 | shift 3 18 | 19 | mkdir -p "$OUTPUT_LOCAL_DIR" 20 | 21 | docker run \ 22 | -v "$SRCDIR":/home/user/project \ 23 | -v "$MODEL_LOCAL_DIR":/data/checkpoint \ 24 | -v "$INPUT_LOCAL_DIR":/data/input \ 25 | -v "$OUTPUT_LOCAL_DIR":/data/output \ 26 | -u $(id -u):$(id -g) \ 27 | --gpus all \ 28 | --name="lama-predict" \ 29 | --rm \ 30 | windj007/lama \ 31 | /home/user/project/bin/predict.py \ 32 | model.path=/data/checkpoint \ 33 | indir=/data/input \ 34 | outdir=/data/output \ 35 | dataset.img_suffix=.png \ 36 | $@ 37 | -------------------------------------------------------------------------------- /docker/3_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | if (( $# < 3 )) 5 | then 6 | echo "Usage: $0 original_dataset_dir predictions_dir output_dir [other arguments to evaluate_predicts.py]" 7 | exit 1 8 | fi 9 | 10 | CURDIR="$(dirname $0)" 11 | SRCDIR="$CURDIR/.." 12 | SRCDIR="$(realpath $SRCDIR)" 13 | 14 | ORIG_DATASET_LOCAL_DIR="$(realpath $1)" 15 | PREDICTIONS_LOCAL_DIR="$(realpath $2)" 16 | OUTPUT_LOCAL_DIR="$(realpath $3)" 17 | shift 3 18 | 19 | mkdir -p "$OUTPUT_LOCAL_DIR" 20 | 21 | docker run \ 22 | -v "$SRCDIR":/home/user/project \ 23 | -v "$ORIG_DATASET_LOCAL_DIR":/data/orig_dataset \ 24 | -v "$PREDICTIONS_LOCAL_DIR":/data/predictions \ 25 | -v "$OUTPUT_LOCAL_DIR":/data/output \ 26 | -u $(id -u):$(id -g) \ 27 | --name="lama-eval" \ 28 | --rm \ 29 | windj007/lama \ 30 | /home/user/project/bin/evaluate_predicts.py \ 31 | /home/user/project/configs/eval2_cpu.yaml \ 32 | /data/orig_dataset \ 33 | /data/predictions \ 34 | /data/output/metrics.yaml \ 35 | $@ 36 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.2-runtime-ubuntu18.04 2 | 3 | RUN apt-get update && \ 4 | apt-get upgrade -y && \ 5 | apt-get install -y wget mc tmux nano build-essential rsync libgl1 6 | 7 | ARG USERNAME=user 8 | RUN apt-get install -y sudo && \ 9 | addgroup --gid 1000 $USERNAME && \ 10 | adduser --uid 1000 --gid 1000 --disabled-password --gecos '' $USERNAME && \ 11 | adduser $USERNAME sudo && \ 12 | echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && \ 13 | USER=$USERNAME && \ 14 | GROUP=$USERNAME 15 | 16 | USER $USERNAME:$USERNAME 17 | WORKDIR "/home/$USERNAME" 18 | ENV PATH="/home/$USERNAME/miniconda3/bin:/home/$USERNAME/.local/bin:${PATH}" 19 | ENV PYTHONPATH="/home/$USERNAME/project" 20 | 21 | RUN wget -O /tmp/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py39_4.9.2-Linux-x86_64.sh && \ 22 | echo "536817d1b14cb1ada88900f5be51ce0a5e042bae178b5550e62f61e223deae7c /tmp/miniconda.sh" > /tmp/miniconda.sh.sha256 && \ 23 | sha256sum --check --status < /tmp/miniconda.sh.sha256 && \ 24 | bash /tmp/miniconda.sh -bt -p "/home/$USERNAME/miniconda3" && \ 25 | rm /tmp/miniconda.sh && \ 26 | conda build purge && \ 27 | conda init 28 | 29 | RUN pip install -U pip 30 | RUN pip install numpy scipy torch==1.8.1 torchvision opencv-python tensorflow joblib matplotlib pandas \ 31 | albumentations==0.5.2 pytorch-lightning==1.2.9 tabulate easydict==1.9.0 kornia==0.5.0 webdataset \ 32 | packaging gpustat tqdm pyyaml hydra-core==1.1.0.dev6 scikit-learn==0.24.2 tabulate 33 | RUN pip install scikit-image==0.17.2 34 | 35 | ENV TORCH_HOME="/home/$USERNAME/.torch" 36 | 37 | ADD entrypoint.sh /home/$USERNAME/.local/bin/entrypoint.sh 38 | ENTRYPOINT [ "entrypoint.sh" ] 39 | -------------------------------------------------------------------------------- /docker/Dockerfile-cuda111: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.1-runtime-ubuntu18.04 2 | 3 | RUN apt-get update && \ 4 | apt-get upgrade -y && \ 5 | apt-get install -y wget mc tmux nano build-essential rsync libgl1 6 | 7 | ARG USERNAME=user 8 | RUN apt-get install -y sudo && \ 9 | addgroup --gid 1000 $USERNAME && \ 10 | adduser --uid 1000 --gid 1000 --disabled-password --gecos '' $USERNAME && \ 11 | adduser $USERNAME sudo && \ 12 | echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && \ 13 | USER=$USERNAME && \ 14 | GROUP=$USERNAME 15 | 16 | USER $USERNAME:$USERNAME 17 | WORKDIR "/home/$USERNAME" 18 | ENV PATH="/home/$USERNAME/miniconda3/bin:/home/$USERNAME/.local/bin:${PATH}" 19 | ENV PYTHONPATH="/home/$USERNAME/project" 20 | 21 | RUN wget -O /tmp/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py39_4.9.2-Linux-x86_64.sh && \ 22 | echo "536817d1b14cb1ada88900f5be51ce0a5e042bae178b5550e62f61e223deae7c /tmp/miniconda.sh" > /tmp/miniconda.sh.sha256 && \ 23 | sha256sum --check --status < /tmp/miniconda.sh.sha256 && \ 24 | bash /tmp/miniconda.sh -bt -p "/home/$USERNAME/miniconda3" && \ 25 | rm /tmp/miniconda.sh && \ 26 | conda build purge && \ 27 | conda init 28 | 29 | RUN pip install -U pip 30 | RUN pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html 31 | RUN pip install numpy scipy opencv-python tensorflow joblib matplotlib pandas \ 32 | albumentations==0.5.2 pytorch-lightning==1.2.9 tabulate easydict==1.9.0 kornia==0.5.0 webdataset \ 33 | packaging gpustat tqdm pyyaml hydra-core==1.1.0.dev6 scikit-learn==0.24.2 tabulate 34 | RUN pip install scikit-image==0.17.2 35 | 36 | ENV TORCH_HOME="/home/$USERNAME/.torch" 37 | 38 | ADD entrypoint.sh /home/$USERNAME/.local/bin/entrypoint.sh 39 | ENTRYPOINT [ "entrypoint.sh" ] 40 | -------------------------------------------------------------------------------- /docker/build-cuda111.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASEDIR="$(dirname $0)" 4 | 5 | docker build -t windj007/lama:cuda111 -f "$BASEDIR/Dockerfile-cuda111" "$BASEDIR" 6 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASEDIR="$(dirname $0)" 4 | 5 | docker build -t windj007/lama -f "$BASEDIR/Dockerfile" "$BASEDIR" 6 | -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | exec $@ 4 | -------------------------------------------------------------------------------- /fetch_data/celebahq_dataset_prepare.sh: -------------------------------------------------------------------------------- 1 | mkdir celeba-hq-dataset 2 | 3 | unzip data256x256.zip -d celeba-hq-dataset/ 4 | 5 | # Reindex 6 | for i in `echo {00001..30000}` 7 | do 8 | mv 'celeba-hq-dataset/data256x256/'$i'.jpg' 'celeba-hq-dataset/data256x256/'$[10#$i - 1]'.jpg' 9 | done 10 | 11 | 12 | # Split: split train -> train & val 13 | cat fetch_data/train_shuffled.flist | shuf > celeba-hq-dataset/temp_train_shuffled.flist 14 | cat celeba-hq-dataset/temp_train_shuffled.flist | head -n 2000 > celeba-hq-dataset/val_shuffled.flist 15 | cat celeba-hq-dataset/temp_train_shuffled.flist | tail -n +2001 > celeba-hq-dataset/train_shuffled.flist 16 | cat fetch_data/val_shuffled.flist > celeba-hq-dataset/visual_test_shuffled.flist 17 | 18 | mkdir celeba-hq-dataset/train_256/ 19 | mkdir celeba-hq-dataset/val_source_256/ 20 | mkdir celeba-hq-dataset/visual_test_source_256/ 21 | 22 | cat celeba-hq-dataset/train_shuffled.flist | xargs -I {} mv celeba-hq-dataset/data256x256/{} celeba-hq-dataset/train_256/ 23 | cat celeba-hq-dataset/val_shuffled.flist | xargs -I {} mv celeba-hq-dataset/data256x256/{} celeba-hq-dataset/val_source_256/ 24 | cat celeba-hq-dataset/visual_test_shuffled.flist | xargs -I {} mv celeba-hq-dataset/data256x256/{} celeba-hq-dataset/visual_test_source_256/ 25 | 26 | 27 | # create location config celeba.yaml 28 | PWD=$(pwd) 29 | DATASET=${PWD}/celeba-hq-dataset 30 | CELEBA=${PWD}/configs/training/location/celeba.yaml 31 | 32 | touch $CELEBA 33 | echo "# @package _group_" >> $CELEBA 34 | echo "data_root_dir: ${DATASET}/" >> $CELEBA 35 | echo "out_root_dir: ${PWD}/experiments/" >> $CELEBA 36 | echo "tb_dir: ${PWD}/tb_logs/" >> $CELEBA 37 | echo "pretrained_models: ${PWD}/" >> $CELEBA 38 | -------------------------------------------------------------------------------- /fetch_data/celebahq_gen_masks.sh: -------------------------------------------------------------------------------- 1 | python3 bin/gen_mask_dataset.py \ 2 | $(pwd)/configs/data_gen/random_thick_256.yaml \ 3 | celeba-hq-dataset/val_source_256/ \ 4 | celeba-hq-dataset/val_256/random_thick_256/ 5 | 6 | python3 bin/gen_mask_dataset.py \ 7 | $(pwd)/configs/data_gen/random_thin_256.yaml \ 8 | celeba-hq-dataset/val_source_256/ \ 9 | celeba-hq-dataset/val_256/random_thin_256/ 10 | 11 | python3 bin/gen_mask_dataset.py \ 12 | $(pwd)/configs/data_gen/random_medium_256.yaml \ 13 | celeba-hq-dataset/val_source_256/ \ 14 | celeba-hq-dataset/val_256/random_medium_256/ 15 | 16 | python3 bin/gen_mask_dataset.py \ 17 | $(pwd)/configs/data_gen/random_thick_256.yaml \ 18 | celeba-hq-dataset/visual_test_source_256/ \ 19 | celeba-hq-dataset/visual_test_256/random_thick_256/ 20 | 21 | python3 bin/gen_mask_dataset.py \ 22 | $(pwd)/configs/data_gen/random_thin_256.yaml \ 23 | celeba-hq-dataset/visual_test_source_256/ \ 24 | celeba-hq-dataset/visual_test_256/random_thin_256/ 25 | 26 | python3 bin/gen_mask_dataset.py \ 27 | $(pwd)/configs/data_gen/random_medium_256.yaml \ 28 | celeba-hq-dataset/visual_test_source_256/ \ 29 | celeba-hq-dataset/visual_test_256/random_medium_256/ 30 | -------------------------------------------------------------------------------- /fetch_data/eval_sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/' 5 | list_of_random_val_files = os.path.abspath('.') + '/places_standard_dataset/original/eval_random_files.txt' 6 | val_files = [val_files_path + image for image in os.listdir(val_files_path)] 7 | 8 | print(f'Sampling 30000 images out of {len(val_files)} images in {val_files_path}' + \ 9 | f'and put their paths to {list_of_random_val_files}') 10 | 11 | print('In our paper we evaluate trained models on these 30k sampled (mask,image) pairs in our paper (check Sup. mat.)') 12 | 13 | random.shuffle(val_files) 14 | val_files_random = val_files[0:30000] 15 | 16 | with open(list_of_random_val_files, 'w') as fw: 17 | for filename in val_files_random: 18 | fw.write(filename+'\n') 19 | print('...done') 20 | 21 | -------------------------------------------------------------------------------- /fetch_data/places_challenge_train_download.sh: -------------------------------------------------------------------------------- 1 | mkdir places_challenge_dataset 2 | 3 | 4 | declare -a TARPARTS 5 | for i in {a..z} 6 | do 7 | TARPARTS[${#TARPARTS[@]}]="http://data.csail.mit.edu/places/places365/train_large_split/${i}.tar" 8 | done 9 | ls 10 | printf "%s\n" "${TARPARTS[@]}" > places_challenge_dataset/places365_train.txt 11 | 12 | cd places_challenge_dataset/ 13 | xargs -a places365_train.txt -n 1 -P 8 wget [...] 14 | ls *.tar | xargs -i tar xvf {} 15 | -------------------------------------------------------------------------------- /fetch_data/places_standard_evaluation_prepare_data.sh: -------------------------------------------------------------------------------- 1 | # 0. folder preparation 2 | mkdir -p places_standard_dataset/evaluation/hires/ 3 | mkdir -p places_standard_dataset/evaluation/random_thick_512/ 4 | mkdir -p places_standard_dataset/evaluation/random_thin_512/ 5 | mkdir -p places_standard_dataset/evaluation/random_medium_512/ 6 | mkdir -p places_standard_dataset/evaluation/random_thick_256/ 7 | mkdir -p places_standard_dataset/evaluation/random_thin_256/ 8 | mkdir -p places_standard_dataset/evaluation/random_medium_256/ 9 | 10 | # 1. sample 30000 new images 11 | OUT=$(python3 fetch_data/eval_sampler.py) 12 | echo ${OUT} 13 | 14 | FILELIST=$(cat places_standard_dataset/original/eval_random_files.txt) 15 | for i in $FILELIST 16 | do 17 | $(cp ${i} places_standard_dataset/evaluation/hires/) 18 | done 19 | 20 | 21 | # 2. generate all kinds of masks 22 | 23 | # all 512 24 | python3 bin/gen_mask_dataset.py \ 25 | $(pwd)/configs/data_gen/random_thick_512.yaml \ 26 | places_standard_dataset/evaluation/hires \ 27 | places_standard_dataset/evaluation/random_thick_512/ 28 | 29 | python3 bin/gen_mask_dataset.py \ 30 | $(pwd)/configs/data_gen/random_thin_512.yaml \ 31 | places_standard_dataset/evaluation/hires \ 32 | places_standard_dataset/evaluation/random_thin_512/ 33 | 34 | python3 bin/gen_mask_dataset.py \ 35 | $(pwd)/configs/data_gen/random_medium_512.yaml \ 36 | places_standard_dataset/evaluation/hires \ 37 | places_standard_dataset/evaluation/random_medium_512/ 38 | 39 | python3 bin/gen_mask_dataset.py \ 40 | $(pwd)/configs/data_gen/random_thick_256.yaml \ 41 | places_standard_dataset/evaluation/hires \ 42 | places_standard_dataset/evaluation/random_thick_256/ 43 | 44 | python3 bin/gen_mask_dataset.py \ 45 | $(pwd)/configs/data_gen/random_thin_256.yaml \ 46 | places_standard_dataset/evaluation/hires \ 47 | places_standard_dataset/evaluation/random_thin_256/ 48 | 49 | python3 bin/gen_mask_dataset.py \ 50 | $(pwd)/configs/data_gen/random_medium_256.yaml \ 51 | places_standard_dataset/evaluation/hires \ 52 | places_standard_dataset/evaluation/random_medium_256/ 53 | -------------------------------------------------------------------------------- /fetch_data/places_standard_test_val_gen_masks.sh: -------------------------------------------------------------------------------- 1 | mkdir -p places_standard_dataset/val/ 2 | mkdir -p places_standard_dataset/visual_test/ 3 | 4 | 5 | python3 bin/gen_mask_dataset.py \ 6 | $(pwd)/configs/data_gen/random_thick_512.yaml \ 7 | places_standard_dataset/val_hires/ \ 8 | places_standard_dataset/val/ 9 | 10 | python3 bin/gen_mask_dataset.py \ 11 | $(pwd)/configs/data_gen/random_thick_512.yaml \ 12 | places_standard_dataset/visual_test_hires/ \ 13 | places_standard_dataset/visual_test/ -------------------------------------------------------------------------------- /fetch_data/places_standard_test_val_prepare.sh: -------------------------------------------------------------------------------- 1 | mkdir -p places_standard_dataset/original/test/ 2 | tar -xvf test_large.tar -C places_standard_dataset/original/test/ 3 | 4 | mkdir -p places_standard_dataset/original/val/ 5 | tar -xvf val_large.tar -C places_standard_dataset/original/val/ 6 | -------------------------------------------------------------------------------- /fetch_data/places_standard_test_val_sample.sh: -------------------------------------------------------------------------------- 1 | mkdir -p places_standard_dataset/val_hires/ 2 | mkdir -p places_standard_dataset/visual_test_hires/ 3 | 4 | 5 | # randomly sample images for test and vis 6 | OUT=$(python3 fetch_data/sampler.py) 7 | echo ${OUT} 8 | 9 | FILELIST=$(cat places_standard_dataset/original/test_random_files.txt) 10 | 11 | for i in $FILELIST 12 | do 13 | $(cp ${i} places_standard_dataset/val_hires/) 14 | done 15 | 16 | FILELIST=$(cat places_standard_dataset/original/val_random_files.txt) 17 | 18 | for i in $FILELIST 19 | do 20 | $(cp ${i} places_standard_dataset/visual_test_hires/) 21 | done 22 | 23 | -------------------------------------------------------------------------------- /fetch_data/places_standard_train_prepare.sh: -------------------------------------------------------------------------------- 1 | mkdir -p places_standard_dataset/train 2 | 3 | # untar without folder structure 4 | tar -xvf train_large_places365standard.tar -C places_standard_dataset/train 5 | 6 | # create location config places.yaml 7 | PWD=$(pwd) 8 | DATASET=${PWD}/places_standard_dataset 9 | PLACES=${PWD}/configs/training/location/places_standard.yaml 10 | 11 | touch $PLACES 12 | echo "# @package _group_" >> $PLACES 13 | echo "data_root_dir: ${DATASET}/" >> $PLACES 14 | echo "out_root_dir: ${PWD}/experiments/" >> $PLACES 15 | echo "tb_dir: ${PWD}/tb_logs/" >> $PLACES 16 | echo "pretrained_models: ${PWD}/" >> $PLACES 17 | -------------------------------------------------------------------------------- /fetch_data/sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | test_files_path = os.path.abspath('.') + '/places_standard_dataset/original/test/' 5 | list_of_random_test_files = os.path.abspath('.') + '/places_standard_dataset/original/test_random_files.txt' 6 | 7 | test_files = [ 8 | test_files_path + image for image in os.listdir(test_files_path) 9 | ] 10 | 11 | print(f'Sampling 2000 images out of {len(test_files)} images in {test_files_path}' + \ 12 | f'and put their paths to {list_of_random_test_files}') 13 | print('Our training procedure will pick best checkpoints according to metrics, computed on these images.') 14 | 15 | random.shuffle(test_files) 16 | test_files_random = test_files[0:2000] 17 | with open(list_of_random_test_files, 'w') as fw: 18 | for filename in test_files_random: 19 | fw.write(filename+'\n') 20 | print('...done') 21 | 22 | 23 | # -------------------------------- 24 | 25 | val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/' 26 | list_of_random_val_files = os.path.abspath('.') + '/places_standard_dataset/original/val_random_files.txt' 27 | 28 | val_files = [ 29 | val_files_path + image for image in os.listdir(val_files_path) 30 | ] 31 | 32 | print(f'Sampling 100 images out of {len(val_files)} in {val_files_path} ' + \ 33 | f'and put their paths to {list_of_random_val_files}') 34 | print('We use these images for visual check up of evolution of inpainting algorithm epoch to epoch' ) 35 | 36 | random.shuffle(val_files) 37 | val_files_random = val_files[0:100] 38 | with open(list_of_random_val_files, 'w') as fw: 39 | for filename in val_files_random: 40 | fw.write(filename+'\n') 41 | print('...done') 42 | 43 | -------------------------------------------------------------------------------- /models/ade20k/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * -------------------------------------------------------------------------------- /models/ade20k/color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/models/ade20k/color150.mat -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /models/ade20k/utils.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" 2 | 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | 9 | try: 10 | from urllib import urlretrieve 11 | except ImportError: 12 | from urllib.request import urlretrieve 13 | 14 | 15 | def load_url(url, model_dir='./pretrained', map_location=None): 16 | if not os.path.exists(model_dir): 17 | os.makedirs(model_dir) 18 | filename = url.split('/')[-1] 19 | cached_file = os.path.join(model_dir, filename) 20 | if not os.path.exists(cached_file): 21 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 22 | urlretrieve(url, cached_file) 23 | return torch.load(cached_file, map_location=map_location) 24 | 25 | 26 | def color_encode(labelmap, colors, mode='RGB'): 27 | labelmap = labelmap.astype('int') 28 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 29 | dtype=np.uint8) 30 | for label in np.unique(labelmap): 31 | if label < 0: 32 | continue 33 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 34 | np.tile(colors[label], 35 | (labelmap.shape[0], labelmap.shape[1], 1)) 36 | 37 | if mode == 'BGR': 38 | return labelmap_rgb[:, :, ::-1] 39 | else: 40 | return labelmap_rgb 41 | -------------------------------------------------------------------------------- /models/lpips_models/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/models/lpips_models/alex.pth -------------------------------------------------------------------------------- /models/lpips_models/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/models/lpips_models/squeeze.pth -------------------------------------------------------------------------------- /models/lpips_models/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/models/lpips_models/vgg.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | tqdm 3 | numpy 4 | easydict==1.9.0 5 | scikit-image==0.17.2 6 | scikit-learn==0.24.2 7 | opencv-python 8 | tensorflow 9 | joblib 10 | matplotlib 11 | pandas 12 | albumentations==0.5.2 13 | hydra-core==1.1.0 14 | pytorch-lightning==1.2.9 15 | tabulate 16 | kornia==0.5.0 17 | webdataset 18 | packaging 19 | scikit-learn==0.24.2 20 | wldhx.yadisk-direct 21 | -------------------------------------------------------------------------------- /saicinpainting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1 6 | from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore 7 | 8 | 9 | def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs): 10 | logging.info(f'Make evaluator {kind}') 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | metrics = {} 13 | if ssim: 14 | metrics['ssim'] = SSIMScore() 15 | if lpips: 16 | metrics['lpips'] = LPIPSScore() 17 | if fid: 18 | metrics['fid'] = FIDScore().to(device) 19 | 20 | if integral_kind is None: 21 | integral_func = None 22 | elif integral_kind == 'ssim_fid100_f1': 23 | integral_func = ssim_fid100_f1 24 | elif integral_kind == 'lpips_fid100_f1': 25 | integral_func = lpips_fid100_f1 26 | else: 27 | raise ValueError(f'Unexpected integral_kind={integral_kind}') 28 | 29 | if kind == 'default': 30 | return InpaintingEvaluatorOnline(scores=metrics, 31 | integral_func=integral_func, 32 | integral_title=integral_kind, 33 | **kwargs) 34 | -------------------------------------------------------------------------------- /saicinpainting/evaluation/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/losses/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/losses/fid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/losses/fid/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/losses/ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class SSIM(torch.nn.Module): 7 | """SSIM. Modified from: 8 | https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 9 | """ 10 | 11 | def __init__(self, window_size=11, size_average=True): 12 | super().__init__() 13 | self.window_size = window_size 14 | self.size_average = size_average 15 | self.channel = 1 16 | self.register_buffer('window', self._create_window(window_size, self.channel)) 17 | 18 | def forward(self, img1, img2): 19 | assert len(img1.shape) == 4 20 | 21 | channel = img1.size()[1] 22 | 23 | if channel == self.channel and self.window.data.type() == img1.data.type(): 24 | window = self.window 25 | else: 26 | window = self._create_window(self.window_size, channel) 27 | 28 | # window = window.to(img1.get_device()) 29 | window = window.type_as(img1) 30 | 31 | self.window = window 32 | self.channel = channel 33 | 34 | return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) 35 | 36 | def _gaussian(self, window_size, sigma): 37 | gauss = torch.Tensor([ 38 | np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size) 39 | ]) 40 | return gauss / gauss.sum() 41 | 42 | def _create_window(self, window_size, channel): 43 | _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) 44 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 45 | return _2D_window.expand(channel, 1, window_size, window_size).contiguous() 46 | 47 | def _ssim(self, img1, img2, window, window_size, channel, size_average=True): 48 | mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel) 49 | mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel) 50 | 51 | mu1_sq = mu1.pow(2) 52 | mu2_sq = mu2.pow(2) 53 | mu1_mu2 = mu1 * mu2 54 | 55 | sigma1_sq = F.conv2d( 56 | img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq 57 | sigma2_sq = F.conv2d( 58 | img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq 59 | sigma12 = F.conv2d( 60 | img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2 61 | 62 | C1 = 0.01 ** 2 63 | C2 = 0.03 ** 2 64 | 65 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ 66 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 67 | 68 | if size_average: 69 | return ssim_map.mean() 70 | 71 | return ssim_map.mean(1).mean(1).mean(1) 72 | 73 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 74 | return 75 | -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/README.md: -------------------------------------------------------------------------------- 1 | # Current algorithm 2 | 3 | ## Choice of mask objects 4 | 5 | For identification of the objects which are suitable for mask obtaining, panoptic segmentation model 6 | from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances 7 | belong either to "stuff" or "things" types. We consider that instances of objects should have category belong 8 | to "things". Besides, we set upper bound on area which is taken by the object — we consider that too big 9 | area indicates either of the instance being a background or a main object which should not be removed. 10 | 11 | ## Choice of position for mask 12 | 13 | We consider that input image has size 2^n x 2^m. We downsample it using 14 | [COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to 15 | 64 = 2^8 = 2^{downsample_levels}. 16 | 17 | ### Augmentation 18 | 19 | There are several parameters for augmentation: 20 | - Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the 21 | image completely. 22 | - 23 | 24 | ### Shift 25 | 26 | 27 | ## Select 28 | -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/.gitignore: -------------------------------------------------------------------------------- 1 | results -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless) 2 | 3 | Python COUNTLESS Downsampling 4 | ============================= 5 | 6 | To install: 7 | 8 | `pip install -r requirements.txt` 9 | 10 | To test: 11 | 12 | `python test.py` 13 | 14 | To benchmark countless2d: 15 | 16 | `python python/countless2d.py python/images/gray_segmentation.png` 17 | 18 | To benchmark countless3d: 19 | 20 | `python python/countless3d.py` 21 | 22 | Adjust N and the list of algorithms inside each script to modify the run parameters. 23 | 24 | 25 | Python3 is slightly faster than Python2. -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/images/gcim.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/images/gcim.jpg -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/images/gray_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/images/segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/images/segmentation.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/images/sparse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/images/sparse.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/memprof/countless3d.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow>=6.2.0 2 | numpy>=1.16 3 | scipy 4 | tqdm 5 | memory_profiler 6 | six 7 | pytest -------------------------------------------------------------------------------- /saicinpainting/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import yaml 4 | from easydict import EasyDict as edict 5 | import torch.nn as nn 6 | import torch 7 | 8 | 9 | def load_yaml(path): 10 | with open(path, 'r') as f: 11 | return edict(yaml.safe_load(f)) 12 | 13 | 14 | def move_to_device(obj, device): 15 | if isinstance(obj, nn.Module): 16 | return obj.to(device) 17 | if torch.is_tensor(obj): 18 | return obj.to(device) 19 | if isinstance(obj, (tuple, list)): 20 | return [move_to_device(el, device) for el in obj] 21 | if isinstance(obj, dict): 22 | return {name: move_to_device(val, device) for name, val in obj.items()} 23 | raise ValueError(f'Unexpected type {type(obj)}') 24 | 25 | 26 | class SmallMode(Enum): 27 | DROP = "drop" 28 | UPSCALE = "upscale" 29 | -------------------------------------------------------------------------------- /saicinpainting/evaluation/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import io 3 | from skimage.segmentation import mark_boundaries 4 | 5 | 6 | def save_item_for_vis(item, out_file): 7 | mask = item['mask'] > 0.5 8 | if mask.ndim == 3: 9 | mask = mask[0] 10 | img = mark_boundaries(np.transpose(item['image'], (1, 2, 0)), 11 | mask, 12 | color=(1., 0., 0.), 13 | outline_color=(1., 1., 1.), 14 | mode='thick') 15 | 16 | if 'inpainted' in item: 17 | inp_img = mark_boundaries(np.transpose(item['inpainted'], (1, 2, 0)), 18 | mask, 19 | color=(1., 0., 0.), 20 | mode='outer') 21 | img = np.concatenate((img, inp_img), axis=1) 22 | 23 | img = np.clip(img * 255, 0, 255).astype('uint8') 24 | io.imsave(out_file, img) 25 | 26 | 27 | def save_mask_for_sidebyside(item, out_file): 28 | mask = item['mask']# > 0.5 29 | if mask.ndim == 3: 30 | mask = mask[0] 31 | mask = np.clip(mask * 255, 0, 255).astype('uint8') 32 | io.imsave(out_file, mask) 33 | 34 | def save_img_for_sidebyside(item, out_file): 35 | img = np.transpose(item['image'], (1, 2, 0)) 36 | img = np.clip(img * 255, 0, 255).astype('uint8') 37 | io.imsave(out_file, img) -------------------------------------------------------------------------------- /saicinpainting/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/training/__init__.py -------------------------------------------------------------------------------- /saicinpainting/training/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/training/data/__init__.py -------------------------------------------------------------------------------- /saicinpainting/training/data/aug.py: -------------------------------------------------------------------------------- 1 | from albumentations import DualIAATransform, to_tuple 2 | import imgaug.augmenters as iaa 3 | 4 | class IAAAffine2(DualIAATransform): 5 | """Place a regular grid of points on the input and randomly move the neighbourhood of these point around 6 | via affine transformations. 7 | 8 | Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} 9 | 10 | Args: 11 | p (float): probability of applying the transform. Default: 0.5. 12 | 13 | Targets: 14 | image, mask 15 | """ 16 | 17 | def __init__( 18 | self, 19 | scale=(0.7, 1.3), 20 | translate_percent=None, 21 | translate_px=None, 22 | rotate=0.0, 23 | shear=(-0.1, 0.1), 24 | order=1, 25 | cval=0, 26 | mode="reflect", 27 | always_apply=False, 28 | p=0.5, 29 | ): 30 | super(IAAAffine2, self).__init__(always_apply, p) 31 | self.scale = dict(x=scale, y=scale) 32 | self.translate_percent = to_tuple(translate_percent, 0) 33 | self.translate_px = to_tuple(translate_px, 0) 34 | self.rotate = to_tuple(rotate) 35 | self.shear = dict(x=shear, y=shear) 36 | self.order = order 37 | self.cval = cval 38 | self.mode = mode 39 | 40 | @property 41 | def processor(self): 42 | return iaa.Affine( 43 | self.scale, 44 | self.translate_percent, 45 | self.translate_px, 46 | self.rotate, 47 | self.shear, 48 | self.order, 49 | self.cval, 50 | self.mode, 51 | ) 52 | 53 | def get_transform_init_args_names(self): 54 | return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode") 55 | 56 | 57 | class IAAPerspective2(DualIAATransform): 58 | """Perform a random four point perspective transform of the input. 59 | 60 | Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} 61 | 62 | Args: 63 | scale ((float, float): standard deviation of the normal distributions. These are used to sample 64 | the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). 65 | p (float): probability of applying the transform. Default: 0.5. 66 | 67 | Targets: 68 | image, mask 69 | """ 70 | 71 | def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5, 72 | order=1, cval=0, mode="replicate"): 73 | super(IAAPerspective2, self).__init__(always_apply, p) 74 | self.scale = to_tuple(scale, 1.0) 75 | self.keep_size = keep_size 76 | self.cval = cval 77 | self.mode = mode 78 | 79 | @property 80 | def processor(self): 81 | return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval) 82 | 83 | def get_transform_init_args_names(self): 84 | return ("scale", "keep_size") 85 | -------------------------------------------------------------------------------- /saicinpainting/training/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/advimman/lama/786f5936b27fb3dacd2b1ad799e4de968ea697e7/saicinpainting/training/losses/__init__.py -------------------------------------------------------------------------------- /saicinpainting/training/losses/constants.py: -------------------------------------------------------------------------------- 1 | weights = {"ade20k": 2 | [6.34517766497462, 3 | 9.328358208955224, 4 | 11.389521640091116, 5 | 16.10305958132045, 6 | 20.833333333333332, 7 | 22.22222222222222, 8 | 25.125628140703515, 9 | 43.29004329004329, 10 | 50.5050505050505, 11 | 54.6448087431694, 12 | 55.24861878453038, 13 | 60.24096385542168, 14 | 62.5, 15 | 66.2251655629139, 16 | 84.74576271186442, 17 | 90.90909090909092, 18 | 91.74311926605505, 19 | 96.15384615384616, 20 | 96.15384615384616, 21 | 97.08737864077669, 22 | 102.04081632653062, 23 | 135.13513513513513, 24 | 149.2537313432836, 25 | 153.84615384615384, 26 | 163.93442622950818, 27 | 166.66666666666666, 28 | 188.67924528301887, 29 | 192.30769230769232, 30 | 217.3913043478261, 31 | 227.27272727272725, 32 | 227.27272727272725, 33 | 227.27272727272725, 34 | 303.03030303030306, 35 | 322.5806451612903, 36 | 333.3333333333333, 37 | 370.3703703703703, 38 | 384.61538461538464, 39 | 416.6666666666667, 40 | 416.6666666666667, 41 | 434.7826086956522, 42 | 434.7826086956522, 43 | 454.5454545454545, 44 | 454.5454545454545, 45 | 500.0, 46 | 526.3157894736842, 47 | 526.3157894736842, 48 | 555.5555555555555, 49 | 555.5555555555555, 50 | 555.5555555555555, 51 | 555.5555555555555, 52 | 555.5555555555555, 53 | 555.5555555555555, 54 | 555.5555555555555, 55 | 588.2352941176471, 56 | 588.2352941176471, 57 | 588.2352941176471, 58 | 588.2352941176471, 59 | 588.2352941176471, 60 | 666.6666666666666, 61 | 666.6666666666666, 62 | 666.6666666666666, 63 | 666.6666666666666, 64 | 714.2857142857143, 65 | 714.2857142857143, 66 | 714.2857142857143, 67 | 714.2857142857143, 68 | 714.2857142857143, 69 | 769.2307692307693, 70 | 769.2307692307693, 71 | 769.2307692307693, 72 | 833.3333333333334, 73 | 833.3333333333334, 74 | 833.3333333333334, 75 | 833.3333333333334, 76 | 909.090909090909, 77 | 1000.0, 78 | 1111.111111111111, 79 | 1111.111111111111, 80 | 1111.111111111111, 81 | 1111.111111111111, 82 | 1111.111111111111, 83 | 1250.0, 84 | 1250.0, 85 | 1250.0, 86 | 1250.0, 87 | 1250.0, 88 | 1428.5714285714287, 89 | 1428.5714285714287, 90 | 1428.5714285714287, 91 | 1428.5714285714287, 92 | 1428.5714285714287, 93 | 1428.5714285714287, 94 | 1428.5714285714287, 95 | 1666.6666666666667, 96 | 1666.6666666666667, 97 | 1666.6666666666667, 98 | 1666.6666666666667, 99 | 1666.6666666666667, 100 | 1666.6666666666667, 101 | 1666.6666666666667, 102 | 1666.6666666666667, 103 | 1666.6666666666667, 104 | 1666.6666666666667, 105 | 1666.6666666666667, 106 | 2000.0, 107 | 2000.0, 108 | 2000.0, 109 | 2000.0, 110 | 2000.0, 111 | 2000.0, 112 | 2000.0, 113 | 2000.0, 114 | 2000.0, 115 | 2000.0, 116 | 2000.0, 117 | 2000.0, 118 | 2000.0, 119 | 2000.0, 120 | 2000.0, 121 | 2000.0, 122 | 2000.0, 123 | 2500.0, 124 | 2500.0, 125 | 2500.0, 126 | 2500.0, 127 | 2500.0, 128 | 2500.0, 129 | 2500.0, 130 | 2500.0, 131 | 2500.0, 132 | 2500.0, 133 | 2500.0, 134 | 2500.0, 135 | 2500.0, 136 | 3333.3333333333335, 137 | 3333.3333333333335, 138 | 3333.3333333333335, 139 | 3333.3333333333335, 140 | 3333.3333333333335, 141 | 3333.3333333333335, 142 | 3333.3333333333335, 143 | 3333.3333333333335, 144 | 3333.3333333333335, 145 | 3333.3333333333335, 146 | 3333.3333333333335, 147 | 3333.3333333333335, 148 | 3333.3333333333335, 149 | 5000.0, 150 | 5000.0, 151 | 5000.0] 152 | } -------------------------------------------------------------------------------- /saicinpainting/training/losses/feature_matching.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def masked_l2_loss(pred, target, mask, weight_known, weight_missing): 8 | per_pixel_l2 = F.mse_loss(pred, target, reduction='none') 9 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known 10 | return (pixel_weights * per_pixel_l2).mean() 11 | 12 | 13 | def masked_l1_loss(pred, target, mask, weight_known, weight_missing): 14 | per_pixel_l1 = F.l1_loss(pred, target, reduction='none') 15 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known 16 | return (pixel_weights * per_pixel_l1).mean() 17 | 18 | 19 | def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None): 20 | if mask is None: 21 | res = torch.stack([F.mse_loss(fake_feat, target_feat) 22 | for fake_feat, target_feat in zip(fake_features, target_features)]).mean() 23 | else: 24 | res = 0 25 | norm = 0 26 | for fake_feat, target_feat in zip(fake_features, target_features): 27 | cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) 28 | error_weights = 1 - cur_mask 29 | cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() 30 | res = res + cur_val 31 | norm += 1 32 | res = res / norm 33 | return res 34 | -------------------------------------------------------------------------------- /saicinpainting/training/losses/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | from models.ade20k import ModelBuilder 7 | from saicinpainting.utils import check_and_warn_input_range 8 | 9 | 10 | IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] 11 | IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] 12 | 13 | 14 | class PerceptualLoss(nn.Module): 15 | def __init__(self, normalize_inputs=True): 16 | super(PerceptualLoss, self).__init__() 17 | 18 | self.normalize_inputs = normalize_inputs 19 | self.mean_ = IMAGENET_MEAN 20 | self.std_ = IMAGENET_STD 21 | 22 | vgg = torchvision.models.vgg19(pretrained=True).features 23 | vgg_avg_pooling = [] 24 | 25 | for weights in vgg.parameters(): 26 | weights.requires_grad = False 27 | 28 | for module in vgg.modules(): 29 | if module.__class__.__name__ == 'Sequential': 30 | continue 31 | elif module.__class__.__name__ == 'MaxPool2d': 32 | vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) 33 | else: 34 | vgg_avg_pooling.append(module) 35 | 36 | self.vgg = nn.Sequential(*vgg_avg_pooling) 37 | 38 | def do_normalize_inputs(self, x): 39 | return (x - self.mean_.to(x.device)) / self.std_.to(x.device) 40 | 41 | def partial_losses(self, input, target, mask=None): 42 | check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses') 43 | 44 | # we expect input and target to be in [0, 1] range 45 | losses = [] 46 | 47 | if self.normalize_inputs: 48 | features_input = self.do_normalize_inputs(input) 49 | features_target = self.do_normalize_inputs(target) 50 | else: 51 | features_input = input 52 | features_target = target 53 | 54 | for layer in self.vgg[:30]: 55 | 56 | features_input = layer(features_input) 57 | features_target = layer(features_target) 58 | 59 | if layer.__class__.__name__ == 'ReLU': 60 | loss = F.mse_loss(features_input, features_target, reduction='none') 61 | 62 | if mask is not None: 63 | cur_mask = F.interpolate(mask, size=features_input.shape[-2:], 64 | mode='bilinear', align_corners=False) 65 | loss = loss * (1 - cur_mask) 66 | 67 | loss = loss.mean(dim=tuple(range(1, len(loss.shape)))) 68 | losses.append(loss) 69 | 70 | return losses 71 | 72 | def forward(self, input, target, mask=None): 73 | losses = self.partial_losses(input, target, mask=mask) 74 | return torch.stack(losses).sum(dim=0) 75 | 76 | def get_global_features(self, input): 77 | check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features') 78 | 79 | if self.normalize_inputs: 80 | features_input = self.do_normalize_inputs(input) 81 | else: 82 | features_input = input 83 | 84 | features_input = self.vgg(features_input) 85 | return features_input 86 | 87 | 88 | class ResNetPL(nn.Module): 89 | def __init__(self, weight=1, 90 | weights_path=None, arch_encoder='resnet50dilated', segmentation=True): 91 | super().__init__() 92 | self.impl = ModelBuilder.get_encoder(weights_path=weights_path, 93 | arch_encoder=arch_encoder, 94 | arch_decoder='ppm_deepsup', 95 | fc_dim=2048, 96 | segmentation=segmentation) 97 | self.impl.eval() 98 | for w in self.impl.parameters(): 99 | w.requires_grad_(False) 100 | 101 | self.weight = weight 102 | 103 | def forward(self, pred, target): 104 | pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) 105 | target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) 106 | 107 | pred_feats = self.impl(pred, return_feature_maps=True) 108 | target_feats = self.impl(target, return_feature_maps=True) 109 | 110 | result = torch.stack([F.mse_loss(cur_pred, cur_target) 111 | for cur_pred, cur_target 112 | in zip(pred_feats, target_feats)]).sum() * self.weight 113 | return result 114 | -------------------------------------------------------------------------------- /saicinpainting/training/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .constants import weights as constant_weights 6 | 7 | 8 | class CrossEntropy2d(nn.Module): 9 | def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs): 10 | """ 11 | weight (Tensor, optional): a manual rescaling weight given to each class. 12 | If given, has to be a Tensor of size "nclasses" 13 | """ 14 | super(CrossEntropy2d, self).__init__() 15 | self.reduction = reduction 16 | self.ignore_label = ignore_label 17 | self.weights = weights 18 | if self.weights is not None: 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | self.weights = torch.FloatTensor(constant_weights[weights]).to(device) 21 | 22 | def forward(self, predict, target): 23 | """ 24 | Args: 25 | predict:(n, c, h, w) 26 | target:(n, 1, h, w) 27 | """ 28 | target = target.long() 29 | assert not target.requires_grad 30 | assert predict.dim() == 4, "{0}".format(predict.size()) 31 | assert target.dim() == 4, "{0}".format(target.size()) 32 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) 33 | assert target.size(1) == 1, "{0}".format(target.size(1)) 34 | assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2)) 35 | assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3)) 36 | target = target.squeeze(1) 37 | n, c, h, w = predict.size() 38 | target_mask = (target >= 0) * (target != self.ignore_label) 39 | target = target[target_mask] 40 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 41 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 42 | loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction) 43 | return loss 44 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from saicinpainting.training.modules.ffc import FFCResNetGenerator 4 | from saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \ 5 | NLayerDiscriminator, MultidilatedNLayerDiscriminator 6 | 7 | def make_generator(config, kind, **kwargs): 8 | logging.info(f'Make generator {kind}') 9 | 10 | if kind == 'pix2pixhd_multidilated': 11 | return MultiDilatedGlobalGenerator(**kwargs) 12 | 13 | if kind == 'pix2pixhd_global': 14 | return GlobalGenerator(**kwargs) 15 | 16 | if kind == 'ffc_resnet': 17 | return FFCResNetGenerator(**kwargs) 18 | 19 | raise ValueError(f'Unknown generator kind {kind}') 20 | 21 | 22 | def make_discriminator(kind, **kwargs): 23 | logging.info(f'Make discriminator {kind}') 24 | 25 | if kind == 'pix2pixhd_nlayer_multidilated': 26 | return MultidilatedNLayerDiscriminator(**kwargs) 27 | 28 | if kind == 'pix2pixhd_nlayer': 29 | return NLayerDiscriminator(**kwargs) 30 | 31 | raise ValueError(f'Unknown discriminator kind {kind}') 32 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Tuple, List 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv 8 | from saicinpainting.training.modules.multidilated_conv import MultidilatedConv 9 | 10 | 11 | class BaseDiscriminator(nn.Module): 12 | @abc.abstractmethod 13 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: 14 | """ 15 | Predict scores and get intermediate activations. Useful for feature matching loss 16 | :return tuple (scores, list of intermediate activations) 17 | """ 18 | raise NotImplemented() 19 | 20 | 21 | def get_conv_block_ctor(kind='default'): 22 | if not isinstance(kind, str): 23 | return kind 24 | if kind == 'default': 25 | return nn.Conv2d 26 | if kind == 'depthwise': 27 | return DepthWiseSeperableConv 28 | if kind == 'multidilated': 29 | return MultidilatedConv 30 | raise ValueError(f'Unknown convolutional block kind {kind}') 31 | 32 | 33 | def get_norm_layer(kind='bn'): 34 | if not isinstance(kind, str): 35 | return kind 36 | if kind == 'bn': 37 | return nn.BatchNorm2d 38 | if kind == 'in': 39 | return nn.InstanceNorm2d 40 | raise ValueError(f'Unknown norm block kind {kind}') 41 | 42 | 43 | def get_activation(kind='tanh'): 44 | if kind == 'tanh': 45 | return nn.Tanh() 46 | if kind == 'sigmoid': 47 | return nn.Sigmoid() 48 | if kind is False: 49 | return nn.Identity() 50 | raise ValueError(f'Unknown activation kind {kind}') 51 | 52 | 53 | class SimpleMultiStepGenerator(nn.Module): 54 | def __init__(self, steps: List[nn.Module]): 55 | super().__init__() 56 | self.steps = nn.ModuleList(steps) 57 | 58 | def forward(self, x): 59 | cur_in = x 60 | outs = [] 61 | for step in self.steps: 62 | cur_out = step(cur_in) 63 | outs.append(cur_out) 64 | cur_in = torch.cat((cur_in, cur_out), dim=1) 65 | return torch.cat(outs[::-1], dim=1) 66 | 67 | def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): 68 | if kind == 'convtranspose': 69 | return [nn.ConvTranspose2d(min(max_features, ngf * mult), 70 | min(max_features, int(ngf * mult / 2)), 71 | kernel_size=3, stride=2, padding=1, output_padding=1), 72 | norm_layer(min(max_features, int(ngf * mult / 2))), activation] 73 | elif kind == 'bilinear': 74 | return [nn.Upsample(scale_factor=2, mode='bilinear'), 75 | DepthWiseSeperableConv(min(max_features, ngf * mult), 76 | min(max_features, int(ngf * mult / 2)), 77 | kernel_size=3, stride=1, padding=1), 78 | norm_layer(min(max_features, int(ngf * mult / 2))), activation] 79 | else: 80 | raise Exception(f"Invalid deconv kind: {kind}") -------------------------------------------------------------------------------- /saicinpainting/training/modules/depthwise_sep_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DepthWiseSeperableConv(nn.Module): 5 | def __init__(self, in_dim, out_dim, *args, **kwargs): 6 | super().__init__() 7 | if 'groups' in kwargs: 8 | # ignoring groups for Depthwise Sep Conv 9 | del kwargs['groups'] 10 | 11 | self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs) 12 | self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1) 13 | 14 | def forward(self, x): 15 | out = self.depthwise(x) 16 | out = self.pointwise(out) 17 | return out -------------------------------------------------------------------------------- /saicinpainting/training/modules/fake_fakes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.constants import SamplePadding 3 | from kornia.augmentation import RandomAffine, CenterCrop 4 | 5 | 6 | class FakeFakesGenerator: 7 | def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2): 8 | self.grad_aug = RandomAffine(degrees=360, 9 | translate=0.2, 10 | padding_mode=SamplePadding.REFLECTION, 11 | keepdim=False, 12 | p=1) 13 | self.img_aug = RandomAffine(degrees=img_aug_degree, 14 | translate=img_aug_translate, 15 | padding_mode=SamplePadding.REFLECTION, 16 | keepdim=True, 17 | p=1) 18 | self.aug_proba = aug_proba 19 | 20 | def __call__(self, input_images, masks): 21 | blend_masks = self._fill_masks_with_gradient(masks) 22 | blend_target = self._make_blend_target(input_images) 23 | result = input_images * (1 - blend_masks) + blend_target * blend_masks 24 | return result, blend_masks 25 | 26 | def _make_blend_target(self, input_images): 27 | batch_size = input_images.shape[0] 28 | permuted = input_images[torch.randperm(batch_size)] 29 | augmented = self.img_aug(input_images) 30 | is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float() 31 | result = augmented * is_aug + permuted * (1 - is_aug) 32 | return result 33 | 34 | def _fill_masks_with_gradient(self, masks): 35 | batch_size, _, height, width = masks.shape 36 | grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \ 37 | .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2) 38 | grad = self.grad_aug(grad) 39 | grad = CenterCrop((height, width))(grad) 40 | grad *= masks 41 | 42 | grad_for_min = grad + (1 - masks) * 10 43 | grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None] 44 | grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6 45 | grad.clamp_(min=0, max=1) 46 | 47 | return grad 48 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/multidilated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv 5 | 6 | class MultidilatedConv(nn.Module): 7 | def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True, 8 | shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs): 9 | super().__init__() 10 | convs = [] 11 | self.equal_dim = equal_dim 12 | assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode 13 | if comb_mode in ('cat_out', 'cat_both'): 14 | self.cat_out = True 15 | if equal_dim: 16 | assert out_dim % dilation_num == 0 17 | out_dims = [out_dim // dilation_num] * dilation_num 18 | self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], []) 19 | else: 20 | out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 21 | out_dims.append(out_dim - sum(out_dims)) 22 | index = [] 23 | starts = [0] + out_dims[:-1] 24 | lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] 25 | for i in range(out_dims[-1]): 26 | for j in range(dilation_num): 27 | index += list(range(starts[j], starts[j] + lengths[j])) 28 | starts[j] += lengths[j] 29 | self.index = index 30 | assert(len(index) == out_dim) 31 | self.out_dims = out_dims 32 | else: 33 | self.cat_out = False 34 | self.out_dims = [out_dim] * dilation_num 35 | 36 | if comb_mode in ('cat_in', 'cat_both'): 37 | if equal_dim: 38 | assert in_dim % dilation_num == 0 39 | in_dims = [in_dim // dilation_num] * dilation_num 40 | else: 41 | in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 42 | in_dims.append(in_dim - sum(in_dims)) 43 | self.in_dims = in_dims 44 | self.cat_in = True 45 | else: 46 | self.cat_in = False 47 | self.in_dims = [in_dim] * dilation_num 48 | 49 | conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d 50 | dilation = min_dilation 51 | for i in range(dilation_num): 52 | if isinstance(padding, int): 53 | cur_padding = padding * dilation 54 | else: 55 | cur_padding = padding[i] 56 | convs.append(conv_type( 57 | self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs 58 | )) 59 | if i > 0 and shared_weights: 60 | convs[-1].weight = convs[0].weight 61 | convs[-1].bias = convs[0].bias 62 | dilation *= 2 63 | self.convs = nn.ModuleList(convs) 64 | 65 | self.shuffle_in_channels = shuffle_in_channels 66 | if self.shuffle_in_channels: 67 | # shuffle list as shuffling of tensors is nondeterministic 68 | in_channels_permute = list(range(in_dim)) 69 | random.shuffle(in_channels_permute) 70 | # save as buffer so it is saved and loaded with checkpoint 71 | self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute)) 72 | 73 | def forward(self, x): 74 | if self.shuffle_in_channels: 75 | x = x[:, self.in_channels_permute] 76 | 77 | outs = [] 78 | if self.cat_in: 79 | if self.equal_dim: 80 | x = x.chunk(len(self.convs), dim=1) 81 | else: 82 | new_x = [] 83 | start = 0 84 | for dim in self.in_dims: 85 | new_x.append(x[:, start:start+dim]) 86 | start += dim 87 | x = new_x 88 | for i, conv in enumerate(self.convs): 89 | if self.cat_in: 90 | input = x[i] 91 | else: 92 | input = x 93 | outs.append(conv(input)) 94 | if self.cat_out: 95 | out = torch.cat(outs, dim=1)[:, self.index] 96 | else: 97 | out = sum(outs) 98 | return out 99 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/spatial_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from kornia.geometry.transform import rotate 5 | 6 | 7 | class LearnableSpatialTransformWrapper(nn.Module): 8 | def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True): 9 | super().__init__() 10 | self.impl = impl 11 | self.angle = torch.rand(1) * angle_init_range 12 | if train_angle: 13 | self.angle = nn.Parameter(self.angle, requires_grad=True) 14 | self.pad_coef = pad_coef 15 | 16 | def forward(self, x): 17 | if torch.is_tensor(x): 18 | return self.inverse_transform(self.impl(self.transform(x)), x) 19 | elif isinstance(x, tuple): 20 | x_trans = tuple(self.transform(elem) for elem in x) 21 | y_trans = self.impl(x_trans) 22 | return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)) 23 | else: 24 | raise ValueError(f'Unexpected input type {type(x)}') 25 | 26 | def transform(self, x): 27 | height, width = x.shape[2:] 28 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) 29 | x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') 30 | x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) 31 | return x_padded_rotated 32 | 33 | def inverse_transform(self, y_padded_rotated, orig_x): 34 | height, width = orig_x.shape[2:] 35 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) 36 | 37 | y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) 38 | y_height, y_width = y_padded.shape[2:] 39 | y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w] 40 | return y 41 | 42 | 43 | if __name__ == '__main__': 44 | layer = LearnableSpatialTransformWrapper(nn.Identity()) 45 | x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float() 46 | y = layer(x) 47 | assert x.shape == y.shape 48 | assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1]) 49 | print('all ok') 50 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/squeeze_excitation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel // reduction, bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(channel // reduction, channel, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | res = x * y.expand_as(x) 20 | return res 21 | -------------------------------------------------------------------------------- /saicinpainting/training/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule 4 | 5 | 6 | def get_training_model_class(kind): 7 | if kind == 'default': 8 | return DefaultInpaintingTrainingModule 9 | 10 | raise ValueError(f'Unknown trainer module {kind}') 11 | 12 | 13 | def make_training_model(config): 14 | kind = config.training_model.kind 15 | kwargs = dict(config.training_model) 16 | kwargs.pop('kind') 17 | kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' 18 | 19 | logging.info(f'Make training model {kind}') 20 | 21 | cls = get_training_model_class(kind) 22 | return cls(config, **kwargs) 23 | 24 | 25 | def load_checkpoint(train_config, path, map_location='cuda', strict=True): 26 | model: torch.nn.Module = make_training_model(train_config) 27 | state = torch.load(path, map_location=map_location) 28 | model.load_state_dict(state['state_dict'], strict=strict) 29 | model.on_load_checkpoint(state) 30 | return model 31 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from saicinpainting.training.visualizers.directory import DirectoryVisualizer 4 | from saicinpainting.training.visualizers.noop import NoopVisualizer 5 | 6 | 7 | def make_visualizer(kind, **kwargs): 8 | logging.info(f'Make visualizer {kind}') 9 | 10 | if kind == 'directory': 11 | return DirectoryVisualizer(**kwargs) 12 | if kind == 'noop': 13 | return NoopVisualizer() 14 | 15 | raise ValueError(f'Unknown visualizer kind {kind}') 16 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | import torch 6 | from skimage import color 7 | from skimage.segmentation import mark_boundaries 8 | 9 | from . import colors 10 | 11 | COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation 12 | 13 | 14 | class BaseVisualizer: 15 | @abc.abstractmethod 16 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 17 | """ 18 | Take a batch, make an image from it and visualize 19 | """ 20 | raise NotImplementedError() 21 | 22 | 23 | def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str], 24 | last_without_mask=True, rescale_keys=None, mask_only_first=None, 25 | black_mask=False) -> np.ndarray: 26 | mask = images_dict['mask'] > 0.5 27 | result = [] 28 | for i, k in enumerate(keys): 29 | img = images_dict[k] 30 | img = np.transpose(img, (1, 2, 0)) 31 | 32 | if rescale_keys is not None and k in rescale_keys: 33 | img = img - img.min() 34 | img /= img.max() + 1e-5 35 | if len(img.shape) == 2: 36 | img = np.expand_dims(img, 2) 37 | 38 | if img.shape[2] == 1: 39 | img = np.repeat(img, 3, axis=2) 40 | elif (img.shape[2] > 3): 41 | img_classes = img.argmax(2) 42 | img = color.label2rgb(img_classes, colors=COLORS) 43 | 44 | if mask_only_first: 45 | need_mark_boundaries = i == 0 46 | else: 47 | need_mark_boundaries = i < len(keys) - 1 or not last_without_mask 48 | 49 | if need_mark_boundaries: 50 | if black_mask: 51 | img = img * (1 - mask[0][..., None]) 52 | img = mark_boundaries(img, 53 | mask[0], 54 | color=(1., 0., 0.), 55 | outline_color=(1., 1., 1.), 56 | mode='thick') 57 | result.append(img) 58 | return np.concatenate(result, axis=1) 59 | 60 | 61 | def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10, 62 | last_without_mask=True, rescale_keys=None) -> np.ndarray: 63 | batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items() 64 | if k in keys or k == 'mask'} 65 | 66 | batch_size = next(iter(batch.values())).shape[0] 67 | items_to_vis = min(batch_size, max_items) 68 | result = [] 69 | for i in range(items_to_vis): 70 | cur_dct = {k: tens[i] for k, tens in batch.items()} 71 | result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask, 72 | rescale_keys=rescale_keys)) 73 | return np.concatenate(result, axis=0) 74 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/colors.py: -------------------------------------------------------------------------------- 1 | import random 2 | import colorsys 3 | 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | from matplotlib.colors import LinearSegmentedColormap 9 | 10 | 11 | def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False): 12 | # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib 13 | """ 14 | Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks 15 | :param nlabels: Number of labels (size of colormap) 16 | :param type: 'bright' for strong colors, 'soft' for pastel colors 17 | :param first_color_black: Option to use first color as black, True or False 18 | :param last_color_black: Option to use last color as black, True or False 19 | :param verbose: Prints the number of labels and shows the colormap. True or False 20 | :return: colormap for matplotlib 21 | """ 22 | if type not in ('bright', 'soft'): 23 | print ('Please choose "bright" or "soft" for type') 24 | return 25 | 26 | if verbose: 27 | print('Number of labels: ' + str(nlabels)) 28 | 29 | # Generate color map for bright colors, based on hsv 30 | if type == 'bright': 31 | randHSVcolors = [(np.random.uniform(low=0.0, high=1), 32 | np.random.uniform(low=0.2, high=1), 33 | np.random.uniform(low=0.9, high=1)) for i in range(nlabels)] 34 | 35 | # Convert HSV list to RGB 36 | randRGBcolors = [] 37 | for HSVcolor in randHSVcolors: 38 | randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) 39 | 40 | if first_color_black: 41 | randRGBcolors[0] = [0, 0, 0] 42 | 43 | if last_color_black: 44 | randRGBcolors[-1] = [0, 0, 0] 45 | 46 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 47 | 48 | # Generate soft pastel colors, by limiting the RGB spectrum 49 | if type == 'soft': 50 | low = 0.6 51 | high = 0.95 52 | randRGBcolors = [(np.random.uniform(low=low, high=high), 53 | np.random.uniform(low=low, high=high), 54 | np.random.uniform(low=low, high=high)) for i in range(nlabels)] 55 | 56 | if first_color_black: 57 | randRGBcolors[0] = [0, 0, 0] 58 | 59 | if last_color_black: 60 | randRGBcolors[-1] = [0, 0, 0] 61 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 62 | 63 | # Display colorbar 64 | if verbose: 65 | from matplotlib import colors, colorbar 66 | from matplotlib import pyplot as plt 67 | fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) 68 | 69 | bounds = np.linspace(0, nlabels, nlabels + 1) 70 | norm = colors.BoundaryNorm(bounds, nlabels) 71 | 72 | cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None, 73 | boundaries=bounds, format='%1i', orientation=u'horizontal') 74 | 75 | return randRGBcolors, random_colormap 76 | 77 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch 7 | from saicinpainting.utils import check_and_warn_input_range 8 | 9 | 10 | class DirectoryVisualizer(BaseVisualizer): 11 | DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ') 12 | 13 | def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10, 14 | last_without_mask=True, rescale_keys=None): 15 | self.outdir = outdir 16 | os.makedirs(self.outdir, exist_ok=True) 17 | self.key_order = key_order 18 | self.max_items_in_batch = max_items_in_batch 19 | self.last_without_mask = last_without_mask 20 | self.rescale_keys = rescale_keys 21 | 22 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 23 | check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image') 24 | vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch, 25 | last_without_mask=self.last_without_mask, 26 | rescale_keys=self.rescale_keys) 27 | 28 | vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8') 29 | 30 | curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}') 31 | os.makedirs(curoutdir, exist_ok=True) 32 | rank_suffix = f'_r{rank}' if rank is not None else '' 33 | out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg') 34 | 35 | vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR) 36 | cv2.imwrite(out_fname, vis_img) 37 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/noop.py: -------------------------------------------------------------------------------- 1 | from saicinpainting.training.visualizers.base import BaseVisualizer 2 | 3 | 4 | class NoopVisualizer(BaseVisualizer): 5 | def __init__(self, *args, **kwargs): 6 | pass 7 | 8 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 9 | pass 10 | --------------------------------------------------------------------------------