├── Generation ├── data │ └── ImageNet1K.py └── utils │ ├── __pycache__ │ ├── utils.cpython-37.pyc │ ├── utils.cpython-38.pyc │ └── utils.cpython-39.pyc │ └── utils.py ├── LICENSE ├── README.md ├── data └── new_load_data.py ├── extract.sh ├── extract_feature.py ├── finetune ├── train_lora.sh └── train_text_to_image_lora.py ├── generate.py ├── requirements.txt ├── shell_generate.sh ├── train.py └── train.sh /Generation/data/ImageNet1K.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from PIL import Image 7 | 8 | def get_image_paths_from_file(file_path): 9 | """ 10 | Extracts the list of image paths from a text file. 11 | Each line in the file is assumed to have the format '/path/to/image.jpeg number' 12 | 13 | Args: 14 | file_path (str): The path to the text file. 15 | 16 | Returns: 17 | List[str]: A list of image paths. 18 | """ 19 | with open(file_path, 'r') as f: 20 | lines = f.readlines() 21 | 22 | 23 | image_paths, labels = [], [] 24 | for line in lines: 25 | # if if_imagenette: 26 | # class_name = line.split()[0].split('/')[1] 27 | # if not class_name in _LABEL_MAP: 28 | # continue 29 | image_paths.append(line.split()[0]) 30 | labels.append(line.split()[-1]) 31 | 32 | # image_paths = [line.split()[0] for line in lines] 33 | # labels = [line.split()[-1] for line in lines] 34 | # print(image_paths) 35 | return image_paths, labels 36 | 37 | def mirror_directory_structure(img_paths, source_directory, dest_directory): 38 | """ 39 | Creates a mirror of the directory structure of source_directory in dest_directory. 40 | It does this based on the unique class directories specified in img_paths. 41 | 42 | Args: 43 | img_paths (list): List of image paths with structure 'class_name/image_name.jpeg'. 44 | source_directory (str): The path of the source directory. 45 | dest_directory (str): The path of the destination directory. 46 | 47 | Returns: 48 | None 49 | """ 50 | 51 | unique_class_names = set(path.split('/')[1] for path in img_paths) 52 | for class_name in unique_class_names: 53 | os.makedirs(os.path.join(dest_directory, class_name), exist_ok=True) 54 | 55 | def create_ImageNetFolder(root_dir, out_dir): 56 | image_paths, labels = get_image_paths_from_file(os.path.join(root_dir,"file_list.txt")) 57 | mirror_directory_structure(image_paths, root_dir, out_dir) 58 | 59 | if __name__ == "__main__": 60 | torch.backends.cudnn.benchmark = True 61 | main() -------------------------------------------------------------------------------- /Generation/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BAAI-DCAI/Training-Data-Synthesis/3be248a128eba6e20787bd757f1d15a0fb35c3e3/Generation/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /Generation/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BAAI-DCAI/Training-Data-Synthesis/3be248a128eba6e20787bd757f1d15a0fb35c3e3/Generation/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /Generation/utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BAAI-DCAI/Training-Data-Synthesis/3be248a128eba6e20787bd757f1d15a0fb35c3e3/Generation/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /Generation/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import islice 3 | import random 4 | from PIL import Image 5 | from PIL import Image, ImageDraw, ImageFont 6 | 7 | 8 | 9 | 10 | 11 | def mkdir(path): 12 | folder = os.path.exists(path) 13 | if not folder: 14 | os.makedirs(path) 15 | 16 | def image_grid(imgs, rows, cols): 17 | assert len(imgs) == rows*cols 18 | 19 | w, h = imgs[0].size 20 | grid = Image.new('RGB', size=(cols*w, rows*h)) 21 | grid_w, grid_h = grid.size 22 | 23 | for i, img in enumerate(imgs): 24 | grid.paste(img, box=(i%cols*w, i//cols*h)) 25 | return grid 26 | 27 | def image_legend(class_names, colors): 28 | # Create a legend image 29 | legend_width = 512 30 | # legend_height = len(class_names) * 31 | legend_height = 512 32 | legend_image = Image.new('RGBA', (legend_width, legend_height), (255, 255, 255, 128)) 33 | 34 | # Draw class names and colors on the legend image 35 | draw = ImageDraw.Draw(legend_image) 36 | # font = ImageFont.truetype("arial.ttf", 15) 37 | 38 | for idx, (class_name, color) in enumerate(zip(class_names, colors)): 39 | draw.rectangle([(10, idx * 30), (30, idx * 30 + 20)], fill=color) 40 | draw.text((40, idx * 30), class_name, fill='black') 41 | 42 | return legend_image 43 | 44 | # ControlNet 45 | def ade_palette(): 46 | """ADE20K palette that maps each class to RGB values.""" 47 | return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 48 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 49 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 50 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 51 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 52 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 53 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 54 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 55 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 56 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 57 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 58 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 59 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 60 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 61 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], 62 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 63 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], 64 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 65 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], 66 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 67 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], 68 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 69 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], 70 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 71 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], 72 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 73 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], 74 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 75 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], 76 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 77 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], 78 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 79 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], 80 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 81 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], 82 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 83 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], 84 | [102, 255, 0], [92, 0, 255]] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 BAAI-DCAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Real-Fake: Effective Training Data Synthesis Through Distribution Matching 2 | [PDF](https://arxiv.org/pdf/2310.10402.pdf) [Project Page](https://torrvision.com/realfake/) 3 | 4 | Synthetic training data has gained prominence in numerous learning tasks and scenarios, offering advantages such as dataset augmentation, generalization evaluation, and privacy preservation. Despite these benefits, the efficiency of synthetic data generated by current methodologies remains inferior when training advanced deep models exclusively, limiting its practical utility. To address this challenge, we analyze the principles underlying training data synthesis for supervised learning and elucidate a principled theoretical framework from the distribution-matching perspective that explicates the mechanisms governing synthesis efficacy. Through extensive experiments, we demonstrate the effectiveness of our synthetic data across diverse image classification tasks, both as a replacement for and augmentation to real datasets, while also benefits challenging tasks such as out-of-distribution generalization and privacy preservation. 5 | 6 | 7 | ## Installation 8 | 9 | The project has been tested with PyTorch 2.01 and CUDA 11.7. 10 | 11 | ### Install Required Environment 12 | 13 | ```bash 14 | pip3 install -r requirements.txt 15 | ``` 16 | 17 | ## Prepare Dataset 18 | 19 | 20 | 21 | ## Download Generated Synthetic Dataset 22 | 23 | You can download the generated synthetic data from [Dataset Link](https://huggingface.co/datasets/JianhaoDYDY/Real-Fake). Please follow the instruction on Huggingface Dataset page. 24 | 25 | ## (Optional) Generate Synthetic Dataset from Scratch 26 | Download ImageNet-1K from [this link](https://www.image-net.org/download.php). 27 | 28 | ### Extract CLIP Embedding for ImageNet-1K 29 | 30 | 1. Check `./extract.sh` and specify the path to the ImageNet data. 31 | 32 | ```bash 33 | bash extract.sh 34 | ``` 35 | 36 | ### Get BLIP2 Caption for ImageNet-1K 37 | 38 | Use the implementation of the BLIP2 caption pipeline. Refer to [this paper](https://arxiv.org/abs/2307.08526) for details. 39 | 40 | ### Implement Modification on Diffuser for Customized Training Data Synthesis 41 | TODO: Release Modified diffusers for direct installation 42 | 43 | ### Train LoRA 44 | 45 | 1. Specify `CACHE_ROOT/MODEL_NAME` to the folder caching stable diffusion. 46 | 2. Check `./finetune/train_lora.sh` and specify the data version in "versions" for training LoRA. 47 | 48 | ```bash 49 | bash ./finetune/train_lora.sh 50 | ``` 51 | 52 | ### Generate Synthetic Dataset 53 | 54 | 1. After training, load the trained LoRA model to generate the Synthetic Dataset. 55 | 2. Check `shell_generate.sh` and specify the data version (1 out of 20) in "versions" for generation. 56 | 3. Review the parameter `--nchunks 8` (Number of GPUs, for example, 8). 57 | 58 | ```bash 59 | bash shell_generate.sh 60 | ``` 61 | 62 | This will save one version of the dataset to `./SyntheticData`. 63 | 64 | ## Evaluate 65 | 66 | 1. Check `train.sh` and specify `--data_dir` with "version" for training on the generated synthetic data. 67 | 2. Review `CUDA_VISIBLE_DEVICES=0,1,2` and `--nproc_per_node=3` to specify the number of GPUs used. 68 | 69 | ```bash 70 | bash train.sh 71 | ``` 72 | 73 | This will save results and the model to `./experiments/`. 74 | -------------------------------------------------------------------------------- /data/new_load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import Dataset, ConcatDataset 10 | from torchvision.datasets import ImageNet 11 | 12 | from PIL import Image 13 | import json 14 | 15 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 16 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 17 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 18 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 19 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 20 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 21 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 22 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 23 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 24 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 25 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 26 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 27 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 28 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 29 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 30 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 31 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 32 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 33 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 34 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 35 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 36 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 37 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 38 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 39 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 40 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 41 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 42 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 43 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 44 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 45 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 46 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 47 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 48 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 49 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 50 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 51 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 52 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 53 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 54 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 55 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 56 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 57 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 58 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 59 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 60 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 61 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 62 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 63 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 64 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 65 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 66 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 67 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 68 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 69 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 70 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 71 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 72 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 73 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 74 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 75 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 76 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 77 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 78 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 79 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 80 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 81 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 82 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 83 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 84 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 85 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 86 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 87 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 88 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 89 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 90 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 91 | "baluster handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 92 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 93 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 94 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 95 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 96 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 97 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 98 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 99 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box carton", 100 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 101 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 102 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 103 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 104 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 105 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 106 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 107 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 108 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 109 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 110 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 111 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 112 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 113 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 114 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 115 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 116 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 117 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 118 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 119 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 120 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 121 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 122 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 123 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 124 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 125 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 126 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 127 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 128 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 129 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 130 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 131 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 132 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 133 | "oxygen mask", "product packet packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 134 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 135 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 136 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 137 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 138 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 139 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 140 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 141 | "printer", "prison", "projectile missile", "projector", "hockey puck", "punching bag", "purse", "quill", 142 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 143 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 144 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 145 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 146 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 147 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 148 | "shoji screen room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 149 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 150 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 151 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 152 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 153 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 154 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 155 | "submarine", "suit", "sundial", "sunglass", "sunglasses", "sunscreen", "suspension bridge", 156 | "mop", "sweatshirt", "swim trunks shorts", "swing", "electrical switch", "syringe", 157 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 158 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 159 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 160 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 161 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 162 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 163 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 164 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 165 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 166 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 167 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 168 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 169 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 170 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 171 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 172 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 173 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 174 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 175 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 176 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 177 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 178 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 179 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 180 | 181 | # easy classified classes 182 | # ['n02979186', 'n03888257', 'n03394916', 'n03445777', 'n03028079', 'n01440764', 'n02102040', 'n03000684', 'n03425413', 'n03417042'] 183 | imagenette = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701] 184 | 185 | # ["australian_terrier", "border_terrier", "samoyed", "beagle", "shih-tzu", "english_foxhound", "rhodesian_ridgeback", "dingo", "golden_retriever", "english_sheepdog"] 186 | imagewoof = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229] 187 | 188 | # ["tabby_cat", "bengal_cat", "persian_cat", "siamese_cat", "egyptian_cat", "lion", "tiger", "jaguar", "snow_leopard", "lynx"] 189 | imagemeow = [281, 282, 283, 284, 285, 291, 292, 290, 289, 287] 190 | 191 | # ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"] 192 | imagesquawk = [84, 130, 88, 144, 145, 22, 96, 9, 100, 89] 193 | 194 | # ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"] 195 | imagefruit = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948] 196 | 197 | # ["bee", "ladys slipper", "banana", "lemon", "corn", "school_bus", "honeycomb", "lion", "garden_spider", "goldfinch"] 198 | imageyellow = [309, 986, 954, 951, 987, 779, 599, 291, 72, 11] 199 | 200 | # 'n03787032', 'n02114855', 'n02877765', 'n02086910', 'n04099969', 'n03032252', 'n02119022', 'n03062245', 'n02116738', 'n02869837', 'n01980166', 'n04589890', 'n03777754', 'n02138441', 'n03379051', 'n02259212', 'n03492542', 'n03530642', 'n03785016', 'n02093428', 'n04429376', 'n04517823', 'n04067472', 'n02231487', 'n01978455', 'n04026417', 'n01983481', 'n02107142', 'n02105505', 'n02488291', 'n03584829', 'n04111531', 'n02859443', 'n03930630', 'n02091831', 'n07831146', 'n02804414', 'n02109047', 'n04336792', 'n04493381', 'n02018207', 'n02087046', 'n04435653', 'n03637318', 'n02123045', 'n03085013', 'n04485082', 'n02788148', 'n02100583', 'n02104029', 'n02108089', 'n13040303', 'n04592741', 'n01820546', 'n03903868', 'n02326432', 'n01729322', 'n07836838', 'n02085620', 'n01692333', 'n03837869', 'n03947888', 'n02701002', 'n02089973', 'n01749939', 'n02009229', 'n02483362', 'n02974003', 'n03891251', 'n02099849', 'n03794056', 'n03494278', 'n03424325', 'n04238763', 'n02106550', 'n03259280', 'n03017168', 'n01558993', 'n01773797', 'n04418357', 'n02113978', 'n07753275', 'n01735189', 'n02086240', 'n04127249', 'n01855672', 'n03775546', 'n04136333', 'n02089867', 'n02090622', 'n03642806', 'n02172182', 'n07714571', 'n07715103', 'n13037406', 'n02113799', 'n03764736', 'n02396427', 'n03594734', 'n04229816' 201 | imagenet100 = [15, 45, 54, 57, 64, 74, 90, 99, 119, 120, 122, 131, 137, 151, 155, 157, 158, 166, 167, 169, 176, 180, 209, 211, 222, 228, 234, 236, 242, 246, 267, 268, 272, 275, 277, 281, 299, 305, 313, 317, 331, 342, 368, 374, 407, 421, 431, 449, 452, 455, 479, 494, 498, 503, 508, 544, 560, 570, 592, 593, 599, 606, 608, 619, 620, 653, 659, 662, 665, 667, 674, 682, 703, 708, 717, 724, 748, 758, 765, 766, 772, 775, 796, 798, 830, 854, 857, 858, 872, 876, 882, 904, 908, 936, 938, 953, 959, 960, 993, 994] 202 | 203 | imageneta = [6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107, 108, 110, 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, 307, 308, 309, 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, 372, 378, 386, 397, 400, 401, 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, 472, 483, 486, 488, 492, 496, 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614, 626, 627, 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 774, 776, 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 879, 880, 888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 981, 984, 986, 987, 988] 204 | 205 | imagenetr = [1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, 113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988] 206 | 207 | imagenetvis = [0, 217] 208 | 209 | imagenet1k = list(range(1000)) 210 | 211 | imagenet_subclass_dict = { 212 | "imagenette" : imagenette, 213 | "imagewoof" : imagewoof, 214 | "imagefruit": imagefruit, 215 | "imageyellow": imageyellow, 216 | "imagemeow": imagemeow, 217 | "imagesquawk": imagesquawk, 218 | "imagenet100": imagenet100, 219 | "imagenet1k": imagenet1k, 220 | "imagenetv2": imagenet1k, 221 | "imagenet_sketch":imagenet1k, 222 | "imageneta":imageneta, 223 | "imagenetr":imagenetr, 224 | "imagenetvis":imagenetvis, 225 | } 226 | 227 | ood_testset = ["imagenetv2","imagenet_sketch","imageneta","imagenetr"] 228 | 229 | mean=(0.48145466, 0.4578275, 0.40821073) 230 | std=(0.26862954, 0.26130258, 0.27577711) 231 | train_preprocess = transforms.Compose([ 232 | transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), 233 | transforms.RandomHorizontalFlip(p=0.5), 234 | transforms.ToTensor(), 235 | transforms.Normalize(mean=mean, std=std) 236 | ]) 237 | test_preprocess = transforms.Compose([ 238 | transforms.Resize((224, 224)), 239 | transforms.ToTensor(), 240 | transforms.Normalize(mean=mean, std=std) 241 | ]) 242 | 243 | def get_image_paths_from_file(file_path, subset): 244 | with open(file_path, 'r') as f: 245 | lines = f.readlines() 246 | 247 | valid_labels = imagenet_subclass_dict[subset] 248 | 249 | image_paths, labels = [], [] 250 | for line in lines: 251 | label = int(line.split()[-1]) 252 | if not label in valid_labels: 253 | continue 254 | if subset in ['imageneta']: 255 | image_paths.append(' '.join(line.split()[:-1])) 256 | 257 | else: 258 | image_paths.append(line.split()[0]) 259 | labels.append(label) 260 | 261 | label_map = {value: index for index, value in enumerate(valid_labels)} 262 | 263 | print("Total Length:",len(image_paths),len(labels)) 264 | 265 | 266 | return image_paths, labels, label_map 267 | 268 | class ImageNetCustom(Dataset): 269 | def __init__(self, root_dir="/zhaobai/ImageNet-1K/train", subset="imagenet1k", transform=None, file_list="file_list.txt"): 270 | self.root_dir = root_dir 271 | self.image_paths, self.labels, self.label_map = get_image_paths_from_file(os.path.join(self.root_dir,file_list), subset) 272 | self.transform = transform 273 | self.subset = subset 274 | 275 | def __len__(self): 276 | return len(self.image_paths) 277 | 278 | def __getitem__(self, idx): 279 | image_name = self.image_paths[idx] 280 | image_path = f"{self.root_dir}/{image_name}" 281 | label = self.labels[idx] 282 | image = Image.open(image_path).convert("RGB") 283 | if self.transform: 284 | image = self.transform(image) 285 | 286 | if self.subset not in ['imageneta','imagenetr']: 287 | 288 | label = self.label_map[label] 289 | 290 | return image, label 291 | 292 | class ImageNetGeneration(Dataset): 293 | r""" 294 | Subset of a dataset at specified indices. 295 | 296 | Arguments: 297 | dataset (Dataset): The whole Dataset 298 | indices (sequence): Indices in the whole set selected for subset 299 | labels(sequence) : targets as required for the indices. will be the same length as indices 300 | """ 301 | def __init__(self, root_dir="/zhaobai/ImageNet-1K/train", subset="imagenet1k", file_list="file_list.txt", use_caption=False): 302 | self.root_dir = root_dir 303 | self.image_paths, self.labels, self.label_map = get_image_paths_from_file(os.path.join(self.root_dir,file_list), subset) 304 | self.caption_path = "/zhaobai46g/Datasets/ImageNet_BLIP2_caption_json/ImageNet_BLIP2_caption_json" 305 | 306 | def __len__(self): 307 | return len(self.image_paths) 308 | 309 | def __getitem__(self, idx): 310 | image_name = self.image_paths[idx] 311 | image_path = f"{self.root_dir}/{image_name}" 312 | label = self.labels[idx] 313 | # image = Image.open(image_path).convert("RGB") 314 | class_name = imagenet_classes[int(label)] 315 | image_name = image_name.split(".")[0][1:] 316 | 317 | return (label, image_path, image_name, class_name) 318 | 319 | class CombinedImageNetCustom(Dataset): 320 | def __init__(self, root_dirs=["/zhaobai/ImageNet-1K/train"], subset="imagenet1k", transform=None, file_list="file_list.txt"): 321 | self.root_dirs = root_dirs 322 | self.transform = transform 323 | 324 | # Aggregate image paths and labels from all directories 325 | self.image_paths = [] 326 | self.labels = [] 327 | self.label_map = {} 328 | 329 | for root_dir in self.root_dirs: 330 | image_paths, labels, label_map = get_image_paths_from_file(os.path.join(root_dir, file_list), subset) 331 | 332 | # Include the full path (with root_dir) in self.image_paths 333 | full_image_paths = [f"{root_dir}/{image_path}" for image_path in image_paths] 334 | 335 | self.image_paths.extend(full_image_paths) 336 | self.labels.extend(labels) 337 | 338 | # Merge label maps (assuming they are consistent across datasets) 339 | self.label_map.update(label_map) 340 | 341 | def __len__(self): 342 | return len(self.image_paths) 343 | 344 | def __getitem__(self, idx): 345 | image_path = self.image_paths[idx] # Directly use the full path 346 | label = self.labels[idx] 347 | image = Image.open(image_path).convert("RGB") 348 | if self.transform: 349 | image = self.transform(image) 350 | 351 | label = self.label_map[label] 352 | 353 | return image, label 354 | 355 | 356 | def get_dataset(root, split="train",subset='imagenet1k',filelist="file_list.txt",transform=None): 357 | if subset not in ood_testset: 358 | root_dir = os.path.join(root,split) 359 | preprocess = train_preprocess if split == 'train' else test_preprocess 360 | else: 361 | root_dir = root 362 | preprocess = test_preprocess 363 | 364 | if transform: 365 | preprocess = transform 366 | dataset = ImageNetCustom(root_dir=root_dir, subset=subset, transform=preprocess, file_list=filelist) 367 | return dataset 368 | 369 | def get_generation_dataset(root, split="train",subset='imagenet1k',filelist="file_list.txt"): 370 | root_dir = os.path.join(root,split) 371 | dataset = ImageNetGeneration(root_dir=root_dir, subset=subset, file_list=filelist) 372 | return dataset 373 | 374 | def get_combined_dataset(roots,split="train",subset='imagenet1k',filelist="file_list.txt"): 375 | preprocess = train_preprocess if split == 'train' else test_preprocess 376 | dataset = CombinedImageNetCustom(root_dirs=roots, subset=subset, transform=preprocess, file_list=filelist) 377 | return dataset 378 | 379 | 380 | 381 | if __name__ == "__main__": 382 | # root = "/zhaobai/ImageNet-1K/train" 383 | root = "/zhaobai/yuanjianhao/ImageNet_Syn_v124/train/" 384 | ds = ImageNetCustom(root_dir=root, subset="imagenette", transform=train_preprocess) -------------------------------------------------------------------------------- /extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0 python extract_feature.py --index 0 --imagenet_path "PATH TO ImageNet-1K" 3 | 4 | wait 5 | echo "All processes completed" -------------------------------------------------------------------------------- /extract_feature.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | from tqdm import tqdm 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | from torch.utils.data import Dataset, DataLoader, Subset 10 | from diffusers.utils import load_image 11 | from Generation.data.ImageNet1K import create_ImageNetFolder 12 | from data.new_load_data import get_generation_dataset 13 | from transformers import AutoProcessor, CLIPModel 14 | from diffusers.image_processor import VaeImageProcessor 15 | 16 | class ImgFeatureExtractor(): 17 | def __init__(self,args): 18 | self.device = "cuda" 19 | self.args = args 20 | self.model = "VIT_L" 21 | 22 | def extract_feature(self): 23 | bsz = 16 24 | create_ImageNetFolder(root_dir=f'{self.args.imagenet_path}train', out_dir=f"./LoRA/ImageNet1K_CLIPEmbedding/{self.model}") 25 | ImageNetPath = self.args.imagenet_path 26 | dataset = "imagenet1k" 27 | real_dst_train = get_generation_dataset(ImageNetPath, split="train",subset=dataset,filelist="file_list.txt") 28 | dataloader = self.get_subdataset_loader(real_dst_train, bsz, num_chunks=1) 29 | # Model 30 | if self.model in ["VIT_L"]: 31 | model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device) 32 | processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") 33 | 34 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): 35 | targets, image_paths, image_names, class_names = batch 36 | bs = len(image_paths) 37 | out_paths = [os.path.join(f"./LoRA/ImageNet1K_CLIPEmbedding/{self.model}",f'{image_names[idx]}.pt') for idx in range(bs)] 38 | if os.path.exists(out_paths[-1]): 39 | continue 40 | 41 | if self.model in ["VIT_L"]: 42 | images = [Image.open(image_paths[idx]) for idx in range(bs)] 43 | inputs = processor(images=images, return_tensors="pt").to(self.device) 44 | image_features = model.get_image_features(**inputs).to(torch.float16) 45 | 46 | for idx in range(bs): 47 | torch.save(image_features[idx], out_paths[idx]) 48 | 49 | def get_subdataset_loader(self, real_dst_train, bsz, num_chunks=4): 50 | # split Task 51 | # num_chunks = 8 52 | chunk_size = len(real_dst_train) // num_chunks 53 | chunk_index = self.args.index 54 | if chunk_index == num_chunks-1: 55 | subset_indices = range(chunk_index*chunk_size, len(real_dst_train)) 56 | else: 57 | subset_indices = range(chunk_index*chunk_size, (chunk_index+1)*chunk_size) 58 | subset_dataset = Subset(real_dst_train, indices=subset_indices) 59 | dataloader = DataLoader(subset_dataset, batch_size=bsz, shuffle=False, num_workers=4) 60 | return dataloader 61 | 62 | 63 | def get_args(): 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--index",default=0,type=int,help="split task") 66 | parser.add_argument("--imagenet_path",default="",type=str,help="path to imagenet") 67 | args = parser.parse_args() 68 | return args 69 | 70 | def main(): 71 | args = get_args() 72 | extractor = ImgFeatureExtractor(args) 73 | extractor.extract_feature() 74 | 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | torch.backends.cudnn.benchmark = True 80 | main() -------------------------------------------------------------------------------- /finetune/train_lora.sh: -------------------------------------------------------------------------------- 1 | # export CACHE_ROOT="ROOT PATH" 2 | # export MODEL_NAME="${CACHE_ROOT}/huggingface/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819" 3 | export MODEL_NAME="PATH TO SD" 4 | versions=('v1') 5 | methods=("gt_dm") 6 | length=${#versions[@]} 7 | # Loop through the array and set the OUTPUT_DIR variable accordingly 8 | for method in "${methods[@]}"; do 9 | for ((i=0; i<$length; i++)); do 10 | version="${versions[$i]}" 11 | export OUTPUT_DIR="./LoRA/checkpoint/${method}_${version}/all" 12 | export DATASET_NAME="./LoRA/train/" 13 | export LOG_DIR="./LoRA/train/logs" 14 | echo "Current OUTPUT_DIR: $OUTPUT_DIR, $i" 15 | if [ -f "$OUTPUT_DIR/pytorch_lora_weights.bin" ] || [ -f "$OUTPUT_DIR/pytorch_lora_weights.safetensors" ]; then 16 | echo "Folder exists. Skipping script execution." 17 | else 18 | if [ "$method" == "gt_dm" ]; then 19 | echo "script execution. ${method}" 20 | accelerate launch --mixed_precision="fp16" ./finetune/train_text_to_image_lora.py \ 21 | --pretrained_model_name_or_path=$MODEL_NAME \ 22 | --train_data_dir $DATASET_NAME --caption_column="text" \ 23 | --report_to=tensorboard \ 24 | --resolution=512 --random_flip \ 25 | --train_batch_size=8 \ 26 | --num_train_epochs=100 --checkpointing_steps=500 \ 27 | --learning_rate=1e-04 --lr_scheduler="constant" \ 28 | --seed=42 \ 29 | --output_dir=${OUTPUT_DIR} \ 30 | --snr_gamma=5 \ 31 | --guidance_token=8 \ 32 | --dist_match=0.003 \ 33 | --logging_dir $LOG_DIR \ 34 | --exp_id ${i} \ 35 | else 36 | echo "Method not implemented" 37 | fi 38 | wait 39 | echo "All processes completed" 40 | fi 41 | done 42 | done -------------------------------------------------------------------------------- /finetune/train_text_to_image_lora.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" 16 | 17 | import argparse 18 | import logging 19 | import math 20 | import os 21 | import random 22 | import shutil 23 | from pathlib import Path 24 | import copy 25 | from collections import deque 26 | 27 | import datasets 28 | import numpy as np 29 | import torch 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | import transformers 33 | from accelerate import Accelerator 34 | from accelerate.logging import get_logger 35 | from accelerate.utils import ProjectConfiguration, set_seed 36 | from datasets import load_dataset 37 | from huggingface_hub import create_repo, upload_folder 38 | from packaging import version 39 | from torchvision import transforms 40 | from tqdm.auto import tqdm 41 | from transformers import CLIPTextModel, CLIPTokenizer 42 | 43 | import diffusers 44 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel 45 | from diffusers.loaders import AttnProcsLayers 46 | from diffusers.models.attention_processor import LoRAAttnProcessor 47 | from diffusers.optimization import get_scheduler 48 | from diffusers.utils import check_min_version, is_wandb_available 49 | from diffusers.utils.import_utils import is_xformers_available 50 | 51 | 52 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 53 | check_min_version("0.20.0.dev0") 54 | 55 | logger = get_logger(__name__, log_level="INFO") 56 | 57 | def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): 58 | img_str = "" 59 | for i, image in enumerate(images): 60 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 61 | img_str += f"![img_{i}](./image_{i}.png)\n" 62 | 63 | yaml = f""" 64 | --- 65 | license: creativeml-openrail-m 66 | base_model: {base_model} 67 | tags: 68 | - stable-diffusion 69 | - stable-diffusion-diffusers 70 | - text-to-image 71 | - diffusers 72 | - lora 73 | inference: true 74 | --- 75 | """ 76 | model_card = f""" 77 | # LoRA text2image fine-tuning - {repo_id} 78 | These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n 79 | {img_str} 80 | """ 81 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 82 | f.write(yaml + model_card) 83 | 84 | 85 | 86 | def parse_args(): 87 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 88 | parser.add_argument('--exp_id', type=int, default=0, help='mode id') 89 | parser.add_argument( 90 | "--pretrained_model_name_or_path", 91 | type=str, 92 | default=None, 93 | required=True, 94 | help="Path to pretrained model or model identifier from huggingface.co/models.", 95 | ) 96 | parser.add_argument( 97 | "--revision", 98 | type=str, 99 | default=None, 100 | required=False, 101 | help="Revision of pretrained model identifier from huggingface.co/models.", 102 | ) 103 | parser.add_argument( 104 | "--dataset_name", 105 | type=str, 106 | default=None, 107 | help=( 108 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 109 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 110 | " or to a folder containing files that 🤗 Datasets can understand." 111 | ), 112 | ) 113 | parser.add_argument( 114 | "--dataset_config_name", 115 | type=str, 116 | default=None, 117 | help="The config of the Dataset, leave as None if there's only one config.", 118 | ) 119 | parser.add_argument( 120 | "--train_data_dir", 121 | type=str, 122 | default=None, 123 | help=( 124 | "A folder containing the training data. Folder contents must follow the structure described in" 125 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 126 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 127 | ), 128 | ) 129 | parser.add_argument( 130 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 131 | ) 132 | parser.add_argument( 133 | "--caption_column", 134 | type=str, 135 | default="text", 136 | help="The column of the dataset containing a caption or a list of captions.", 137 | ) 138 | parser.add_argument( 139 | "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." 140 | ) 141 | parser.add_argument( 142 | "--num_validation_images", 143 | type=int, 144 | default=4, 145 | help="Number of images that should be generated during validation with `validation_prompt`.", 146 | ) 147 | parser.add_argument( 148 | "--validation_epochs", 149 | type=int, 150 | default=1, 151 | help=( 152 | "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" 153 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 154 | ), 155 | ) 156 | parser.add_argument( 157 | "--max_train_samples", 158 | type=int, 159 | default=None, 160 | help=( 161 | "For debugging purposes or quicker training, truncate the number of training examples to this " 162 | "value if set." 163 | ), 164 | ) 165 | parser.add_argument( 166 | "--output_dir", 167 | type=str, 168 | default="sd-model-finetuned-lora", 169 | help="The output directory where the model predictions and checkpoints will be written.", 170 | ) 171 | parser.add_argument( 172 | "--cache_dir", 173 | type=str, 174 | default=None, 175 | help="The directory where the downloaded models and datasets will be stored.", 176 | ) 177 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 178 | parser.add_argument( 179 | "--resolution", 180 | type=int, 181 | default=512, 182 | help=( 183 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 184 | " resolution" 185 | ), 186 | ) 187 | parser.add_argument( 188 | "--center_crop", 189 | default=False, 190 | action="store_true", 191 | help=( 192 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 193 | " cropped. The images will be resized to the resolution first before cropping." 194 | ), 195 | ) 196 | parser.add_argument( 197 | "--random_flip", 198 | action="store_true", 199 | help="whether to randomly flip images horizontally", 200 | ) 201 | parser.add_argument( 202 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 203 | ) 204 | parser.add_argument("--num_train_epochs", type=int, default=100) 205 | parser.add_argument( 206 | "--max_train_steps", 207 | type=int, 208 | default=None, 209 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 210 | ) 211 | parser.add_argument( 212 | "--gradient_accumulation_steps", 213 | type=int, 214 | default=1, 215 | help="Number of updates steps to accumulate before performing a backward/update pass.", 216 | ) 217 | parser.add_argument( 218 | "--gradient_checkpointing", 219 | action="store_true", 220 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 221 | ) 222 | parser.add_argument( 223 | "--learning_rate", 224 | type=float, 225 | default=1e-4, 226 | help="Initial learning rate (after the potential warmup period) to use.", 227 | ) 228 | parser.add_argument( 229 | "--scale_lr", 230 | action="store_true", 231 | default=False, 232 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 233 | ) 234 | parser.add_argument( 235 | "--lr_scheduler", 236 | type=str, 237 | default="constant", 238 | help=( 239 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 240 | ' "constant", "constant_with_warmup"]' 241 | ), 242 | ) 243 | parser.add_argument( 244 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 245 | ) 246 | parser.add_argument( 247 | "--snr_gamma", 248 | type=float, 249 | default=None, 250 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 251 | "More details here: https://arxiv.org/abs/2303.09556.", 252 | ) 253 | parser.add_argument( 254 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 255 | ) 256 | parser.add_argument( 257 | "--allow_tf32", 258 | action="store_true", 259 | help=( 260 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 261 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 262 | ), 263 | ) 264 | parser.add_argument( 265 | "--dataloader_num_workers", 266 | type=int, 267 | default=0, 268 | help=( 269 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 270 | ), 271 | ) 272 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 273 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 274 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 275 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 276 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 277 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 278 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 279 | parser.add_argument( 280 | "--prediction_type", 281 | type=str, 282 | default=None, 283 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", 284 | ) 285 | parser.add_argument( 286 | "--hub_model_id", 287 | type=str, 288 | default=None, 289 | help="The name of the repository to keep in sync with the local `output_dir`.", 290 | ) 291 | parser.add_argument( 292 | "--logging_dir", 293 | type=str, 294 | default="logs", 295 | help=( 296 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 297 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 298 | ), 299 | ) 300 | parser.add_argument( 301 | "--mixed_precision", 302 | type=str, 303 | default=None, 304 | choices=["no", "fp16", "bf16"], 305 | help=( 306 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 307 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 308 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 309 | ), 310 | ) 311 | parser.add_argument( 312 | "--report_to", 313 | type=str, 314 | default="tensorboard", 315 | help=( 316 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 317 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 318 | ), 319 | ) 320 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 321 | parser.add_argument( 322 | "--checkpointing_steps", 323 | type=int, 324 | default=500, 325 | help=( 326 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 327 | " training using `--resume_from_checkpoint`." 328 | ), 329 | ) 330 | parser.add_argument( 331 | "--checkpoints_total_limit", 332 | type=int, 333 | default=None, 334 | help=("Max number of checkpoints to store."), 335 | ) 336 | parser.add_argument( 337 | "--resume_from_checkpoint", 338 | type=str, 339 | default=None, 340 | help=( 341 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 342 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 343 | ), 344 | ) 345 | parser.add_argument( 346 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 347 | ) 348 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") 349 | parser.add_argument( 350 | "--rank", 351 | type=int, 352 | default=4, 353 | help=("The dimension of the LoRA update matrices."), 354 | ) 355 | ### Add DY 356 | parser.add_argument("--guidance_token", type=float, default=0, help="If use Guidance Token") 357 | parser.add_argument( 358 | "--dist_match", 359 | type=float, 360 | default=None, 361 | help="If add distribution match loss, and what's the weight", 362 | ) 363 | 364 | args = parser.parse_args() 365 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 366 | if env_local_rank != -1 and env_local_rank != args.local_rank: 367 | args.local_rank = env_local_rank 368 | 369 | # Sanity checks 370 | if args.dataset_name is None and args.train_data_dir is None: 371 | raise ValueError("Need either a dataset name or a training folder.") 372 | 373 | return args 374 | 375 | 376 | DATASET_NAME_MAPPING = { 377 | "lambdalabs/pokemon-blip-captions": ("image", "text"), 378 | } 379 | 380 | 381 | 382 | def main(): 383 | args = parse_args() 384 | logging_dir = Path(args.output_dir, args.logging_dir) 385 | 386 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 387 | 388 | accelerator = Accelerator( 389 | gradient_accumulation_steps=args.gradient_accumulation_steps, 390 | mixed_precision=args.mixed_precision, 391 | log_with=args.report_to, 392 | project_config=accelerator_project_config, 393 | ) 394 | if args.report_to == "wandb": 395 | if not is_wandb_available(): 396 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 397 | import wandb 398 | 399 | # Make one log on every process with the configuration for debugging. 400 | logging.basicConfig( 401 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 402 | datefmt="%m/%d/%Y %H:%M:%S", 403 | level=logging.INFO, 404 | ) 405 | logger.info(accelerator.state, main_process_only=False) 406 | if accelerator.is_local_main_process: 407 | datasets.utils.logging.set_verbosity_warning() 408 | transformers.utils.logging.set_verbosity_warning() 409 | diffusers.utils.logging.set_verbosity_info() 410 | else: 411 | datasets.utils.logging.set_verbosity_error() 412 | transformers.utils.logging.set_verbosity_error() 413 | diffusers.utils.logging.set_verbosity_error() 414 | 415 | # If passed along, set the training seed now. 416 | if args.seed is not None: 417 | set_seed(args.seed) 418 | 419 | # Handle the repository creation 420 | if accelerator.is_main_process: 421 | if args.output_dir is not None: 422 | os.makedirs(args.output_dir, exist_ok=True) 423 | 424 | if args.push_to_hub: 425 | repo_id = create_repo( 426 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 427 | ).repo_id 428 | # Load scheduler, tokenizer and models. 429 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 430 | tokenizer = CLIPTokenizer.from_pretrained( 431 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision 432 | ) 433 | text_encoder = CLIPTextModel.from_pretrained( 434 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 435 | ) 436 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 437 | unet = UNet2DConditionModel.from_pretrained( 438 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, local_files_only=True 439 | ) 440 | 441 | # freeze parameters of models to save more memory 442 | unet.requires_grad_(False) 443 | vae.requires_grad_(False) 444 | text_encoder.requires_grad_(False) 445 | 446 | 447 | 448 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision 449 | # as these weights are only used for inference, keeping weights in full precision is not required. 450 | weight_dtype = torch.float32 451 | if accelerator.mixed_precision == "fp16": 452 | weight_dtype = torch.float16 453 | elif accelerator.mixed_precision == "bf16": 454 | weight_dtype = torch.bfloat16 455 | 456 | # Move unet, vae and text_encoder to device and cast to weight_dtype 457 | unet.to(accelerator.device, dtype=weight_dtype) 458 | vae.to(accelerator.device, dtype=weight_dtype) 459 | text_encoder.to(accelerator.device, dtype=weight_dtype) 460 | 461 | # now we will add new LoRA weights to the attention layers 462 | # It's important to realize here how many attention weights will be added and of which sizes 463 | # The sizes of the attention layers consist only of two different variables: 464 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. 465 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. 466 | 467 | # Let's first see how many attention processors we will have to set. 468 | # For Stable Diffusion, it should be equal to: 469 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 470 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 471 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 472 | # => 32 layers 473 | 474 | # Set correct lora layers 475 | lora_attn_procs = {} 476 | for name in unet.attn_processors.keys(): 477 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 478 | if name.startswith("mid_block"): 479 | hidden_size = unet.config.block_out_channels[-1] 480 | elif name.startswith("up_blocks"): 481 | block_id = int(name[len("up_blocks.")]) 482 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 483 | elif name.startswith("down_blocks"): 484 | block_id = int(name[len("down_blocks.")]) 485 | hidden_size = unet.config.block_out_channels[block_id] 486 | 487 | lora_attn_procs[name] = LoRAAttnProcessor( 488 | hidden_size=hidden_size, 489 | cross_attention_dim=cross_attention_dim, 490 | rank=args.rank, 491 | ) 492 | 493 | unet.set_attn_processor(lora_attn_procs) 494 | 495 | if args.enable_xformers_memory_efficient_attention: 496 | if is_xformers_available(): 497 | import xformers 498 | 499 | xformers_version = version.parse(xformers.__version__) 500 | if xformers_version == version.parse("0.0.16"): 501 | logger.warn( 502 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 503 | ) 504 | unet.enable_xformers_memory_efficient_attention() 505 | else: 506 | raise ValueError("xformers is not available. Make sure it is installed correctly") 507 | 508 | def compute_snr(timesteps): 509 | """ 510 | Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 511 | """ 512 | alphas_cumprod = noise_scheduler.alphas_cumprod 513 | sqrt_alphas_cumprod = alphas_cumprod**0.5 514 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 515 | 516 | # Expand the tensors. 517 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 518 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 519 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 520 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 521 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 522 | 523 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 524 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 525 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 526 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 527 | 528 | # Compute SNR. 529 | snr = (alpha / sigma) ** 2 530 | return snr 531 | 532 | lora_layers = AttnProcsLayers(unet.attn_processors) 533 | 534 | # Enable TF32 for faster training on Ampere GPUs, 535 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 536 | if args.allow_tf32: 537 | torch.backends.cuda.matmul.allow_tf32 = True 538 | 539 | if args.scale_lr: 540 | args.learning_rate = ( 541 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 542 | ) 543 | 544 | # Initialize the optimizer 545 | if args.use_8bit_adam: 546 | try: 547 | import bitsandbytes as bnb 548 | except ImportError: 549 | raise ImportError( 550 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 551 | ) 552 | 553 | optimizer_cls = bnb.optim.AdamW8bit 554 | else: 555 | optimizer_cls = torch.optim.AdamW 556 | 557 | optimizer = optimizer_cls( 558 | lora_layers.parameters(), 559 | lr=args.learning_rate, 560 | betas=(args.adam_beta1, args.adam_beta2), 561 | weight_decay=args.adam_weight_decay, 562 | eps=args.adam_epsilon, 563 | ) 564 | 565 | # Get the datasets: you can either provide your own training and evaluation files (see below) 566 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 567 | 568 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 569 | # download the dataset. 570 | 571 | if args.dataset_name is not None: 572 | # Downloading and loading a dataset from the hub. 573 | dataset = load_dataset( 574 | args.dataset_name, 575 | args.dataset_config_name, 576 | cache_dir=args.cache_dir, 577 | ) 578 | else: 579 | #! load our dataset 580 | data_files = {} 581 | if args.train_data_dir is not None: 582 | data_files["train"] = os.path.join(args.train_data_dir, "**") 583 | dataset = load_dataset( 584 | "imagefolder", 585 | data_files=data_files, 586 | cache_dir=args.cache_dir, 587 | ) 588 | # See more about loading custom images at 589 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 590 | 591 | # Preprocessing the datasets. 592 | # We need to tokenize inputs and targets. 593 | #! column_names: ['image', 'text', 'img_features'] 594 | #! metadata: {"file_name": "n04517823/n04517823_8279.JPEG", "text": "photo of vacuum cleaner, a vacuum cleaner sitting on a tiled floor", "img_features": "n04517823/n04517823_8279.pt"} 595 | column_names = dataset["train"].column_names 596 | 597 | # 6. Get the column names for input/target. 598 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) #! none 599 | if args.image_column is None: 600 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 601 | else: #! args.image_column=image 602 | image_column = args.image_column 603 | if image_column not in column_names: 604 | raise ValueError( 605 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 606 | ) 607 | if args.caption_column is None: 608 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 609 | else: #! args.caption_column = text 610 | caption_column = args.caption_column 611 | if caption_column not in column_names: 612 | raise ValueError( 613 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 614 | ) 615 | embedding_column = 'img_features' 616 | # Preprocessing the datasets. 617 | # We need to tokenize input captions and transform the images. 618 | def tokenize_captions(examples, is_train=True): 619 | captions = [] 620 | for caption in examples[caption_column]: 621 | if isinstance(caption, str): 622 | captions.append(caption) 623 | elif isinstance(caption, (list, np.ndarray)): 624 | # take a random caption if there are multiple 625 | captions.append(random.choice(caption) if is_train else caption[0]) 626 | else: 627 | raise ValueError( 628 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 629 | ) 630 | inputs = tokenizer( 631 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 632 | ) 633 | return inputs.input_ids 634 | 635 | # Load Pre-extracted CLIP embeddings: 636 | def load_embeddings(examples, is_train=True): 637 | embeddings = [] 638 | for emb_path in examples[embedding_column]: 639 | embeddings.append(torch.load(os.path.join("./LoRA/CLIPEmbedding/train", emb_path))) 640 | return embeddings 641 | 642 | # Preprocessing the datasets. 643 | train_transforms = transforms.Compose( 644 | [ #! args.resolution = 512 645 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 646 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 647 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 648 | transforms.ToTensor(), 649 | transforms.Normalize([0.5], [0.5]), 650 | ] 651 | ) 652 | 653 | def preprocess_train(examples): 654 | images = [image.convert("RGB") for image in examples[image_column]] 655 | examples["pixel_values"] = [train_transforms(image) for image in images] 656 | examples["input_ids"] = tokenize_captions(examples) 657 | examples["img_features"] = load_embeddings(examples) 658 | return examples 659 | 660 | with accelerator.main_process_first(): 661 | # if args.max_train_samples is not None: 662 | # dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 663 | def split_half_data(dataset): 664 | #! select the half data 665 | np.random.seed(0) 666 | num = 10 667 | keep = np.random.uniform(0,1,size=(num, len(dataset["train"]))) 668 | order = keep.argsort(0) 669 | keep = order < int(0.5 * num) 670 | indices = np.array(keep[args.exp_id], dtype=bool) 671 | indices = np.where(indices==True)[0] 672 | np.save(os.path.join("lira_exp/indices", f"indice_{args.exp_id}.npy"), indices) 673 | dataset["train"] = dataset["train"].select(indices) 674 | return dataset 675 | #! end of selection 676 | # Set the training transforms 677 | dataset = split_half_data(dataset) 678 | train_dataset = dataset["train"].with_transform(preprocess_train) 679 | 680 | # Customise 681 | def collate_fn(examples): 682 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 683 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 684 | input_ids = torch.stack([example["input_ids"] for example in examples]) 685 | img_embeddings = torch.stack([example["img_features"] for example in examples]) 686 | return {"pixel_values": pixel_values, "input_ids": input_ids, "img_features": img_embeddings} 687 | 688 | # DataLoaders creation: 689 | train_dataloader = torch.utils.data.DataLoader( 690 | train_dataset, 691 | shuffle=True, 692 | collate_fn=collate_fn, 693 | batch_size=args.train_batch_size, 694 | num_workers=args.dataloader_num_workers, 695 | ) 696 | 697 | # Scheduler and math around the number of training steps. 698 | overrode_max_train_steps = False 699 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 700 | if args.max_train_steps is None: 701 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 702 | overrode_max_train_steps = True 703 | 704 | lr_scheduler = get_scheduler( 705 | args.lr_scheduler, 706 | optimizer=optimizer, 707 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 708 | num_training_steps=args.max_train_steps * accelerator.num_processes, 709 | ) 710 | 711 | # Prepare everything with our `accelerator`. 712 | lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 713 | lora_layers, optimizer, train_dataloader, lr_scheduler 714 | ) 715 | 716 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 717 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 718 | if overrode_max_train_steps: 719 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 720 | # Afterwards we recalculate our number of training epochs 721 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 722 | 723 | # We need to initialize the trackers we use, and also store our configuration. 724 | # The trackers initializes automatically on the main process. 725 | if accelerator.is_main_process: 726 | accelerator.init_trackers("text2image-fine-tune", config=vars(args)) 727 | 728 | # Train! 729 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 730 | 731 | logger.info("***** Running training *****") 732 | logger.info(f" Num examples = {len(train_dataset)}") 733 | logger.info(f" Num Epochs = {args.num_train_epochs}") 734 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 735 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 736 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 737 | logger.info(f" Total optimization steps = {args.max_train_steps}") 738 | global_step = 0 739 | first_epoch = 0 740 | 741 | # Potentially load in the weights and states from a previous save 742 | if args.resume_from_checkpoint: 743 | if args.resume_from_checkpoint != "latest": 744 | path = os.path.basename(args.resume_from_checkpoint) 745 | else: 746 | # Get the most recent checkpoint 747 | dirs = os.listdir(args.output_dir) 748 | dirs = [d for d in dirs if d.startswith("checkpoint")] 749 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 750 | path = dirs[-1] if len(dirs) > 0 else None 751 | 752 | if path is None: 753 | accelerator.print( 754 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 755 | ) 756 | args.resume_from_checkpoint = None 757 | else: 758 | accelerator.print(f"Resuming from checkpoint {path}") 759 | accelerator.load_state(os.path.join(args.output_dir, path)) 760 | global_step = int(path.split("-")[1]) 761 | 762 | resume_global_step = global_step * args.gradient_accumulation_steps 763 | first_epoch = global_step // num_update_steps_per_epoch 764 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 765 | 766 | # Only show the progress bar once on each machine. 767 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 768 | progress_bar.set_description("Steps") 769 | 770 | for epoch in range(first_epoch, args.num_train_epochs): 771 | unet.train() 772 | train_loss = 0.0 773 | for step, batch in enumerate(train_dataloader): 774 | # Skip steps until we reach the resumed step 775 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 776 | if step % args.gradient_accumulation_steps == 0: 777 | progress_bar.update(1) 778 | continue 779 | 780 | with accelerator.accumulate(unet): 781 | # Convert images to latent space 782 | #! batch['pixel_values'].shape:[8, 3, 512, 512], batch['input_ids'].shape:[8, 77], batch['img_features'].shape:[8, 768] 783 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 784 | latents = latents * vae.config.scaling_factor 785 | # print("latent",latents.shape) 786 | # Sample noise that we'll add to the latents 787 | noise = torch.randn_like(latents) 788 | if args.noise_offset: 789 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 790 | noise += args.noise_offset * torch.randn( 791 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device 792 | ) 793 | 794 | bsz = latents.shape[0] 795 | # Sample a random timestep for each image 796 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 797 | timesteps = timesteps.long() 798 | 799 | # Add noise to the latents according to the noise magnitude at each timestep 800 | # (this is the forward diffusion process) 801 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 802 | 803 | # Get the text embedding for conditioning 804 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 805 | 806 | if args.guidance_token > 0: #! 8.0 807 | guidance_token = torch.mean(batch["img_features"],axis=0) 808 | # hidden_dim = guidance_token.shape[-1] 809 | # repeated_guidance_token = guidance_token.repeat(bsz).view(bsz, hidden_dim) # n * 1 * 784 810 | # encoder_hidden_states[:, -1, :] = repeated_guidance_token # n * 77 * 784 811 | 812 | guidance_token = batch["img_features"].view(bsz, -1) 813 | encoder_hidden_states[:, -1, :] = guidance_token 814 | 815 | # Get the target for loss depending on the prediction type 816 | if args.prediction_type is not None: 817 | # set prediction_type of scheduler if defined 818 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 819 | 820 | if noise_scheduler.config.prediction_type == "epsilon": 821 | target = noise 822 | elif noise_scheduler.config.prediction_type == "v_prediction": 823 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 824 | else: 825 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 826 | 827 | # Predict the noise residual and compute loss 828 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 829 | if args.snr_gamma is None: 830 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 831 | else: 832 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 833 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 834 | # This is discussed in Section 4.2 of the same paper. 835 | snr = compute_snr(timesteps) 836 | mse_loss_weights = ( 837 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 838 | ) 839 | # We first calculate the original loss. Then we mean over the non-batch dimensions and 840 | # rebalance the sample-wise losses with their respective loss weights. 841 | # Finally, we take the mean of the rebalanced loss. 842 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 843 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights 844 | loss = loss.mean() 845 | # print("Loss",loss) 846 | 847 | #### Add Dist Matching Loss #### 848 | if args.dist_match: 849 | mse_loss_weights = mse_loss_weights.view(bsz, 1, 1, 1) 850 | model_pred_ws = (model_pred.float() * mse_loss_weights).sum(dim=0) 851 | target_ws = (target.float() * mse_loss_weights).sum(dim=0) 852 | dist_loss = F.mse_loss(model_pred_ws, target_ws, reduction="mean") 853 | loss = loss + dist_loss * args.dist_match 854 | # print("Dist",dist_loss) 855 | 856 | # Gather the losses across all processes for logging (if we use distributed training). 857 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 858 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 859 | 860 | # Backpropagate 861 | accelerator.backward(loss) 862 | if accelerator.sync_gradients: 863 | params_to_clip = lora_layers.parameters() 864 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 865 | optimizer.step() 866 | lr_scheduler.step() 867 | optimizer.zero_grad() 868 | 869 | # Checks if the accelerator has performed an optimization step behind the scenes 870 | if accelerator.sync_gradients: 871 | progress_bar.update(1) 872 | global_step += 1 873 | accelerator.log({"train_loss": train_loss}, step=global_step) 874 | train_loss = 0.0 875 | 876 | if global_step % args.checkpointing_steps == 0: 877 | if accelerator.is_main_process: 878 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 879 | if args.checkpoints_total_limit is not None: 880 | checkpoints = os.listdir(args.output_dir) 881 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 882 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 883 | 884 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 885 | if len(checkpoints) >= args.checkpoints_total_limit: 886 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 887 | removing_checkpoints = checkpoints[0:num_to_remove] 888 | 889 | logger.info( 890 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 891 | ) 892 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 893 | 894 | for removing_checkpoint in removing_checkpoints: 895 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 896 | shutil.rmtree(removing_checkpoint) 897 | 898 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 899 | accelerator.save_state(save_path) 900 | logger.info(f"Saved state to {save_path}") 901 | 902 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 903 | progress_bar.set_postfix(**logs) 904 | 905 | if global_step >= args.max_train_steps: 906 | break 907 | 908 | if accelerator.is_main_process: 909 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: 910 | logger.info( 911 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 912 | f" {args.validation_prompt}." 913 | ) 914 | # create pipeline 915 | pipeline = DiffusionPipeline.from_pretrained( 916 | args.pretrained_model_name_or_path, 917 | unet=accelerator.unwrap_model(unet), 918 | revision=args.revision, 919 | torch_dtype=weight_dtype, 920 | ) 921 | pipeline = pipeline.to(accelerator.device) 922 | pipeline.set_progress_bar_config(disable=True) 923 | 924 | # run inference 925 | generator = torch.Generator(device=accelerator.device) 926 | if args.seed is not None: 927 | generator = generator.manual_seed(args.seed) 928 | images = [] 929 | for _ in range(args.num_validation_images): 930 | images.append( 931 | pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] 932 | ) 933 | 934 | for tracker in accelerator.trackers: 935 | if tracker.name == "tensorboard": 936 | np_images = np.stack([np.asarray(img) for img in images]) 937 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 938 | if tracker.name == "wandb": 939 | tracker.log( 940 | { 941 | "validation": [ 942 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 943 | for i, image in enumerate(images) 944 | ] 945 | } 946 | ) 947 | 948 | del pipeline 949 | torch.cuda.empty_cache() 950 | 951 | # Save the lora layers 952 | accelerator.wait_for_everyone() 953 | if accelerator.is_main_process: 954 | unet = unet.to(torch.float32) 955 | unet.save_attn_procs(args.output_dir) 956 | 957 | if args.push_to_hub: # False 958 | save_model_card( 959 | repo_id, 960 | images=images, 961 | base_model=args.pretrained_model_name_or_path, 962 | dataset_name=args.dataset_name, 963 | repo_folder=args.output_dir, 964 | ) 965 | upload_folder( 966 | repo_id=repo_id, 967 | folder_path=args.output_dir, 968 | commit_message="End of training", 969 | ignore_patterns=["step_*", "epoch_*"], 970 | ) 971 | 972 | accelerator.end_training() 973 | 974 | 975 | if __name__ == "__main__": 976 | main() 977 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import argparse 5 | import json 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import numpy as np 9 | import torch 10 | import torchvision 11 | from torch.utils.data import Dataset, DataLoader, Subset 12 | from transformers import AutoImageProcessor, UperNetForSemanticSegmentation 13 | from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DPMSolverMultistepScheduler, UniPCMultistepScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline 14 | from diffusers import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline 15 | from diffusers.models import AutoencoderKL 16 | from diffusers.utils import load_image 17 | 18 | from Generation.data.ImageNet1K import create_ImageNetFolder 19 | # Dataset 20 | from data.new_load_data import get_generation_dataset 21 | 22 | print("Package Load Check Done") 23 | 24 | 25 | def load_caption_dict(image_names,caption_path): 26 | class_lis = set([image_name.split("/")[0] for image_name in image_names]) 27 | dict_lis = [] 28 | for class_name in class_lis: 29 | with open(os.path.join(caption_path,f"{class_name}.json"), 'r') as file: 30 | dict_lis.append(json.load(file)) 31 | caption_dict = {key: value for dictionary in dict_lis for key, value in dictionary.items()} 32 | return caption_dict 33 | 34 | def group_lists(list1, list2, list3, list4, list5): 35 | grouped_data = {} 36 | for idx, item in enumerate(list1): 37 | if item not in grouped_data: 38 | grouped_data[item] = ([list2[idx]], [list3[idx]], [list4[idx]], [list5[idx]]) 39 | else: 40 | grouped_data[item][0].append(list2[idx]) 41 | grouped_data[item][1].append(list3[idx]) 42 | grouped_data[item][2].append(list4[idx]) 43 | grouped_data[item][3].append(list5[idx]) 44 | 45 | grouped_list = [(key, grouped_data[key][0], grouped_data[key][1], grouped_data[key][2], grouped_data[key][3]) for key in grouped_data] 46 | return grouped_list 47 | 48 | def get_args(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--dataset",default="imagenette", help="Which Dataset") 51 | parser.add_argument("--index",default=0,type=int,help="split task") 52 | parser.add_argument("--version",default="v57",type=str,help="out_version") 53 | parser.add_argument("--lora_path",default=None,type=str,help="lora path") 54 | parser.add_argument("--batch_size",default=8,type=int,help="batch size") 55 | parser.add_argument('--use_caption', default='blip2', type=str, help="use caption model") 56 | parser.add_argument('--img_size',default=512,type=int, help='Generation Image Size') 57 | 58 | parser.add_argument('--method', default='SD_T2I', type=str, help="generation method") 59 | parser.add_argument('--use_guidance', default='No', type=str, help="guidance token") 60 | parser.add_argument('--if_SDXL',default='No', type=str, help="SDXL") 61 | parser.add_argument('--if_full',default='Yes',type=str, help='singleLora') 62 | parser.add_argument('--if_compile',default='No',type=str, help='compile?') 63 | parser.add_argument('--image_strength',default=0.75,type=float,help="init image strength") 64 | parser.add_argument('--nchunks',default=8,type=int,help="No. subprocess") 65 | 66 | parser.add_argument("--imagenet_path",default="",type=str,help="path to imagenet") 67 | parser.add_argument("--syn_path",default="",type=str,help="path to synthetic data") 68 | 69 | # Parameters 70 | parser.add_argument('--cross_attention_scale', default=0.5, type=float, help="lora scale") 71 | parser.add_argument('--ref_version',default='v120',type=str, help='version to refine') 72 | 73 | args = parser.parse_args() 74 | return args 75 | 76 | class StableDiffusionHandler: 77 | def __init__(self, args): 78 | self.args = args 79 | """ 80 | (Pdb) print(self.args) 81 | Namespace(batch_size=24, cross_attention_scale=0.5, dataset='imagenette', if_SDXL='No', if_compile='No', if_full='Yes', image_strength=0.75, img_size=512, index=0, lora_path='./LoRA/checkpoint/gt_dm_v1', method='SDI2I_LoRA', nchunks=8, ref_version='v120', use_caption='blip2', use_guidance='Yes', version='v1') 82 | """ 83 | self.method = args.method # SDI2I_LoRA 84 | self.if_SDXL = False 85 | self.use_guidance_tokens = True 86 | self.if_full = True 87 | self.if_compile = False 88 | 89 | self.controlnet_scale = 1.0 90 | self.lora_path = args.lora_path 91 | self.inference_step = 30 92 | self.guidance_scale = 2.0 93 | self.cross_attention_scale = args.cross_attention_scale # 0.5 94 | self.init_image_strength = args.image_strength # 0.75 95 | self.scheduler = "UniPC" 96 | self.img_size = args.img_size # 512 97 | 98 | ### Get Pipelines 99 | def get_stablediffusion(self, stablediffusion_path, lora=None): 100 | pipe = StableDiffusionPipeline.from_pretrained( 101 | stablediffusion_path, safety_checker=None, torch_dtype=torch.float16, add_watermarker=False 102 | ) 103 | if lora: 104 | print("Load LoRA:", os.path.join(self.lora_path,lora)) 105 | pipe.unet.load_attn_procs(os.path.join(self.lora_path,lora)) 106 | 107 | pipe = self.set_scheduler(pipe) 108 | pipe.to("cuda") 109 | if self.if_compile: 110 | print("Compile UNet") 111 | torch._dynamo.config.verbose = True 112 | pipe.unet = torch.compile(pipe.unet) 113 | pipe.enable_model_cpu_offload() 114 | return pipe 115 | 116 | def get_img2img(self,img2img_path, lora=None): 117 | pipe = StableDiffusionImg2ImgPipeline.from_pretrained(img2img_path, safety_checker=None, torch_dtype=torch.float16) 118 | if lora: 119 | print("Load LoRA:", os.path.join(self.lora_path,lora)) 120 | pipe.unet.load_attn_procs(os.path.join(self.lora_path,lora)) 121 | pipe = self.set_scheduler(pipe) 122 | pipe.to("cuda") 123 | if self.if_compile: # False 124 | print("Compile UNet") 125 | torch._dynamo.config.verbose = True 126 | pipe.unet = torch.compile(pipe.unet) 127 | pipe.enable_model_cpu_offload() 128 | 129 | return pipe 130 | 131 | def set_scheduler(self, pipe): 132 | if self.scheduler == "UniPC": #! scheduler 133 | pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) 134 | elif self.scheduler == "DPM++2MKarras": 135 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras=True) 136 | elif self.scheduler == "DPM++2MAKarras": 137 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras=True, algorithm_type="sde-dpmsolver++") 138 | return pipe 139 | 140 | def get_subdataset_loader(self, real_dst_train, bsz, num_chunks=8): 141 | # split Task 142 | chunk_size = len(real_dst_train) // num_chunks 143 | chunk_index = self.args.index 144 | if chunk_index == num_chunks-1: 145 | subset_indices = range(chunk_index*chunk_size, len(real_dst_train)) 146 | else: 147 | subset_indices = range(chunk_index*chunk_size, (chunk_index+1)*chunk_size) 148 | subset_dataset = Subset(real_dst_train, indices=subset_indices) 149 | dataloader = DataLoader(subset_dataset, batch_size=bsz, shuffle=False, num_workers=4) 150 | return dataloader 151 | 152 | ### Generate 153 | def generate_sd(self,prompts,negative_prompts): 154 | images = self.pipe(prompts, 155 | num_inference_steps=self.inference_step, 156 | negative_prompt=negative_prompts, 157 | prompt_embeds=None, 158 | negative_prompt_embeds=None, 159 | guidance_scale=self.guidance_scale 160 | ).images 161 | 162 | return images 163 | 164 | def generate_sd_lora(self,prompts,negative_prompts, image_names, prev_class_id): 165 | class_ids = [image_name.split("/")[0] for image_name in image_names] 166 | groups = group_lists(class_ids, prompts, negative_prompts, negative_prompts, image_names) 167 | print("Group:",len(groups)) 168 | images = [] 169 | for group in groups: 170 | class_id, prompts, negative_prompts, _, img_names = group 171 | if not class_id == prev_class_id and not self.if_full: 172 | self.pipe = self.get_stablediffusion(class_id) 173 | if self.use_guidance_tokens: 174 | guidance_tokens = self.get_guidance_tokens_v2(class_id, img_names) 175 | else: 176 | guidance_tokens = None 177 | 178 | sub_images = self.pipe(prompts, 179 | num_inference_steps=self.inference_step, 180 | negative_prompt=negative_prompts, 181 | prompt_embeds=None, 182 | negative_prompt_embeds=None, 183 | guidance_scale=self.guidance_scale, 184 | cross_attention_kwargs={"scale": self.cross_attention_scale}, 185 | guidance_tokens = guidance_tokens 186 | ).images 187 | images.extend(sub_images) 188 | return images, class_id 189 | 190 | def generate_img2img(self,prompts,init_images,negative_prompts): 191 | images = self.pipe(prompts, 192 | num_inference_steps=self.inference_step, 193 | image=init_images, 194 | negative_prompt=negative_prompts, 195 | prompt_embeds=None, 196 | negative_prompt_embeds=None, 197 | guidance_scale=self.guidance_scale, 198 | strength=self.init_image_strength 199 | ).images 200 | 201 | return images 202 | 203 | def generate_img2img_lora(self,prompts,init_images,negative_prompts, image_names, prev_class_id, class_names=None): 204 | if self.args.dataset in ['imagenette','imagenet100','imagenet1k']: 205 | class_ids = [image_name.split("/")[0] for image_name in image_names] 206 | groups = group_lists(class_ids, prompts, init_images, negative_prompts, image_names) 207 | print("Group:",len(groups)) 208 | images = [] 209 | for group in groups: 210 | class_id, prompts, init_images, negative_prompts, img_names = group 211 | if not class_id == prev_class_id and not self.if_full: 212 | self.pipe = self.get_img2img(class_id) 213 | if self.use_guidance_tokens: 214 | guidance_tokens = self.get_guidance_tokens_v2(class_id, img_names) 215 | else: 216 | guidance_tokens = None 217 | 218 | sub_images = self.pipe(prompts, 219 | num_inference_steps=self.inference_step, 220 | image=init_images, 221 | negative_prompt=negative_prompts, 222 | prompt_embeds=None, 223 | negative_prompt_embeds=None, 224 | guidance_scale=self.guidance_scale, 225 | strength=self.init_image_strength, 226 | cross_attention_kwargs={"scale": self.cross_attention_scale}, 227 | guidance_tokens = guidance_tokens 228 | ).images 229 | images.extend(sub_images) 230 | 231 | return images, class_id 232 | 233 | ### Misc 234 | def get_pipe(self,pid): 235 | if self.method in ['SDI2I_LoRA']: 236 | pipe = self.get_img2img(pid) 237 | elif self.method in ['SDT2I_LoRA']: # SDI2I_LoRA' 238 | pipe = self.get_stablediffusion(pid) 239 | elif self.method in ['SDI2I']: 240 | pipe = self.get_img2img() 241 | elif self.method in ['SDT2I']: 242 | pipe = self.get_stablediffusion() 243 | 244 | return pipe 245 | 246 | def get_misc(self): 247 | img_size = (self.img_size, self.img_size) # (512, 512) 248 | bsz = self.args.batch_size # 24 249 | out_version = self.args.version # "v52" 250 | create_ImageNetFolder(root_dir=f'{self.args.imagnet_path}/train', out_dir=f"{self.args.syn_path}/train") 251 | ImageNetPath = self.args.imagnet_path 252 | dataset = self.args.dataset # "imagenette" 253 | use_caption = True if self.args.use_caption == 'blip2' else False 254 | caption_path = "./ImageNet_BLIP2_caption_json/ImageNet_BLIP2_caption_json" 255 | ### 256 | print('Image Size',img_size) 257 | print('Batch Size',bsz) 258 | print("Use BLIP Caption", use_caption) 259 | ### 260 | 261 | return img_size, bsz, out_version, use_caption, caption_path, ImageNetPath, dataset 262 | 263 | def generate(self, prompts,init_images,negative_prompts, image_names, prev_class_id,class_names=None): 264 | # Generate 265 | if self.method in ['SDI2I_LoRA']: 266 | images, prev_class_id = self.generate_img2img_lora(prompts,init_images,negative_prompts, image_names, prev_class_id,class_names) 267 | elif self.method in ['SDT2I_LoRA']: 268 | images, prev_class_id = self.generate_sd_lora(prompts, negative_prompts, image_names, prev_class_id) 269 | elif self.method in ['SDI2I']: 270 | images = self.generate_img2img(prompts,init_images,negative_prompts) 271 | elif self.method in ['SDT2I']: 272 | images = self.generate_sd(prompts,negative_prompts) 273 | elif self.method in ['SDXLRefine']: 274 | images = self.generate_sdxl_refine(prompts,init_images,negative_prompts) 275 | 276 | return images, prev_class_id 277 | 278 | def get_prompt(self, use_caption, image_names,class_names,caption_path,bs): 279 | if use_caption: 280 | base_prompts = [f"photo of {c}" for c in class_names] 281 | caption_dict = load_caption_dict(image_names, caption_path) 282 | caption_suffix = [caption_dict[f"{image_name.split('/')[-1]}.JPEG"] for image_name in image_names] 283 | prompts = [f"{base_prompts[n]}, {caption_suffix[n]}, best quality" for n in range(bs)] 284 | else: 285 | prompts = [f"{c}, photo, best quality" for c in class_names] 286 | 287 | return prompts 288 | 289 | def get_guidance_tokens_v2(self, class_id, image_names): 290 | if self.args.dataset in ['imagenette','imagenet100','imagenet1k']: 291 | root='./LoRA/CLIPEmbedding/train' 292 | 293 | 294 | dir_path = os.path.join(f"{root}", class_id) 295 | if self.args.dataset in ['imagenette','imagenet100','imagenet1k']: 296 | # sampled_files = [os.path.join(dir_path,f"{img_name.split('/')[-1]}.pt") for img_name in image_names] 297 | sampled_files = [f"{img_name.split('/')[-1]}.pt" for img_name in image_names] 298 | 299 | feature_dist_samples = [torch.load(os.path.join(dir_path, f)) for f in sampled_files] 300 | guidance_tokens = torch.stack(feature_dist_samples, dim=1).view(len(image_names),1,-1) 301 | 302 | return guidance_tokens 303 | 304 | ### Dataset Generation Pipe 305 | def generate_ImageNet1k(self): 306 | print(f"Generation: {self.method}") 307 | print("Full",self.if_full) 308 | 309 | prev_class_id = 'all' if self.if_full else None 310 | self.pipe = self.get_pipe(prev_class_id) # all, sdi2i_lora 311 | img_size, bsz, out_version, use_caption, caption_path, ImageNetPath, dataset = self.get_misc() 312 | real_dst_train = get_generation_dataset(ImageNetPath, split="train",subset=dataset,filelist="file_list.txt") 313 | #! just a subset, 8 gpu for running. 314 | dataloader = self.get_subdataset_loader(real_dst_train, bsz, num_chunks=self.args.nchunks) 315 | 316 | for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): 317 | targets, image_paths, image_names, class_names = batch 318 | bs = len(image_paths) 319 | out_paths = [os.path.join(f"{self.args.syn_path}/train",f'{image_names[idx]}.jpg') for idx in range(bs)] 320 | if os.path.exists(out_paths[-1]): 321 | continue 322 | 323 | prompts = self.get_prompt(use_caption, image_names,class_names,caption_path,bs) 324 | 325 | negative_prompts = ["distorted, unrealistic, blurry, out of frame, cropped, deformed" for n in range(bs)] 326 | 327 | if self.method in ['SDI2I_LoRA', 'SDI2I', 'SDXLRefine']: 328 | init_images = [Image.open(image_path).convert("RGB").resize(img_size) for image_path in image_paths] 329 | else: 330 | init_images = None 331 | 332 | images, prev_class_id = self.generate(prompts,init_images,negative_prompts, image_names, prev_class_id) 333 | 334 | # Save Image 335 | for idx,image in enumerate(images): 336 | image.save(out_paths[idx]) 337 | 338 | # Copy Label 339 | shutil.copy(f"{ImageNetPath}/train/file_list.txt", f"{self.args.syn_path}/train") 340 | 341 | def generate_pipeline(self): 342 | if self.args.dataset in ['imagenet1k','imagenette','imagenet100']: 343 | self.generate_ImageNet1k() 344 | 345 | 346 | def main(): 347 | args = get_args() 348 | # import pdb; pdb.set_trace() 349 | handler = StableDiffusionHandler(args) 350 | handler.generate_pipeline() 351 | 352 | if __name__ == "__main__": 353 | torch.backends.cudnn.benchmark = True 354 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.19.0 2 | aiohttp==3.8.4 3 | aiosignal==1.3.1 4 | async-timeout==4.0.2 5 | attrs==23.1.0 6 | certifi==2023.5.7 7 | charset-normalizer==3.1.0 8 | cmake==3.26.3 9 | contourpy==1.1.0 10 | cycler==0.11.0 11 | datasets==2.12.0 12 | dill==0.3.6 13 | filelock==3.12.0 14 | fonttools==4.42.1 15 | frozenlist==1.3.3 16 | fsspec==2023.5.0 17 | ftfy==6.1.1 18 | huggingface-hub==0.14.1 19 | idna==3.4 20 | importlib-metadata==6.6.0 21 | importlib-resources==6.0.1 22 | invisible-watermark 23 | Jinja2==3.1.2 24 | kiwisolver==1.4.5 25 | lit==16.0.6 26 | MarkupSafe==2.1.2 27 | matplotlib==3.7.2 28 | mpmath==1.3.0 29 | multidict==6.0.4 30 | multiprocess==0.70.14 31 | mypy-extensions==1.0.0 32 | networkx==3.1 33 | numpy==1.24.4 34 | nvidia-cublas-cu11==11.10.3.66 35 | nvidia-cuda-cupti-cu11==11.7.101 36 | nvidia-cuda-nvrtc-cu11==11.7.99 37 | nvidia-cuda-runtime-cu11==11.7.99 38 | nvidia-cudnn-cu11==8.5.0.96 39 | nvidia-cufft-cu11==10.9.0.58 40 | nvidia-curand-cu11==10.2.10.91 41 | nvidia-cusolver-cu11==11.4.0.1 42 | nvidia-cusparse-cu11==11.7.4.91 43 | nvidia-nccl-cu11==2.14.3 44 | nvidia-nvtx-cu11==11.7.91 45 | open-clip-torch==2.20.0 46 | opencv-python 47 | packaging==23.1 48 | pandas==2.0.3 49 | Pillow==9.5.0 50 | protobuf==3.20.3 51 | psutil==5.9.5 52 | pyarrow==12.0.0 53 | pyparsing==3.0.9 54 | pyre-extensions==0.0.29 55 | python-dateutil==2.8.2 56 | pytz==2023.3 57 | PyWavelets 58 | PyYAML==6.0 59 | regex==2023.5.5 60 | requests==2.31.0 61 | responses==0.18.0 62 | safetensors==0.3.1 63 | scipy==1.10.1 64 | sentencepiece==0.1.99 65 | six==1.16.0 66 | sympy==1.12 67 | timm==0.9.2 68 | tokenizers==0.13.3 69 | torch==2.0.1 70 | torchaudio==2.0.2 71 | torchvision==0.15.2 72 | tqdm==4.65.0 73 | transformers==4.30.2 74 | triton==2.0.0 75 | typing-inspect==0.9.0 76 | typing_extensions==4.7.0 77 | tzdata==2023.3 78 | urllib3==2.0.3 79 | wcwidth==0.2.6 80 | xformers==0.0.20 81 | xxhash==3.2.0 82 | yarl==1.9.2 83 | zipp==3.15.0 84 | -------------------------------------------------------------------------------- /shell_generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dataset='imagenette' 4 | versions=('v1') 5 | loras=('gt_dm') 6 | methods=('SDI2I_LoRA') 7 | guidance_tokens=('Yes') 8 | SDXLs=('No') 9 | image_strengths=(0.75) 10 | 11 | length=${#versions[@]} 12 | echo "start Generation Loop" 13 | for ((i=0; i<$length; i++)); do 14 | ver="${versions[$i]}" 15 | lora="./LoRA/checkpoint/${loras[$i]}_${ver}" 16 | method="${methods[$i]}" 17 | guidance_token="${guidance_tokens[$i]}" 18 | SDXL="${SDXLs[$i]}" 19 | cw="${cws[$i]}" 20 | imst="${image_strengths[$i]}" 21 | echo "$ver LoRA: $lora Method $method" 22 | # Iterate from 0-7, cover all case for nchunks <= 8 23 | for j in {0..7}; do 24 | echo $j 25 | CUDA_VISIBLE_DEVICES=$j python generate.py --index ${j} --method $method --version $ver --batch_size 24 \ 26 | --use_caption "blip2" --dataset $dataset --lora_path $lora --if_SDXL $SDXL --use_guidance $guidance_token \ 27 | --img_size 512 --cross_attention_scale 0.5 --image_strength $imst --nchunks 8 > results/gen${j}.out 2>&1 & 28 | done 29 | wait 30 | echo "All processes completed" 31 | done -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ ImageNet Training Script 3 | 4 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet 5 | training results with some of the latest networks and training techniques. It favours canonical PyTorch 6 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed 7 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. 8 | 9 | This script was started from an early version of the PyTorch ImageNet example 10 | (https://github.com/pytorch/examples/tree/master/imagenet) 11 | 12 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples 13 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 16 | """ 17 | import argparse 18 | import logging 19 | import os 20 | import time 21 | from collections import OrderedDict 22 | from contextlib import suppress 23 | from datetime import datetime 24 | from functools import partial 25 | 26 | import torch 27 | import torch.nn as nn 28 | import torchvision.utils 29 | import yaml 30 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 31 | 32 | from timm import utils 33 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 34 | from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm 35 | from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy 36 | from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters 37 | from timm.optim import create_optimizer_v2, optimizer_kwargs 38 | from timm.scheduler import create_scheduler_v2, scheduler_kwargs 39 | from timm.utils import ApexScaler, NativeScaler 40 | 41 | from data.new_load_data import get_dataset, get_combined_dataset 42 | try: 43 | from apex import amp 44 | from apex.parallel import DistributedDataParallel as ApexDDP 45 | from apex.parallel import convert_syncbn_model 46 | has_apex = True 47 | except ImportError: 48 | has_apex = False 49 | 50 | has_native_amp = False 51 | try: 52 | if getattr(torch.cuda.amp, 'autocast') is not None: 53 | has_native_amp = True 54 | except AttributeError: 55 | pass 56 | 57 | try: 58 | import wandb 59 | has_wandb = True 60 | except ImportError: 61 | has_wandb = False 62 | 63 | try: 64 | from functorch.compile import memory_efficient_fusion 65 | has_functorch = True 66 | except ImportError as e: 67 | has_functorch = False 68 | 69 | has_compile = hasattr(torch, 'compile') 70 | 71 | 72 | _logger = logging.getLogger('train') 73 | 74 | # The first arg parser parses out only the --config argument, this argument is used to 75 | # load a yaml file containing key-values that override the defaults for the main parser below 76 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 77 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 78 | help='YAML config file specifying default arguments') 79 | 80 | 81 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 82 | 83 | # Dataset parameters 84 | group = parser.add_argument_group('Dataset parameters') 85 | # Keep this argument outside the dataset group because it is positional. 86 | parser.add_argument('data', nargs='?', metavar='DIR', const=None, 87 | help='path to dataset (positional is *deprecated*, use --data-dir)') 88 | parser.add_argument('--data-dir', metavar='DIR', default="", 89 | help='path to dataset (root dir)') 90 | parser.add_argument('--imagenet_path', metavar='DIR', default="", 91 | help='path to real ImageNet') 92 | parser.add_argument('--dataset', metavar='NAME', default='', 93 | help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)') 94 | parser.add_argument('--use_caption', action='store_true') 95 | parser.add_argument("--guidance_scale", type=float, default=3, help="guidance scale") 96 | parser.add_argument('--comb_dataset', type=str, nargs='*', default=None, help='A list of dataset') 97 | 98 | 99 | group.add_argument('--train-split', metavar='NAME', default='train', 100 | help='dataset train split (default: train)') 101 | group.add_argument('--val-split', metavar='NAME', default='validation', 102 | help='dataset validation split (default: validation)') 103 | group.add_argument('--dataset-download', action='store_true', default=False, 104 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 105 | group.add_argument('--class-map', default='', type=str, metavar='FILENAME', 106 | help='path to class to idx mapping file (default: "")') 107 | 108 | 109 | # Model parameters 110 | group = parser.add_argument_group('Model parameters') 111 | group.add_argument('--model', default='resnet50', type=str, metavar='MODEL', 112 | help='Name of model to train (default: "resnet50")') 113 | group.add_argument('--pretrained', action='store_true', default=False, 114 | help='Start with pretrained version of specified network (if avail)') 115 | group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 116 | help='Initialize model from this checkpoint (default: none)') 117 | group.add_argument('--resume', default='', type=str, metavar='PATH', 118 | help='Resume full model and optimizer state from checkpoint (default: none)') 119 | group.add_argument('--no-resume-opt', action='store_true', default=False, 120 | help='prevent resume of optimizer state when resuming model') 121 | group.add_argument('--num-classes', type=int, default=None, metavar='N', 122 | help='number of label classes (Model default if None)') 123 | group.add_argument('--gp', default=None, type=str, metavar='POOL', 124 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 125 | group.add_argument('--img-size', type=int, default=None, metavar='N', 126 | help='Image size (default: None => model default)') 127 | group.add_argument('--in-chans', type=int, default=None, metavar='N', 128 | help='Image input channels (default: None => 3)') 129 | group.add_argument('--input-size', default=None, nargs=3, type=int, 130 | metavar='N N N', 131 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 132 | group.add_argument('--crop-pct', default=None, type=float, 133 | metavar='N', help='Input image center crop percent (for validation only)') 134 | group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 135 | help='Override mean pixel value of dataset') 136 | group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 137 | help='Override std deviation of dataset') 138 | group.add_argument('--interpolation', default='', type=str, metavar='NAME', 139 | help='Image resize interpolation type (overrides model)') 140 | group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', 141 | help='Input batch size for training (default: 128)') 142 | group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', 143 | help='Validation batch size override (default: None)') 144 | group.add_argument('--channels-last', action='store_true', default=False, 145 | help='Use channels_last memory layout') 146 | group.add_argument('--fuser', default='', type=str, 147 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 148 | group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N', 149 | help='The number of steps to accumulate gradients (default: 1)') 150 | group.add_argument('--grad-checkpointing', action='store_true', default=False, 151 | help='Enable gradient checkpointing through model blocks/stages') 152 | group.add_argument('--fast-norm', default=False, action='store_true', 153 | help='enable experimental fast-norm') 154 | group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs) 155 | group.add_argument('--head-init-scale', default=None, type=float, 156 | help='Head initialization scale') 157 | group.add_argument('--head-init-bias', default=None, type=float, 158 | help='Head initialization bias value') 159 | 160 | # scripting / codegen 161 | scripting_group = group.add_mutually_exclusive_group() 162 | scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', 163 | help='torch.jit.script the full model') 164 | scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', 165 | help="Enable compilation w/ specified backend (default: inductor).") 166 | 167 | # Optimizer parameters 168 | group = parser.add_argument_group('Optimizer parameters') 169 | group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', 170 | help='Optimizer (default: "sgd")') 171 | group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 172 | help='Optimizer Epsilon (default: None, use opt default)') 173 | group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 174 | help='Optimizer Betas (default: None, use opt default)') 175 | group.add_argument('--momentum', type=float, default=0.9, metavar='M', 176 | help='Optimizer momentum (default: 0.9)') 177 | group.add_argument('--weight-decay', type=float, default=2e-5, 178 | help='weight decay (default: 2e-5)') 179 | group.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 180 | help='Clip gradient norm (default: None, no clipping)') 181 | group.add_argument('--clip-mode', type=str, default='norm', 182 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 183 | group.add_argument('--layer-decay', type=float, default=None, 184 | help='layer-wise learning rate decay (default: None)') 185 | group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs) 186 | 187 | # Learning rate schedule parameters 188 | group = parser.add_argument_group('Learning rate schedule parameters') 189 | group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER', 190 | help='LR scheduler (default: "step"') 191 | group.add_argument('--sched-on-updates', action='store_true', default=False, 192 | help='Apply LR scheduler step on update instead of epoch end.') 193 | group.add_argument('--lr', type=float, default=None, metavar='LR', 194 | help='learning rate, overrides lr-base if set (default: None)') 195 | group.add_argument('--lr-base', type=float, default=0.1, metavar='LR', 196 | help='base learning rate: lr = lr_base * global_batch_size / base_size') 197 | group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV', 198 | help='base learning rate batch size (divisor, default: 256).') 199 | group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE', 200 | help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') 201 | group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 202 | help='learning rate noise on/off epoch percentages') 203 | group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 204 | help='learning rate noise limit percent (default: 0.67)') 205 | group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 206 | help='learning rate noise std-dev (default: 1.0)') 207 | group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 208 | help='learning rate cycle len multiplier (default: 1.0)') 209 | group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', 210 | help='amount to decay each learning rate cycle (default: 0.5)') 211 | group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 212 | help='learning rate cycle limit, cycles enabled if > 1') 213 | group.add_argument('--lr-k-decay', type=float, default=1.0, 214 | help='learning rate k-decay for cosine/poly (default: 1.0)') 215 | group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR', 216 | help='warmup learning rate (default: 1e-5)') 217 | group.add_argument('--min-lr', type=float, default=0, metavar='LR', 218 | help='lower lr bound for cyclic schedulers that hit 0 (default: 0)') 219 | group.add_argument('--epochs', type=int, default=300, metavar='N', 220 | help='number of epochs to train (default: 300)') 221 | group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', 222 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') 223 | group.add_argument('--start-epoch', default=None, type=int, metavar='N', 224 | help='manual epoch number (useful on restarts)') 225 | # group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES", 226 | # help='list of decay epoch indices for multistep lr. must be increasing') 227 | group.add_argument('--decay-milestones', default=[50, 100, 150], type=int, nargs='+', metavar="MILESTONES", 228 | help='list of decay epoch indices for multistep lr. must be increasing') 229 | group.add_argument('--decay-epochs', type=float, default=90, metavar='N', 230 | help='epoch interval to decay LR') 231 | group.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 232 | help='epochs to warmup LR, if scheduler supports') 233 | group.add_argument('--warmup-prefix', action='store_true', default=False, 234 | help='Exclude warmup period from decay schedule.'), 235 | group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N', 236 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 237 | group.add_argument('--patience-epochs', type=int, default=10, metavar='N', 238 | help='patience epochs for Plateau LR scheduler (default: 10)') 239 | group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 240 | help='LR decay rate (default: 0.1)') 241 | 242 | # Augmentation & regularization parameters 243 | group = parser.add_argument_group('Augmentation and regularization parameters') 244 | group.add_argument('--no-aug', action='store_true', default=False, 245 | help='Disable all training augmentation, override other train aug args') 246 | group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 247 | help='Random resize scale (default: 0.08 1.0)') 248 | group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', 249 | help='Random resize aspect ratio (default: 0.75 1.33)') 250 | group.add_argument('--hflip', type=float, default=0.5, 251 | help='Horizontal flip training aug probability') 252 | group.add_argument('--vflip', type=float, default=0., 253 | help='Vertical flip training aug probability') 254 | group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 255 | help='Color jitter factor (default: 0.4)') 256 | group.add_argument('--aa', type=str, default=None, metavar='NAME', 257 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 258 | group.add_argument('--aug-repeats', type=float, default=0, 259 | help='Number of augmentation repetitions (distributed training only) (default: 0)') 260 | group.add_argument('--aug-splits', type=int, default=0, 261 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 262 | group.add_argument('--jsd-loss', action='store_true', default=False, 263 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 264 | group.add_argument('--bce-loss', action='store_true', default=False, 265 | help='Enable BCE loss w/ Mixup/CutMix use.') 266 | group.add_argument('--bce-target-thresh', type=float, default=None, 267 | help='Threshold for binarizing softened BCE targets (default: None, disabled)') 268 | group.add_argument('--reprob', type=float, default=0., metavar='PCT', 269 | help='Random erase prob (default: 0.)') 270 | group.add_argument('--remode', type=str, default='pixel', 271 | help='Random erase mode (default: "pixel")') 272 | group.add_argument('--recount', type=int, default=1, 273 | help='Random erase count (default: 1)') 274 | group.add_argument('--resplit', action='store_true', default=False, 275 | help='Do not random erase first (clean) augmentation split') 276 | group.add_argument('--mixup', type=float, default=0.0, 277 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 278 | group.add_argument('--cutmix', type=float, default=0.0, 279 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 280 | group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 281 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 282 | group.add_argument('--mixup-prob', type=float, default=1.0, 283 | help='Probability of performing mixup or cutmix when either/both is enabled') 284 | group.add_argument('--mixup-switch-prob', type=float, default=0.5, 285 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 286 | group.add_argument('--mixup-mode', type=str, default='batch', 287 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 288 | group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 289 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 290 | group.add_argument('--smoothing', type=float, default=0.1, 291 | help='Label smoothing (default: 0.1)') 292 | group.add_argument('--train-interpolation', type=str, default='random', 293 | help='Training interpolation (random, bilinear, bicubic default: "random")') 294 | group.add_argument('--drop', type=float, default=0.0, metavar='PCT', 295 | help='Dropout rate (default: 0.)') 296 | group.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 297 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 298 | group.add_argument('--drop-path', type=float, default=None, metavar='PCT', 299 | help='Drop path rate (default: None)') 300 | group.add_argument('--drop-block', type=float, default=None, metavar='PCT', 301 | help='Drop block rate (default: None)') 302 | 303 | # Batch norm parameters (only works with gen_efficientnet based models currently) 304 | group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') 305 | group.add_argument('--bn-momentum', type=float, default=None, 306 | help='BatchNorm momentum override (if not None)') 307 | group.add_argument('--bn-eps', type=float, default=None, 308 | help='BatchNorm epsilon override (if not None)') 309 | group.add_argument('--sync-bn', action='store_true', 310 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 311 | group.add_argument('--dist-bn', type=str, default='reduce', 312 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 313 | group.add_argument('--split-bn', action='store_true', 314 | help='Enable separate BN layers per augmentation split.') 315 | 316 | # Model Exponential Moving Average 317 | group = parser.add_argument_group('Model exponential moving average parameters') 318 | group.add_argument('--model-ema', action='store_true', default=False, 319 | help='Enable tracking moving average of model weights') 320 | group.add_argument('--model-ema-force-cpu', action='store_true', default=False, 321 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 322 | group.add_argument('--model-ema-decay', type=float, default=0.9998, 323 | help='decay factor for model weights moving average (default: 0.9998)') 324 | 325 | # Misc 326 | group = parser.add_argument_group('Miscellaneous parameters') 327 | group.add_argument('--seed', type=int, default=42, metavar='S', 328 | help='random seed (default: 42)') 329 | group.add_argument('--worker-seeding', type=str, default='all', 330 | help='worker seed mode (default: all)') 331 | group.add_argument('--log-interval', type=int, default=50, metavar='N', 332 | help='how many batches to wait before logging training status') 333 | group.add_argument('--recovery-interval', type=int, default=0, metavar='N', 334 | help='how many batches to wait before writing recovery checkpoint') 335 | group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', 336 | help='number of checkpoints to keep (default: 10)') 337 | group.add_argument('-j', '--workers', type=int, default=4, metavar='N', 338 | help='how many training processes to use (default: 4)') 339 | group.add_argument('--save-images', action='store_true', default=False, 340 | help='save images of input bathes every log interval for debugging') 341 | group.add_argument('--amp', action='store_true', default=False, 342 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 343 | group.add_argument('--amp-dtype', default='float16', type=str, 344 | help='lower precision AMP dtype (default: float16)') 345 | group.add_argument('--amp-impl', default='native', type=str, 346 | help='AMP impl to use, "native" or "apex" (default: native)') 347 | group.add_argument('--no-ddp-bb', action='store_true', default=False, 348 | help='Force broadcast buffers for native DDP to off.') 349 | group.add_argument('--synchronize-step', action='store_true', default=False, 350 | help='torch.cuda.synchronize() end of each step') 351 | group.add_argument('--pin-mem', action='store_true', default=False, 352 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 353 | group.add_argument('--no-prefetcher', action='store_true', default=False, 354 | help='disable fast prefetcher') 355 | group.add_argument('--output', default='', type=str, metavar='PATH', 356 | help='path to output folder (default: none, current dir)') 357 | group.add_argument('--experiment', default='', type=str, metavar='NAME', 358 | help='name of train experiment, name of sub-folder for output') 359 | group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 360 | help='Best metric (default: "top1"') 361 | group.add_argument('--tta', type=int, default=0, metavar='N', 362 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 363 | group.add_argument("--local_rank", default=0, type=int) 364 | group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 365 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 366 | group.add_argument('--log-wandb', action='store_true', default=False, 367 | help='log training and validation metrics to wandb') 368 | 369 | 370 | def _parse_args(): 371 | # Do we have a config file to parse? 372 | args_config, remaining = config_parser.parse_known_args() 373 | if args_config.config: 374 | with open(args_config.config, 'r') as f: 375 | cfg = yaml.safe_load(f) 376 | parser.set_defaults(**cfg) 377 | 378 | # The main arg parser parses the rest of the args, the usual 379 | # defaults will have been overridden if config file specified. 380 | args = parser.parse_args(remaining) 381 | 382 | # Cache the args as a text string to save them in the output dir later 383 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 384 | return args, args_text 385 | 386 | 387 | def main(): 388 | utils.setup_default_logging() 389 | args, args_text = _parse_args() 390 | 391 | if torch.cuda.is_available(): 392 | torch.backends.cuda.matmul.allow_tf32 = True 393 | torch.backends.cudnn.benchmark = True 394 | 395 | args.prefetcher = not args.no_prefetcher 396 | args.grad_accum_steps = max(1, args.grad_accum_steps) 397 | device = utils.init_distributed_device(args) 398 | if args.distributed: 399 | _logger.info( 400 | 'Training in distributed mode with multiple processes, 1 device per process.' 401 | f'Process {args.rank}, total {args.world_size}, device {args.device}.') 402 | else: 403 | _logger.info(f'Training with a single process on 1 device ({args.device}).') 404 | assert args.rank >= 0 405 | 406 | # resolve AMP arguments based on PyTorch / Apex availability 407 | use_amp = None 408 | amp_dtype = torch.float16 409 | if args.amp: 410 | if args.amp_impl == 'apex': 411 | assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' 412 | use_amp = 'apex' 413 | assert args.amp_dtype == 'float16' 414 | else: 415 | assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' 416 | use_amp = 'native' 417 | assert args.amp_dtype in ('float16', 'bfloat16') 418 | if args.amp_dtype == 'bfloat16': 419 | amp_dtype = torch.bfloat16 420 | 421 | utils.random_seed(args.seed, args.rank) 422 | 423 | if args.fuser: 424 | utils.set_jit_fuser(args.fuser) 425 | if args.fast_norm: 426 | set_fast_norm() 427 | 428 | in_chans = 3 429 | if args.in_chans is not None: 430 | in_chans = args.in_chans 431 | elif args.input_size is not None: 432 | in_chans = args.input_size[0] 433 | 434 | model = create_model( 435 | args.model, 436 | pretrained=args.pretrained, 437 | in_chans=in_chans, 438 | num_classes=args.num_classes, 439 | drop_rate=args.drop, 440 | drop_path_rate=args.drop_path, 441 | drop_block_rate=args.drop_block, 442 | global_pool=args.gp, 443 | bn_momentum=args.bn_momentum, 444 | bn_eps=args.bn_eps, 445 | scriptable=args.torchscript, 446 | checkpoint_path=args.initial_checkpoint, 447 | **args.model_kwargs, 448 | ) 449 | if args.head_init_scale is not None: 450 | with torch.no_grad(): 451 | model.get_classifier().weight.mul_(args.head_init_scale) 452 | model.get_classifier().bias.mul_(args.head_init_scale) 453 | if args.head_init_bias is not None: 454 | nn.init.constant_(model.get_classifier().bias, args.head_init_bias) 455 | 456 | if args.num_classes is None: 457 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 458 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 459 | 460 | if args.grad_checkpointing: 461 | model.set_grad_checkpointing(enable=True) 462 | 463 | if utils.is_primary(args): 464 | _logger.info( 465 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') 466 | 467 | data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args)) 468 | # setup augmentation batch splits for contrastive loss or split bn 469 | num_aug_splits = 0 470 | if args.aug_splits > 0: 471 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 472 | num_aug_splits = args.aug_splits 473 | 474 | # enable split bn (separate bn stats per batch-portion) 475 | if args.split_bn: 476 | assert num_aug_splits > 1 or args.resplit 477 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 478 | 479 | # move model to GPU, enable channels last layout if set 480 | model.to(device=device) 481 | if args.channels_last: 482 | model.to(memory_format=torch.channels_last) 483 | 484 | # setup synchronized BatchNorm for distributed training 485 | if args.distributed and args.sync_bn: 486 | args.dist_bn = '' # disable dist_bn when sync BN active 487 | assert not args.split_bn 488 | if has_apex and use_amp == 'apex': 489 | # Apex SyncBN used with Apex AMP 490 | # WARNING this won't currently work with models using BatchNormAct2d 491 | model = convert_syncbn_model(model) 492 | else: 493 | model = convert_sync_batchnorm(model) 494 | if utils.is_primary(args): 495 | _logger.info( 496 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 497 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 498 | 499 | if args.torchscript: 500 | assert not args.torchcompile 501 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 502 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 503 | model = torch.jit.script(model) 504 | 505 | if not args.lr: 506 | global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps 507 | batch_ratio = global_batch_size / args.lr_base_size 508 | if not args.lr_base_scale: 509 | on = args.opt.lower() 510 | args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear' 511 | if args.lr_base_scale == 'sqrt': 512 | batch_ratio = batch_ratio ** 0.5 513 | args.lr = args.lr_base * batch_ratio 514 | if utils.is_primary(args): 515 | _logger.info( 516 | f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' 517 | f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') 518 | 519 | optimizer = create_optimizer_v2( 520 | model, 521 | **optimizer_kwargs(cfg=args), 522 | **args.opt_kwargs, 523 | ) 524 | 525 | # setup automatic mixed-precision (AMP) loss scaling and op casting 526 | amp_autocast = suppress # do nothing 527 | loss_scaler = None 528 | if use_amp == 'apex': 529 | assert device.type == 'cuda' 530 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 531 | loss_scaler = ApexScaler() 532 | if utils.is_primary(args): 533 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 534 | elif use_amp == 'native': 535 | try: 536 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) 537 | except (AttributeError, TypeError): 538 | # fallback to CUDA only AMP for PyTorch < 1.10 539 | assert device.type == 'cuda' 540 | amp_autocast = torch.cuda.amp.autocast 541 | if device.type == 'cuda' and amp_dtype == torch.float16: 542 | # loss scaler only used for float16 (half) dtype, bfloat16 does not need it 543 | loss_scaler = NativeScaler() 544 | if utils.is_primary(args): 545 | _logger.info('Using native Torch AMP. Training in mixed precision.') 546 | else: 547 | if utils.is_primary(args): 548 | _logger.info('AMP not enabled. Training in float32.') 549 | 550 | # optionally resume from a checkpoint 551 | resume_epoch = None 552 | if args.resume: 553 | resume_epoch = resume_checkpoint( 554 | model, 555 | args.resume, 556 | optimizer=None if args.no_resume_opt else optimizer, 557 | loss_scaler=None if args.no_resume_opt else loss_scaler, 558 | log_info=utils.is_primary(args), 559 | ) 560 | 561 | # setup exponential moving average of model weights, SWA could be used here too 562 | model_ema = None 563 | if args.model_ema: 564 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper 565 | model_ema = utils.ModelEmaV2( 566 | model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) 567 | if args.resume: 568 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 569 | 570 | # setup distributed training 571 | if args.distributed: 572 | if has_apex and use_amp == 'apex': 573 | # Apex DDP preferred unless native amp is activated 574 | if utils.is_primary(args): 575 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 576 | model = ApexDDP(model, delay_allreduce=True) 577 | else: 578 | if utils.is_primary(args): 579 | _logger.info("Using native Torch DistributedDataParallel.") 580 | model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) 581 | # NOTE: EMA model does not need to be wrapped by DDP 582 | 583 | if args.torchcompile: 584 | # torch compile should be done after DDP 585 | assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' 586 | model = torch.compile(model, backend=args.torchcompile) 587 | 588 | # create the train and eval datasets 589 | if args.data and not args.data_dir: 590 | args.data_dir = args.data 591 | 592 | dataset_train = get_dataset(args.data_dir, split="train",subset=args.dataset,filelist="file_list.txt") 593 | dataset_test = get_dataset(args.imagenet_path, split="val",subset=args.dataset, filelist="new_file_list.txt") 594 | 595 | # setup mixup / cutmix 596 | collate_fn = None 597 | mixup_fn = None 598 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 599 | if mixup_active: 600 | mixup_args = dict( 601 | mixup_alpha=args.mixup, 602 | cutmix_alpha=args.cutmix, 603 | cutmix_minmax=args.cutmix_minmax, 604 | prob=args.mixup_prob, 605 | switch_prob=args.mixup_switch_prob, 606 | mode=args.mixup_mode, 607 | label_smoothing=args.smoothing, 608 | num_classes=args.num_classes 609 | ) 610 | if args.prefetcher: 611 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 612 | collate_fn = FastCollateMixup(**mixup_args) 613 | else: 614 | mixup_fn = Mixup(**mixup_args) 615 | 616 | # wrap dataset in AugMix helper 617 | if num_aug_splits > 1: 618 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 619 | 620 | # create data loaders w/ augmentation pipeiine 621 | train_interpolation = args.train_interpolation 622 | if args.no_aug or not train_interpolation: 623 | train_interpolation = data_config['interpolation'] 624 | loader_train = create_loader( 625 | dataset_train, 626 | input_size=data_config['input_size'], 627 | batch_size=args.batch_size, 628 | is_training=True, 629 | use_prefetcher=args.prefetcher, 630 | no_aug=args.no_aug, 631 | re_prob=args.reprob, 632 | re_mode=args.remode, 633 | re_count=args.recount, 634 | re_split=args.resplit, 635 | scale=args.scale, 636 | ratio=args.ratio, 637 | hflip=args.hflip, 638 | vflip=args.vflip, 639 | color_jitter=args.color_jitter, 640 | auto_augment=args.aa, 641 | num_aug_repeats=args.aug_repeats, 642 | num_aug_splits=num_aug_splits, 643 | interpolation=train_interpolation, 644 | # mean=(0.4754, 0.4562, 0.4154), 645 | # std=(0.2323, 0.2271, 0.2275), 646 | mean=data_config['mean'], 647 | std=data_config['std'], 648 | num_workers=args.workers, 649 | distributed=args.distributed, 650 | collate_fn=collate_fn, 651 | pin_memory=args.pin_mem, 652 | device=device, 653 | use_multi_epochs_loader=args.use_multi_epochs_loader, 654 | worker_seeding=args.worker_seeding, 655 | ) 656 | 657 | eval_workers = args.workers 658 | if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): 659 | # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training 660 | eval_workers = min(2, args.workers) 661 | loader_eval = create_loader( 662 | dataset_test, 663 | input_size=data_config['input_size'], 664 | batch_size=args.validation_batch_size or args.batch_size, 665 | is_training=False, 666 | use_prefetcher=args.prefetcher, 667 | interpolation=data_config['interpolation'], 668 | mean=data_config['mean'], 669 | std=data_config['std'], 670 | num_workers=eval_workers, 671 | distributed=args.distributed, 672 | crop_pct=data_config['crop_pct'], 673 | pin_memory=args.pin_mem, 674 | device=device, 675 | ) 676 | 677 | # setup loss function 678 | if args.jsd_loss: 679 | assert num_aug_splits > 1 # JSD only valid with aug splits set 680 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) 681 | elif mixup_active: 682 | # smoothing is handled with mixup target transform which outputs sparse, soft targets 683 | if args.bce_loss: 684 | train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) 685 | else: 686 | train_loss_fn = SoftTargetCrossEntropy() 687 | elif args.smoothing: 688 | if args.bce_loss: 689 | train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) 690 | else: 691 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 692 | else: 693 | train_loss_fn = nn.CrossEntropyLoss() 694 | train_loss_fn = train_loss_fn.to(device=device) 695 | validate_loss_fn = nn.CrossEntropyLoss().to(device=device) 696 | 697 | # setup checkpoint saver and eval metric tracking 698 | eval_metric = args.eval_metric 699 | best_metric = None 700 | best_epoch = None 701 | saver = None 702 | output_dir = None 703 | if utils.is_primary(args): 704 | if args.experiment: 705 | exp_name = args.experiment 706 | else: 707 | exp_name = '-'.join([ 708 | datetime.now().strftime("%Y%m%d-%H%M%S"), 709 | safe_model_name(args.model), 710 | str(data_config['input_size'][-1]) 711 | ]) 712 | output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) 713 | decreasing = True if eval_metric == 'loss' else False 714 | saver = utils.CheckpointSaver( 715 | model=model, 716 | optimizer=optimizer, 717 | args=args, 718 | model_ema=model_ema, 719 | amp_scaler=loss_scaler, 720 | checkpoint_dir=output_dir, 721 | recovery_dir=output_dir, 722 | decreasing=decreasing, 723 | max_history=args.checkpoint_hist 724 | ) 725 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 726 | f.write(args_text) 727 | 728 | if utils.is_primary(args) and args.log_wandb: 729 | if has_wandb: 730 | wandb.init(project=args.experiment, config=args) 731 | else: 732 | _logger.warning( 733 | "You've requested to log metrics to wandb but package not found. " 734 | "Metrics not being logged to wandb, try `pip install wandb`") 735 | 736 | # setup learning rate schedule and starting epoch 737 | updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps 738 | lr_scheduler, num_epochs = create_scheduler_v2( 739 | optimizer, 740 | **scheduler_kwargs(args), 741 | updates_per_epoch=updates_per_epoch, 742 | ) 743 | start_epoch = 0 744 | if args.start_epoch is not None: 745 | # a specified start_epoch will always override the resume epoch 746 | start_epoch = args.start_epoch 747 | elif resume_epoch is not None: 748 | start_epoch = resume_epoch 749 | if lr_scheduler is not None and start_epoch > 0: 750 | if args.sched_on_updates: 751 | lr_scheduler.step_update(start_epoch * updates_per_epoch) 752 | else: 753 | lr_scheduler.step(start_epoch) 754 | 755 | if utils.is_primary(args): 756 | _logger.info( 757 | f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.') 758 | 759 | try: 760 | for epoch in range(start_epoch, num_epochs): 761 | if hasattr(dataset_train, 'set_epoch'): 762 | dataset_train.set_epoch(epoch) 763 | elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 764 | loader_train.sampler.set_epoch(epoch) 765 | 766 | train_metrics = train_one_epoch( 767 | epoch, 768 | model, 769 | loader_train, 770 | optimizer, 771 | train_loss_fn, 772 | args, 773 | lr_scheduler=lr_scheduler, 774 | saver=saver, 775 | output_dir=output_dir, 776 | amp_autocast=amp_autocast, 777 | loss_scaler=loss_scaler, 778 | model_ema=model_ema, 779 | mixup_fn=mixup_fn, 780 | ) 781 | 782 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 783 | if utils.is_primary(args): 784 | _logger.info("Distributing BatchNorm running means and vars") 785 | utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 786 | 787 | eval_metrics = validate( 788 | model, 789 | loader_eval, 790 | validate_loss_fn, 791 | args, 792 | amp_autocast=amp_autocast, 793 | ) 794 | 795 | if model_ema is not None and not args.model_ema_force_cpu: 796 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 797 | utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') 798 | 799 | ema_eval_metrics = validate( 800 | model_ema.module, 801 | loader_eval, 802 | validate_loss_fn, 803 | args, 804 | amp_autocast=amp_autocast, 805 | log_suffix=' (EMA)', 806 | ) 807 | eval_metrics = ema_eval_metrics 808 | 809 | if output_dir is not None: 810 | lrs = [param_group['lr'] for param_group in optimizer.param_groups] 811 | utils.update_summary( 812 | epoch, 813 | train_metrics, 814 | eval_metrics, 815 | filename=os.path.join(output_dir, 'summary.csv'), 816 | lr=sum(lrs) / len(lrs), 817 | write_header=best_metric is None, 818 | log_wandb=args.log_wandb and has_wandb, 819 | ) 820 | 821 | if saver is not None: 822 | # save proper checkpoint with eval metric 823 | save_metric = eval_metrics[eval_metric] 824 | if epoch >= args.epochs - 1: 825 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) 826 | 827 | if lr_scheduler is not None: 828 | # step LR for next epoch 829 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 830 | 831 | except KeyboardInterrupt: 832 | pass 833 | 834 | if best_metric is not None: 835 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) 836 | 837 | 838 | def train_one_epoch( 839 | epoch, 840 | model, 841 | loader, 842 | optimizer, 843 | loss_fn, 844 | args, 845 | device=torch.device('cuda'), 846 | lr_scheduler=None, 847 | saver=None, 848 | output_dir=None, 849 | amp_autocast=suppress, 850 | loss_scaler=None, 851 | model_ema=None, 852 | mixup_fn=None, 853 | ): 854 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 855 | if args.prefetcher and loader.mixup_enabled: 856 | loader.mixup_enabled = False 857 | elif mixup_fn is not None: 858 | mixup_fn.mixup_enabled = False 859 | 860 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 861 | has_no_sync = hasattr(model, "no_sync") 862 | update_time_m = utils.AverageMeter() 863 | data_time_m = utils.AverageMeter() 864 | losses_m = utils.AverageMeter() 865 | 866 | model.train() 867 | 868 | accum_steps = args.grad_accum_steps 869 | last_accum_steps = len(loader) % accum_steps 870 | updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps 871 | num_updates = epoch * updates_per_epoch 872 | last_batch_idx = len(loader) - 1 873 | last_batch_idx_to_accum = len(loader) - last_accum_steps 874 | 875 | data_start_time = update_start_time = time.time() 876 | optimizer.zero_grad() 877 | update_sample_count = 0 878 | for batch_idx, (input, target) in enumerate(loader): 879 | # print(input.shape,target) 880 | last_batch = batch_idx == last_batch_idx 881 | need_update = last_batch or (batch_idx + 1) % accum_steps == 0 882 | update_idx = batch_idx // accum_steps 883 | if batch_idx >= last_batch_idx_to_accum: 884 | accum_steps = last_accum_steps 885 | 886 | if not args.prefetcher: 887 | input, target = input.to(device), target.to(device) 888 | if mixup_fn is not None: 889 | input, target = mixup_fn(input, target) 890 | if args.channels_last: 891 | input = input.contiguous(memory_format=torch.channels_last) 892 | 893 | # multiply by accum steps to get equivalent for full update 894 | data_time_m.update(accum_steps * (time.time() - data_start_time)) 895 | 896 | def _forward(): 897 | with amp_autocast(): 898 | output = model(input) 899 | loss = loss_fn(output, target) 900 | if accum_steps > 1: 901 | loss /= accum_steps 902 | return loss 903 | 904 | def _backward(_loss): 905 | if loss_scaler is not None: 906 | loss_scaler( 907 | _loss, 908 | optimizer, 909 | clip_grad=args.clip_grad, 910 | clip_mode=args.clip_mode, 911 | parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), 912 | create_graph=second_order, 913 | need_update=need_update, 914 | ) 915 | else: 916 | _loss.backward(create_graph=second_order) 917 | if need_update: 918 | if args.clip_grad is not None: 919 | utils.dispatch_clip_grad( 920 | model_parameters(model, exclude_head='agc' in args.clip_mode), 921 | value=args.clip_grad, 922 | mode=args.clip_mode, 923 | ) 924 | optimizer.step() 925 | 926 | if has_no_sync and not need_update: 927 | with model.no_sync(): 928 | loss = _forward() 929 | _backward(loss) 930 | else: 931 | loss = _forward() 932 | _backward(loss) 933 | 934 | if not args.distributed: 935 | losses_m.update(loss.item() * accum_steps, input.size(0)) 936 | update_sample_count += input.size(0) 937 | 938 | if not need_update: 939 | data_start_time = time.time() 940 | continue 941 | 942 | num_updates += 1 943 | optimizer.zero_grad() 944 | if model_ema is not None: 945 | model_ema.update(model) 946 | 947 | if args.synchronize_step and device.type == 'cuda': 948 | torch.cuda.synchronize() 949 | time_now = time.time() 950 | update_time_m.update(time.time() - update_start_time) 951 | update_start_time = time_now 952 | 953 | if update_idx % args.log_interval == 0: 954 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 955 | lr = sum(lrl) / len(lrl) 956 | 957 | if args.distributed: 958 | reduced_loss = utils.reduce_tensor(loss.data, args.world_size) 959 | losses_m.update(reduced_loss.item() * accum_steps, input.size(0)) 960 | update_sample_count *= args.world_size 961 | 962 | if utils.is_primary(args): 963 | _logger.info( 964 | f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} ' 965 | f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] ' 966 | f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) ' 967 | f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s ' 968 | f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) ' 969 | f'LR: {lr:.3e} ' 970 | f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})' 971 | ) 972 | 973 | if args.save_images and output_dir: 974 | torchvision.utils.save_image( 975 | input, 976 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), 977 | padding=0, 978 | normalize=True 979 | ) 980 | 981 | if saver is not None and args.recovery_interval and ( 982 | (update_idx + 1) % args.recovery_interval == 0): 983 | saver.save_recovery(epoch, batch_idx=update_idx) 984 | 985 | if lr_scheduler is not None: 986 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) 987 | 988 | update_sample_count = 0 989 | data_start_time = time.time() 990 | # end for 991 | 992 | if hasattr(optimizer, 'sync_lookahead'): 993 | optimizer.sync_lookahead() 994 | 995 | return OrderedDict([('loss', losses_m.avg)]) 996 | 997 | 998 | def validate( 999 | model, 1000 | loader, 1001 | loss_fn, 1002 | args, 1003 | device=torch.device('cuda'), 1004 | amp_autocast=suppress, 1005 | log_suffix='' 1006 | ): 1007 | batch_time_m = utils.AverageMeter() 1008 | losses_m = utils.AverageMeter() 1009 | top1_m = utils.AverageMeter() 1010 | top5_m = utils.AverageMeter() 1011 | 1012 | model.eval() 1013 | 1014 | end = time.time() 1015 | last_idx = len(loader) - 1 1016 | with torch.no_grad(): 1017 | for batch_idx, (input, target) in enumerate(loader): 1018 | last_batch = batch_idx == last_idx 1019 | if not args.prefetcher: 1020 | input = input.to(device) 1021 | target = target.to(device) 1022 | if args.channels_last: 1023 | input = input.contiguous(memory_format=torch.channels_last) 1024 | 1025 | with amp_autocast(): 1026 | output = model(input) 1027 | if isinstance(output, (tuple, list)): 1028 | output = output[0] 1029 | 1030 | # augmentation reduction 1031 | reduce_factor = args.tta 1032 | if reduce_factor > 1: 1033 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 1034 | target = target[0:target.size(0):reduce_factor] 1035 | 1036 | loss = loss_fn(output, target) 1037 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 1038 | 1039 | if args.distributed: 1040 | reduced_loss = utils.reduce_tensor(loss.data, args.world_size) 1041 | acc1 = utils.reduce_tensor(acc1, args.world_size) 1042 | acc5 = utils.reduce_tensor(acc5, args.world_size) 1043 | else: 1044 | reduced_loss = loss.data 1045 | 1046 | if device.type == 'cuda': 1047 | torch.cuda.synchronize() 1048 | 1049 | losses_m.update(reduced_loss.item(), input.size(0)) 1050 | top1_m.update(acc1.item(), output.size(0)) 1051 | top5_m.update(acc5.item(), output.size(0)) 1052 | 1053 | batch_time_m.update(time.time() - end) 1054 | end = time.time() 1055 | if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): 1056 | log_name = 'Test' + log_suffix 1057 | _logger.info( 1058 | f'{log_name}: [{batch_idx:>4d}/{last_idx}] ' 1059 | f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) ' 1060 | f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) ' 1061 | f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) ' 1062 | f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})' 1063 | ) 1064 | 1065 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 1066 | 1067 | return metrics 1068 | 1069 | 1070 | if __name__ == '__main__': 1071 | main() -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ImageNetPath="path to imagenet" 3 | SyntheticPath="path to synthetic data" 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --rdzv_endpoint localhost:5003 train.py \ 6 | --imagenet_path $ImageNetPath --data-dir $SyntheticPath \ 7 | --dataset imagenette --model resnet50 --num-classes 10 \ 8 | --batch-size 128 --opt sgd --weight-decay 5e-4 --sched multistep --lr 0.1 --decay-rate 0.2 --epochs 200 \ 9 | --amp --output experiments/synthetic --experiment r50_imagenette --workers 8 --pin-mem --use-multi-epochs-loader 10 | 11 | --------------------------------------------------------------------------------