├── .gitattributes ├── README.md ├── data ├── ActivityNet │ ├── test.json │ ├── train.json │ ├── val.json │ └── words_vocab_activitynet.json └── TACoS │ ├── test.json │ ├── train.json │ ├── val.json │ └── words_vocab_tacos.json ├── experiments ├── activitynet │ └── MSAT-32.yaml └── tacos │ └── MSAT-128.yaml ├── imgs └── pipeline.jpg ├── lib ├── core │ ├── config.py │ ├── engine.py │ ├── eval.py │ └── utils.py ├── datasets │ ├── __init__.py │ ├── activitynet.py │ └── tacos.py └── models │ ├── __init__.py │ ├── bert_modules │ ├── __init__.py │ ├── file_utils.py │ ├── modeling.py │ ├── visual_linguistic_bert.py │ └── vlbert.py │ ├── frame_modules │ ├── __init__.py │ └── frame_pool.py │ ├── loss.py │ └── tan.py └── moment_localization ├── _init_paths.py ├── test.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSAT 2 | 3 | This is the code for the paper "Multi-stage Aggregated Transformer Network for Temporal Language Localization in Videos". We appreciate the contribution of [2D-TAN](https://github.com/microsoft/2D-TAN). 4 | 5 | ## Framework 6 | ![alt text](imgs/pipeline.jpg) 7 | 8 | ## Prerequisites 9 | - python 3 10 | - pytorch 1.6.0 11 | - torchvision 0.7.0 12 | - torchtext 0.7.0 13 | - easydict 14 | - terminaltables 15 | 16 | 17 | ## Quick Start 18 | 19 | Please download the visual features from [box drive](https://rochester.box.com/s/8znalh6y5e82oml2lr7to8s6ntab6mav) and save it to the `data/` folder. 20 | 21 | 22 | #### Training 23 | Use the following commands for training: 24 | ``` 25 | # For ActivityNet Captions 26 | python moment_localization/train.py --cfg experiments/activitynet/MSAT-32.yaml --verbose 27 | 28 | # For TACoS 29 | python moment_localization/train.py --cfg experiments/tacos/MSAT-128.yaml --verbose 30 | ``` 31 | 32 | #### Testing 33 | Our trained model are provided in [Baidu Yun](https://pan.baidu.com/s/1l9O7Csg479kmQB8hsqYM8w)(access code:rc2m). Please download them to the `checkpoints` folder. 34 | 35 | Then, run the following commands for evaluation: 36 | ``` 37 | # For ActivityNet Captions 38 | python moment_localization/test.py --cfg experiments/activitynet/MSAT-32.yaml --verbose --split test 39 | 40 | # For TACoS 41 | python moment_localization/test.py --cfg experiments/tacos/MSAT-128.yaml --verbose --split test 42 | ``` 43 | 44 | ## Citation 45 | If any part of our paper and code is helpful to your work, please generously cite with: 46 | ``` 47 | @inproceedings{zhang2021multi, 48 | title={Multi-Stage Aggregated Transformer Network for Temporal Language Localization in Videos}, 49 | author={Zhang, Mingxing and Yang, Yang and Chen, Xinghan and Ji, Yanli and Xu, Xing and Li, Jingjing and Shen, Heng Tao}, 50 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 51 | pages={12669--12678}, 52 | year={2021} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /data/TACoS/words_vocab_tacos.json: -------------------------------------------------------------------------------- 1 | {"words": ["PAD", "he", "took", "out", "a", "pan", "an", "egg", "cup", "fork", "cracked", "the", "added", "salt", "scrambled", "washed", "his", "hands", "turned", "on", "stove", "put", "oil", "in", "spread", "poured", "into", "man", "enters", "kitchen", "removes", "frying", "from", "drawer", "and", "places", "it", "retrieves", "refrigerator", "counter", "coffee", "cupboard", "placing", "them", "both", "cracks", "throws", "shells", "trash", "seasons", "eggs", "using", "wisks", "washes", "turns", "gets", "pantry", "adds", "some", "to", "returns", "puts", "sink", "bowl", "of", "herbs", "add", "plate", "sets", "takes", "insides", "pepper", "whisks", "prepares", "by", "turning", "heat", "adding", "continues", "whisk", "knife", "got", "all", "cooking", "utensils", "prepared", "mug", "including", "spices", "set", "temperature", "for", "placed", "mixture", "sprinkled", "ingredient", "now", "cooked", "onto", "scrambles", "skillet", "cabinet", "open", "seasoning", "pours", "ingredients", "person", "away", "eggshell", "stirs", "wooden", "spatula", "kind", "herb", "scrapes", "range", "uses", "scramble", "olive", "chopped", "transfers", "is", "done", "spoon", "breaks", "discards", "shell", "then", "off", "prepare", "crack", "one", "pour", "little", "bit", "as", "well", "with", "mix", "together", "contents", "small", "amount", "hot", "patiently", "wait", "cooks", "next", "top", "carefully", "push", "recycling", "bin", "burner", "under", "while", "waiting", "up", "more", "heated", "greens", "when", "empties", "additional", "move", "fridge", "rinses", "their", "swirls", "around", "waits", "few", "moments", "adjusts", "stares", "at", "stare", "chives", "keeps", "staring", "cook", "white", "does", "head", "bobs", "forever", "be", "omelette", "not", "fyi", "threw", "mixed", "cutting", "board", "plum", "cut", "pit", "procures", "slices", "down", "middle", "repeatedly", "tries", "tear", "apart", "sections", "finally", "so", "disposes", "core", "chunk", "another", "part", "goes", "opens", "couple", "drawers", "half", "pry", "fingers", "cuts", "till", "successfully", "has", "several", "pieces", "bad", "parts", "rest", "pull", "piece", "break", "second", "remaining", "quarters", "separate", "closes", "walks", "door", "crisper", "faucet", "fruit", "finds", "large", "begins", "remove", "garbage", "opened", "apple", "section", "take", "wash", "halves", "your", "hand", "additionally", "if", "needed", "smaller", "discard", "rinse", "four", "get", "place", "use", "2", "perpendicular", "circles", "until", "you", "are", "able", "that", "still", "attached", "throw", "chopping", "peels", "gather", "ginger", "bring", "find", "good", "end", "thinly", "slice", "appropriate", "further", "slivers", "root", "portion", "fragment", "desired", "sized", "width", "wise", "strips", "body", "latter", "side", "chops", "grabs", "chunks", "guy", "dices", "fresh", "round", "sharp", "skin", "about", "5", "even", "selects", "moves", "searches", "proper", "eventually", "choosing", "paring", "roughly", "she", "carrots", "carrot", "peeler", "peeled", "front", "back", "ends", "peel", "chop", "once", "leftover", "bunch", "vegetable", "dish", "water", "discarded", "tip", "peelings", "video", "properly", "container", "bundle", "peeling", "tool", "remains", "removed", "serving", "trimmed", "selected", "scraps", "sliced", "returned", "ready", "over", "picks", "rinds", "stem", "lancashire", "inch", "kitchenware", "dumps", "talks", "woman", "extra", "sprinkles", "beats", "oven", "distributes", "again", "flips", "completely", "opposite", "oiled", "cleans", "they", "metal", "two", "edge", "before", "discarding", "gathers", "yolks", "fully", "preheats", "retrieving", "mixes", "lets", "without", "flipping", "mixing", "after", "approximately", "minutes", "clean", "sponge", "soap", "uncooked", "evenly", "greases", "checks", "fluffs", "had", "heats", "salts", "finishes", "scrambling", "stainless", "steel", "was", "taken", "were", "whisked", "cleaned", "flipped", "moved", "rinsed", "seasoned", "shakes", "dry", "throwing", "pats", "sprinkling", "detergent", "season", "stir", "they're", "pair", "each", "wastebin", "washing", "going", "getting", "bottle", "condiments", "first", "beat", "utensil", "oils", "starts", "inside", "fine", "layer", "making", "sure", "receives", "grain", "swishes", "or", "poke", "peppers", "incredibly", "thoughtful", "manner", "pokes", "folds", "reassemble", "'scrambled'", "pineapple", "knives", "storage", "bottom", "sides", "hard", "spots", "trimming", "cleaning", "rounds", "scrap", "emptied", "equipment", "last", "bits", "plastic", "its", "trim", "any", "unwanted", "these", "lay", "1", "thick", "disks", "bowls", "tupperware", "outside", "zester", "i", "thought", "stands", "upright", "embedded", "lays", "big", "circle", "shapes", "rind", "her", "trims", "putting", "standing", "slicing", "downward", "excess", "circular", "everything", "dispose", "waste", "station", "grabbing", "pine", "outer", "pits", "quickly", "knifes", "crown", "cutlery", "countertop", "tub", "fronds", "scaly", "left", "thin", "cucumber", "chef's", "lengthwise", "hold", "cubes", "six", "times", "stacking", "twenty", "crosswise", "larger", "other", "bigger", "selecting", "long", "same", "through", "creating", "been", "ways", "diced", "keeping", "lenghtwise", "also", "holds", "squares", "picked", "chooses", "sideways", "form", "completed", "narrow", "stacks", "went", "pulled", "began", "either", "partially", "table", "pulls", "parer", "skins", "tools", "grabbed", "green", "entire", "obtain", "leaves", "heel", "coins", "stems", "girl", "unused", "rip", "plant", "shake", "edible", "arranges", "equal", "wilted", "lady", "reaches", "below", "preparing", "knocks", "much", "can", "finished", "want", "wipes", "stalks", "stalk", "herbes", "select", "blade", "pile", "withered", "4", "something", "handful", "vegetables", "perpendicularly", "sticking", "refrigerate", "uncut", "leek", "pot", "fills", "lip", "usable", "leeks", "aside", "lid", "length", "onions", "onion", "boil", "very", "just", "groups", "boiling", "trying", "handfuls", "filled", "cover", "drops", "sauce", "pics", "repeats", "exact", "process", "what", "look", "like", "roots", "leaving", "time", "medium", "covers", "preparer", "started", "covered", "retrieved", "chili", "dries", "cabinets", "straight", "line", "awaiting", "towel", "finely", "scoops", "dried", "cuppards", "where", "cubbard", "drys", "lateral", "whole", "slides", "higher", "chilli", "transfer", "tap", "clove", "looks", "rises", "medallions", "cupbard", "peice", "potato", "closing", "knees", "avocado", "serve", "sticker", "se", "pries", "closed", "make", "single", "axial", "way", "seperates", "axially", "seed", "proceeds", "minces", "bag", "running", "slightly", "kictchen", "chilies", "smells", "red", "thrown", "this", "boy", "chile", "cauliflower", "silver", "blue", "stuff", "used", "stray", "trimmings", "colander", "base", "flowerettes", "floret", "main", "work", "surface", "crumbs", "receptacle", "individual", "florets", "sweeps", "remnants", "cuttings", "trashbin", "pried", "disposed", "cubed", "separates", "ending", "which", "split", "him", "halfs", "loaf", "bread", "unwraps", "holding", "wrapping", "package", "wrap", "wraps", "reseals", "bread's", "packaging", "access", "serrated", "sawing", "motion", "re", "paper", "covering", "arrange", "existing", "makes", "pushes", "deeper", "previous", "region", "rewrapped", "figs", "fig", "tosses", "tops", "run", "dirt", "horizontal", "vertical", "kiwi", "dice", "tips", "thoroughly", "enjoy", "fourths", "persons", "yellow", "bell", "inner", "later", "seeds", "fours", "center", "plates", "cuttingboard", "hollows", "husk", "chosen", "ribs", "camera", "doesnt", "show", "but", "retrieve", "returning", "works", "offscreen", "fails", "notice", "pointing", "messes", "picture", "impossible", "tell", "doing", "seems", "presumably", "appears", "resumes", "dicing", "shot", "only", "pack", "chilis", "starting", "refridgerator", "horizontally", "pointed", "keep", "moving", "stop", "you've", "reached", "broccoli", "salted", "wiped", "attempts", "stops", "area", "lidded", "condiment", "refuse", "wrapper", "see", "fill", "third", "ten", "fifteen", "dashes", "begin", "portions", "fetches", "saucepan", "inches", "worktop", "spice", "rack", "gives", "unusable", "shrink", "broccolli", "juicer", "lime", "twists", "juice", "press", "squeezes", "rendered", "juices", "juicing", "halve", "twisted", "taking", "forth", "liquid", "obtained", "check", "collected", "citrus", "containing", "laterally", "extracts", "replaces", "juiced", "chilled", "carries", "presses", "firmly", "twisting", "express", "switches", "expressed", "action", "unable", "room", "immediately", "twirls", "drain", "drains", "3", "smile", "lovely", "would", "will", "need", "repeat", "have", "slowly", "unwrapped", "step", "grab", "butcher", "block", "fourth", "pressing", "rings", "include", "space", "neatly", "steps", "beans", "men", "refigerator", "bite", "size", "freshly", "broad", "bean", "wants", "picking", "locate", "checking", "freshness", "vigorously", "segments", "runs", "upon", "box", "limes", "draw", "squeezing", "manual", "strainer", "grinds", "squeezer", "remainders", "cloth", "squeeze", "quarter", "segment", "carve", "splits", "separating", "price", "plum's", "setting", "working", "free", "stone", "circumference", "producing", "bisects", "directions", "stood", "beside", "across", "rotating", "cannot", "do", "succeeds", "rests", "removing", "plumb", "obtains", "peach", "five", "mango", "slicer", "brand", "avoiding", "dishwasher", "magno", "briefly", "brings", "dirty", "he's", "how", "label", "s", "flesh", "washer", "pee", "pick", "stacked", "diagonally", "tiny", "bunches", "easier", "laying", "flat", "changes", "mind", "hollow", "gut", "carves", "cap", "fleshy", "it's", "materials", "satisfactory", "hollowed", "rectangular", "scooped", "orange", "separated", "oranges", "broke", "help", "gap", "between", "along", "forming", "wedges", "skinned", "tears", "women", "spends", "pulp", "brought", "#1", "#2", "potatoes", "potted", "skinning", "enough", "heating", "peelers", "filling", "turn", "gas", "deftly", "plums", "plumbs", "supplies", "many", "there", "walk", "beneath", "slit", "tastes", "presents", "product", "rearranges", "items", "bringing", "above", "skinless", "tasting", "centers", "stickers", "item", "soak", "preparation", "receiving", "return", "edges", "smooth", "wet", "scrubbing", "knots", "minced", "cabinent", "ultimately", "thinner", "displays", "collecting", "inedible", "bone", "similar", "rid", "aren't", "clear", "pitt", "most", "cups", "yolk", "pouring", "whites", "empty", "drips", "eggshells", "containers", "original", "mugs", "cubbord", "seperate", "ended", "tea", "allowing", "drip", "alternates", "gently", "rubs", "vertically", "silverware", "walked", "naked", "drained", "anything", "didn't", "potatoe", "cores", "'eyes'", "potatoes:", "actually", "debris", "eyes", "potato's", "eye", "herb's", "ness", "cool", "rubber", "band", "leafy", "banded", "undo", "lightly", "neat", "completes", "assigned", "task", "quartered", "twice", "date", "rag", "hanging", "aranges", "garlic", "pressed", "beak", "scrape", "bulb", "cloves", "saucer", "separation", "peals", "three", "individually", "utencil", "lingering", "handle", "comes", "scoop", "meat", "handles", "release", "mince", "cream", "ads", "stirring", "hardware", "milk", "finish", "dunks", "sees", "too", "improvised", "dishes", "occasionally", "progress", "assorted", "dips", "considers", "former", "dipped", "straightened", "counters", "contemplates", "cabbage", "grille", "sause", "blend", "roll", "grill", "creates", "preps", "simmer", "behind", "failed", "attempt", "finishing", "knobby", "diagonal", "rotate", "odor", "shark", "exterior", "various", "angles", "asian", "says", "carrot's", "rewashes", "carving", "rewash", "no", "than", "fruits", "cold", "kiwis", "unpeeling", "different", "serves", "wood", "reach", "continue", "kiwifruit", "final", "widthwise", "wish", "basket", "dark", "eighth", "10", "12", "total", "ripe", "nearby", "giving", "proceed", "contains", "easy", "removal", "workstation", "brown", "lenghwise", "bruised", "damaged", "areas", "upside", "grocery", "tag", "pearing", "semicircles", "semicircle", "knive", "complete", "pineapple's", "face", "rough", "eats", "halfway", "cucumbers", "perosn", "opening", "walking", "digs", "fashion", "circling", "lift", "relocates", "those", "combined", "unwrap", "recipe", "wipe", "rewrap", "wrapped", "type", "2nd", "source", "screen", "exactly", "fairly", "undesired", "mystery", "combines", "collects", "veggies", "direction", "remainder", "extractor", "seconds", "firm", "twist", "extracted", "extract", "glass", "rolls", "young", "draws", "rolling", "sifts", "scoups", "juicers", "itself", "tenderizes", "palm", "pushing", "mid", "navel", "exposed", "against", "grinding", "strain", "softens", "rubbing", "forcefully", "butter", "pat", "melted", "melt", "fry", "crops", "planter", "parcels", "choice", "such", "start", "taste", "stick", "spreads", "melting", "allows", "every", "pad", "melts", "let", "coat", "periodically", "during", "garnishes", "rewashed", "leak", "sheaf", "sheaves", "halved", "leek's", "nice", "chef", "leafs", "dinner", "heads", "jar", "boiled", "necessary", "case", "uniform", "may", "served", "ones", "don't", "recognize", "shaking", "crowns", "rounded", "slotted", "pots", "awhile", "decides", "faster", "pickle", "pasta", "scooper", "longer", "recovers", "ladle", "recover", "platter", "allow", "checked", "dash", "tender", "corner", "cupboards", "thirds", "right", "seperately", "3rd", "replaced", "hit", "stores", "unneeded", "concludes", "demonstration", "crush", "alternating", "smashes", "garlic's", "lining", "finger", "clung", "stuck", "completing", "flattens", "taps", "loose", "stiff", "bases", "inward", "avoid", "pod", "pepper's", "wastebasket", "beginning", "stemless", "scoring", "directly", "assemble", "slits", "exits", "near", "possible", "procedure", "loosen", "drinking", "shallow", "squeezed", "stored", "rotates", "grind", "afterwards", "lemon", "boards", "tray", "string", "strings", "required", "peas", "pea", "needs", "facet", "sitting", "pomegranate", "deep", "arils", "weaken", "spare", "one:", "two:", "three:", "four:", "five:", "six:", "insert", "thumb", "seven:", "eight:", "nine:", "ten:", "painstakingly", "unedible", "larges", "flower", "new", "stoned", "califlower", "external", "undesirable", "rips", "broken", "care", "thing", "fifth", "fallen", "choose", "longways", "pulling", "spiraling", "inserts", "thumbs", "majority", "divides", "th", "hald", "peices", "pith", "subject", "shut", "reopened", "looked", "realized", "sticks", "simultaneously", "following", "materials:", "sizes", "membrane", "that's", "whisker", "think", "flour", "regularly", "things", "spoons", "stovetop", "bouillon", "powder", "stock", "useful", "broth", "corn", "starch", "lids", "thrice", "beater", "sort", "30", "shave", "unsavory", "leave", "uncovered", "pot's", "replace", "browned", "determine", "appropriately", "tenderness", "creamsauce", "plated", "shaves", "home", "drop", "eat", "UNK"]} -------------------------------------------------------------------------------- /experiments/activitynet/MSAT-32.yaml: -------------------------------------------------------------------------------- 1 | WORKERS: 4 2 | 3 | MODEL_DIR: ./checkpoints 4 | RESULT_DIR: ./results 5 | LOG_DIR: ./log 6 | DATA_DIR: ./data/ActivityNet 7 | 8 | DATASET: 9 | NAME: ActivityNet 10 | VIS_INPUT_TYPE: c3d 11 | NO_VAL: False 12 | NUM_SAMPLE_CLIPS: 256 13 | TARGET_STRIDE: 8 14 | NORMALIZE: True 15 | RANDOM_SAMPLING: False 16 | 17 | TEST: 18 | BATCH_SIZE: 16 19 | RECALL: 1,5 20 | TIOU: 0.3,0.5,0.7 21 | EVAL_TRAIN: False 22 | NMS_THRESH: 0.51 23 | INTERVAL: 0.25 24 | 25 | CUDNN: 26 | DETERMINISTIC: False 27 | BENCHMARK: True 28 | 29 | TRAIN: 30 | BATCH_SIZE: 16 31 | LR: 0.0001 32 | WEIGHT_DECAY: 0.0000 33 | MAX_EPOCH: 100 34 | CONTINUE: False 35 | 36 | LOSS: 37 | NAME: bce_rescale_loss 38 | PARAMS: 39 | W1: 0.4 40 | W2: 1.0 41 | W3: 10.0 42 | W4: 0.1 43 | 44 | TAN: 45 | 46 | FRAME_MODULE: 47 | NAME: FrameAvgPool 48 | PARAMS: 49 | INPUT_SIZE: 500 50 | HIDDEN_SIZE: 512 51 | KERNEL_SIZE: 8 52 | STRIDE: 8 53 | 54 | VLBERT_MODULE: 55 | NAME: TLocVLBERT 56 | PARAMS: 57 | object_word_embed_mode: 2 58 | input_transform_type: 1 59 | visual_size: 500 60 | hidden_size: 512 61 | num_hidden_layers: 6 62 | num_attention_heads: 16 63 | intermediate_size: 512 64 | hidden_act: "gelu" 65 | hidden_dropout_prob: 0.1 66 | attention_probs_dropout_prob: 0.1 67 | max_position_embeddings: 512 68 | type_vocab_size: 2 69 | vocab_size: 10728 70 | initializer_range: 0.02 71 | visual_scale_text_init: 1.0 72 | visual_scale_object_init: 1.0 73 | visual_ln: False 74 | word_embedding_frozen: False 75 | with_pooler: True 76 | 77 | BERT_MODEL_NAME: './model/pretrained_model/bert-base-uncased' 78 | BERT_PRETRAINED: '' 79 | BERT_PRETRAINED_EPOCH: 0 80 | 81 | CLASSIFIER_TYPE: "2fc" 82 | CLASSIFIER_PRETRAINED: True 83 | CLASSIFIER_DROPOUT: 0.1 84 | CLASSIFIER_HIDDEN_SIZE: 512 85 | NO_GROUNDING: True 86 | 87 | MODEL: 88 | NAME: TAN 89 | CHECKPOINT: ./checkpoints/ActivityNet/TAN_c3d/iter028032-0.6150-0.8634.pkl 90 | -------------------------------------------------------------------------------- /experiments/tacos/MSAT-128.yaml: -------------------------------------------------------------------------------- 1 | WORKERS: 4 2 | 3 | MODEL_DIR: ./checkpoints 4 | RESULT_DIR: ./results 5 | LOG_DIR: ./log 6 | DATA_DIR: ./data/TACoS 7 | 8 | DATASET: 9 | NAME: TACoS 10 | VIS_INPUT_TYPE: c3d 11 | NO_VAL: False 12 | NUM_SAMPLE_CLIPS: 256 13 | TARGET_STRIDE: 2 14 | NORMALIZE: True 15 | RANDOM_SAMPLING: False 16 | 17 | TEST: 18 | BATCH_SIZE: 16 19 | RECALL: 1,5 20 | TIOU: 0.3,0.5,0.7 21 | EVAL_TRAIN: False 22 | NMS_THRESH: 0.37 23 | INTERVAL: 1.0 24 | 25 | CUDNN: 26 | DETERMINISTIC: False 27 | BENCHMARK: True 28 | 29 | TRAIN: 30 | BATCH_SIZE: 16 31 | LR: 0.0001 32 | WEIGHT_DECAY: 0.0000 33 | MAX_EPOCH: 100 34 | CONTINUE: False 35 | 36 | LOSS: 37 | NAME: bce_rescale_loss 38 | PARAMS: 39 | W1: 0.3 40 | W2: 1.0 41 | W3: 200.0 42 | W4: 0.25 43 | 44 | TAN: 45 | 46 | FRAME_MODULE: 47 | NAME: FrameAvgPool 48 | PARAMS: 49 | INPUT_SIZE: 4096 50 | HIDDEN_SIZE: 512 51 | KERNEL_SIZE: 2 52 | STRIDE: 2 53 | 54 | VLBERT_MODULE: 55 | NAME: TLocVLBERT 56 | PARAMS: 57 | object_word_embed_mode: 2 58 | input_transform_type: 1 59 | visual_size: 4096 60 | hidden_size: 512 61 | num_hidden_layers: 6 62 | num_attention_heads: 32 63 | intermediate_size: 512 64 | hidden_act: "gelu" 65 | hidden_dropout_prob: 0.1 66 | attention_probs_dropout_prob: 0.1 67 | max_position_embeddings: 512 68 | type_vocab_size: 2 69 | vocab_size: 1514 70 | initializer_range: 0.02 71 | visual_scale_text_init: 1.0 72 | visual_scale_object_init: 1.0 73 | visual_ln: false 74 | word_embedding_frozen: False 75 | with_pooler: True 76 | 77 | BERT_MODEL_NAME: './model/pretrained_model/bert-base-uncased' 78 | BERT_PRETRAINED: '' 79 | BERT_PRETRAINED_EPOCH: 0 80 | 81 | CLASSIFIER_TYPE: "2fc" 82 | CLASSIFIER_PRETRAINED: True 83 | CLASSIFIER_DROPOUT: 0.1 84 | CLASSIFIER_HIDDEN_SIZE: 512 85 | NO_GROUNDING: True 86 | 87 | MODEL: 88 | NAME: TAN 89 | CHECKPOINT: ./checkpoints/TACoS/TAN_c3d/iter017108-0.4879-0.6763.pkl 90 | -------------------------------------------------------------------------------- /imgs/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxingzhang90/MSAT/d9a467ca814fef3f4fee9bb6e50675e0c35b7dc2/imgs/pipeline.jpg -------------------------------------------------------------------------------- /lib/core/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import yaml 6 | from easydict import EasyDict as edict 7 | 8 | config = edict() 9 | 10 | config.WORKERS = 2 11 | config.LOG_DIR = '' 12 | config.MODEL_DIR = '' 13 | config.RESULT_DIR = '' 14 | config.DATA_DIR = '' 15 | config.VERBOSE = False 16 | config.TAG = '' 17 | 18 | # CUDNN related params 19 | config.CUDNN = edict() 20 | config.CUDNN.BENCHMARK = True 21 | config.CUDNN.DETERMINISTIC = False 22 | config.CUDNN.ENABLED = True 23 | 24 | # TAN related params 25 | config.TAN = edict() 26 | config.TAN.FRAME_MODULE = edict() 27 | config.TAN.FRAME_MODULE.NAME = '' 28 | config.TAN.FRAME_MODULE.PARAMS = None 29 | config.TAN.PROP_MODULE = edict() 30 | config.TAN.PROP_MODULE.NAME = '' 31 | config.TAN.PROP_MODULE.PARAMS = None 32 | config.TAN.FUSION_MODULE = edict() 33 | config.TAN.FUSION_MODULE.NAME = '' 34 | config.TAN.FUSION_MODULE.PARAMS = None 35 | config.TAN.MAP_MODULE = edict() 36 | config.TAN.MAP_MODULE.NAME = '' 37 | config.TAN.MAP_MODULE.PARAMS = None 38 | config.TAN.VLBERT_MODULE = edict() 39 | config.TAN.VLBERT_MODULE.NAME = '' 40 | config.TAN.VLBERT_MODULE.PARAMS = None 41 | config.TAN.PRED_INPUT_SIZE = 512 42 | 43 | # common params for NETWORK 44 | config.MODEL = edict() 45 | config.MODEL.NAME = '' 46 | config.MODEL.CHECKPOINT = '' # The checkpoint for the best performance 47 | 48 | # DATASET related params 49 | config.DATASET = edict() 50 | config.DATASET.ROOT = '' 51 | config.DATASET.NAME = '' 52 | config.DATASET.MODALITY = '' 53 | config.DATASET.VIS_INPUT_TYPE = '' 54 | config.DATASET.NO_VAL = False 55 | config.DATASET.BIAS = 0 56 | config.DATASET.NUM_SAMPLE_CLIPS = 256 57 | config.DATASET.TARGET_STRIDE = 16 58 | config.DATASET.DOWNSAMPLING_STRIDE = 16 59 | config.DATASET.SPLIT = '' 60 | config.DATASET.NORMALIZE = False 61 | config.DATASET.RANDOM_SAMPLING = False 62 | 63 | # train 64 | config.TRAIN = edict() 65 | config.TRAIN.LR = 0.001 66 | config.TRAIN.WEIGHT_DECAY = 0 67 | config.TRAIN.FACTOR = 0.8 68 | config.TRAIN.PATIENCE = 20 69 | config.TRAIN.MAX_EPOCH = 20 70 | config.TRAIN.BATCH_SIZE = 4 71 | config.TRAIN.SHUFFLE = True 72 | config.TRAIN.CONTINUE = False 73 | 74 | config.LOSS = edict() 75 | config.LOSS.NAME = 'bce_loss' 76 | config.LOSS.PARAMS = None 77 | 78 | # test 79 | config.TEST = edict() 80 | config.TEST.RECALL = [] 81 | config.TEST.TIOU = [] 82 | config.TEST.NMS_THRESH = 0.4 83 | config.TEST.INTERVAL = 0.25 84 | config.TEST.EVAL_TRAIN = False 85 | config.TEST.BATCH_SIZE = 1 86 | config.TEST.TOP_K = 10 87 | 88 | def _update_dict(cfg, value): 89 | for k, v in value.items(): 90 | if k in cfg: 91 | if k == 'PARAMS': 92 | cfg[k] = v 93 | elif isinstance(v, dict): 94 | _update_dict(cfg[k],v) 95 | else: 96 | cfg[k] = v 97 | else: 98 | raise ValueError("{} not exist in config.py".format(k)) 99 | 100 | def update_config(config_file): 101 | with open(config_file) as f: 102 | exp_config = edict(yaml.load(f, Loader=yaml.FullLoader)) 103 | for k, v in exp_config.items(): 104 | if k in config: 105 | if isinstance(v, dict): 106 | _update_dict(config[k], v) 107 | else: 108 | config[k] = v 109 | else: 110 | raise ValueError("{} not exist in config.py".format(k)) 111 | -------------------------------------------------------------------------------- /lib/core/engine.py: -------------------------------------------------------------------------------- 1 | class Engine(object): 2 | def __init__(self): 3 | self.hooks = {} 4 | 5 | def hook(self, name, state): 6 | 7 | if name in self.hooks: 8 | self.hooks[name](state) 9 | 10 | def train(self, network, iterator, maxepoch, optimizer, scheduler): 11 | state = { 12 | 'network': network, 13 | 'iterator': iterator, 14 | 'maxepoch': maxepoch, 15 | 'optimizer': optimizer, 16 | 'scheduler': scheduler, 17 | 'epoch': 0, 18 | 't': 0, 19 | 'train': True, 20 | } 21 | 22 | self.hook('on_start', state) 23 | while state['epoch'] < state['maxepoch']: 24 | self.hook('on_start_epoch', state) 25 | for sample in state['iterator']: 26 | state['sample'] = sample 27 | self.hook('on_sample', state) 28 | 29 | def closure(): 30 | loss, output = state['network'](state['sample']) 31 | state['output'] = output 32 | state['loss'] = loss 33 | loss.backward() 34 | self.hook('on_forward', state) 35 | # to free memory in save_for_backward 36 | state['output'] = None 37 | state['loss'] = None 38 | return loss 39 | 40 | state['optimizer'].zero_grad() 41 | state['optimizer'].step(closure) 42 | self.hook('on_update', state) 43 | state['t'] += 1 44 | state['epoch'] += 1 45 | self.hook('on_end_epoch', state) 46 | self.hook('on_end', state) 47 | return state 48 | 49 | def test(self, network, iterator, split): 50 | state = { 51 | 'network': network, 52 | 'iterator': iterator, 53 | 'split': split, 54 | 't': 0, 55 | 'train': False, 56 | } 57 | 58 | self.hook('on_test_start', state) 59 | for sample in state['iterator']: 60 | state['sample'] = sample 61 | self.hook('on_test_sample', state) 62 | 63 | def closure(): 64 | loss, output = state['network'](state['sample']) 65 | state['output'] = output 66 | state['loss'] = loss 67 | self.hook('on_test_forward', state) 68 | # to free memory in save_for_backward 69 | state['output'] = None 70 | state['loss'] = None 71 | 72 | closure() 73 | state['t'] += 1 74 | self.hook('on_test_end', state) 75 | return state -------------------------------------------------------------------------------- /lib/core/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import numpy as np 4 | from terminaltables import AsciiTable 5 | 6 | from core.config import config, update_config 7 | 8 | def iou(pred, gt): # require pred and gt is numpy 9 | assert isinstance(pred, list) and isinstance(gt,list) 10 | pred_is_list = isinstance(pred[0],list) 11 | gt_is_list = isinstance(gt[0],list) 12 | if not pred_is_list: pred = [pred] 13 | if not gt_is_list: gt = [gt] 14 | pred, gt = np.array(pred), np.array(gt) 15 | inter_left = np.maximum(pred[:,0,None], gt[None,:,0]) 16 | inter_right = np.minimum(pred[:,1,None], gt[None,:,1]) 17 | inter = np.maximum(0.0, inter_right - inter_left) 18 | union_left = np.minimum(pred[:,0,None], gt[None,:,0]) 19 | union_right = np.maximum(pred[:,1,None], gt[None,:,1]) 20 | union = np.maximum(0.0, union_right - union_left) 21 | overlap = 1.0 * inter / union 22 | if not gt_is_list: 23 | overlap = overlap[:,0] 24 | if not pred_is_list: 25 | overlap = overlap[0] 26 | return overlap 27 | 28 | def rank(pred, gt): 29 | return pred.index(gt) + 1 30 | 31 | def nms(dets, thresh=0.4, top_k=-1): 32 | """Pure Python NMS baseline.""" 33 | if len(dets) == 0: return [] 34 | order = np.arange(0,len(dets),1) 35 | dets = np.array(dets) 36 | x1 = dets[:, 0] 37 | x2 = dets[:, 1] 38 | lengths = x2 - x1 39 | keep = [] 40 | while order.size > 0: 41 | i = order[0] 42 | keep.append(i) 43 | if len(keep) == top_k: 44 | break 45 | xx1 = np.maximum(x1[i], x1[order[1:]]) 46 | xx2 = np.minimum(x2[i], x2[order[1:]]) 47 | inter = np.maximum(0.0, xx2 - xx1) 48 | ovr = inter / (lengths[i] + lengths[order[1:]] - inter) 49 | inds = np.where(ovr <= thresh)[0] 50 | order = order[inds + 1] 51 | 52 | return dets[keep] 53 | 54 | def eval(segments, data): 55 | tious = [float(i) for i in config.TEST.TIOU.split(',')] if isinstance(config.TEST.TIOU,str) else [config.TEST.TIOU] 56 | recalls = [int(i) for i in config.TEST.RECALL.split(',')] if isinstance(config.TEST.RECALL,str) else [config.TEST.RECALL] 57 | 58 | eval_result = [[[] for _ in recalls] for _ in tious] 59 | max_recall = max(recalls) 60 | average_iou = [] 61 | for seg, dat in zip(segments, data): 62 | seg = nms(seg, thresh=config.TEST.NMS_THRESH, top_k=max_recall).tolist() 63 | overlap = iou(seg, [dat['times']]) 64 | average_iou.append(np.mean(np.sort(overlap[0])[-3:])) 65 | 66 | for i,t in enumerate(tious): 67 | for j,r in enumerate(recalls): 68 | eval_result[i][j].append((overlap > t)[:r].any()) 69 | eval_result = np.array(eval_result).mean(axis=-1) 70 | miou = np.mean(average_iou) 71 | 72 | 73 | return eval_result, miou 74 | 75 | def eval_predictions(segments, data, verbose=True): 76 | eval_result, miou = eval(segments, data) 77 | if verbose: 78 | print(display_results(eval_result, miou, '')) 79 | 80 | return eval_result, miou 81 | 82 | def display_results(eval_result, miou, title=None): 83 | tious = [float(i) for i in config.TEST.TIOU.split(',')] if isinstance(config.TEST.TIOU,str) else [config.TEST.TIOU] 84 | recalls = [int(i) for i in config.TEST.RECALL.split(',')] if isinstance(config.TEST.RECALL,str) else [config.TEST.RECALL] 85 | 86 | display_data = [['Rank@{},mIoU@{}'.format(i,j) for i in recalls for j in tious]+['mIoU']] 87 | eval_result = eval_result*100 88 | miou = miou*100 89 | display_data.append(['{:.02f}'.format(eval_result[j][i]) for i in range(len(recalls)) for j in range(len(tious))] 90 | +['{:.02f}'.format(miou)]) 91 | table = AsciiTable(display_data, title) 92 | for i in range(len(tious)*len(recalls)): 93 | table.justify_columns[i] = 'center' 94 | return table.table 95 | 96 | 97 | def parse_args(): 98 | parser = argparse.ArgumentParser(description='Train localization network') 99 | 100 | # general 101 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 102 | args, rest = parser.parse_known_args() 103 | 104 | # update config 105 | update_config(args.cfg) 106 | 107 | parser.add_argument('--verbose', default=False, action="store_true", help='print progress bar') 108 | args = parser.parse_args() 109 | 110 | return args 111 | 112 | def reset_config(config, args): 113 | if args.verbose: 114 | config.VERBOSE = args.verbose 115 | 116 | if __name__ == '__main__': 117 | args = parse_args() 118 | reset_config(config, args) 119 | train_data = json.load(open('/data/home2/hacker01/Data/DiDeMo/train_data.json', 'r')) 120 | val_data = json.load(open('/data/home2/hacker01/Data/DiDeMo/val_data.json', 'r')) 121 | 122 | moment_frequency_dict = {} 123 | for d in train_data: 124 | times = [t for t in d['times']] 125 | for time in times: 126 | time = tuple(time) 127 | if time not in moment_frequency_dict.keys(): 128 | moment_frequency_dict[time] = 0 129 | moment_frequency_dict[time] += 1 130 | 131 | prior = sorted(moment_frequency_dict, key=moment_frequency_dict.get, reverse=True) 132 | prior = [list(item) for item in prior] 133 | prediction = [prior for d in val_data] 134 | 135 | eval_predictions(prediction, val_data) -------------------------------------------------------------------------------- /lib/core/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | import time 8 | from pathlib import Path 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | def create_logger(cfg, cfg_name, tag='train'): 28 | root_log_dir = Path(cfg.LOG_DIR) 29 | # set up logger 30 | if not root_log_dir.exists(): 31 | print('=> creating {}'.format(root_log_dir)) 32 | root_log_dir.mkdir() 33 | 34 | dataset = cfg.DATASET.NAME 35 | cfg_name = os.path.basename(cfg_name).split('.yaml')[0] 36 | 37 | final_log_dir = root_log_dir / dataset / cfg_name 38 | 39 | print('=> creating {}'.format(final_log_dir)) 40 | final_log_dir.mkdir(parents=True, exist_ok=True) 41 | 42 | time_str = time.strftime('%Y-%m-%d-%H-%M') 43 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, tag) 44 | final_log_file = final_log_dir / log_file 45 | head = '%(asctime)-15s %(message)s' 46 | logging.basicConfig(filename=str(final_log_file), format=head) 47 | logger = logging.getLogger() 48 | logger.setLevel(logging.INFO) 49 | console = logging.StreamHandler() 50 | logging.getLogger('').addHandler(console) 51 | 52 | return logger, str(final_log_dir) 53 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from core.config import config 4 | 5 | def collate_fn(batch): 6 | batch_word_vectors = [b['word_vectors'] for b in batch] 7 | batch_txt_mask = [b['txt_mask'] for b in batch] 8 | batch_map_gt = [b['map_gt'] for b in batch] 9 | batch_anno_idxs = [b['anno_idx'] for b in batch] 10 | batch_vis_feats = [b['visual_input'] for b in batch] 11 | batch_duration = [b['duration'] for b in batch] 12 | batch_word_label = [b['word_label'] for b in batch] 13 | batch_word_mask = [b['word_mask'] for b in batch] 14 | batch_gt_times = [b['gt_times'].unsqueeze(0) for b in batch] 15 | 16 | max_num_clips = max([map_gt.shape[-1] for map_gt in batch_map_gt]) 17 | padded_batch_map_gt = torch.zeros(len(batch_map_gt), 5, max_num_clips) 18 | for i, map_gt in enumerate(batch_map_gt): 19 | num_clips = map_gt.shape[-1] 20 | padded_batch_map_gt[i][:,:num_clips] = map_gt 21 | 22 | batch_data = { 23 | 'batch_anno_idxs': batch_anno_idxs, 24 | 'batch_word_vectors': nn.utils.rnn.pad_sequence(batch_word_vectors, batch_first=True), 25 | 'batch_txt_mask': nn.utils.rnn.pad_sequence(batch_txt_mask, batch_first=True), 26 | 'batch_map_gt': padded_batch_map_gt, 27 | 'batch_vis_input': nn.utils.rnn.pad_sequence(batch_vis_feats, batch_first=True).float(), 28 | 'batch_duration': batch_duration, 29 | 'batch_word_label': nn.utils.rnn.pad_sequence(batch_word_label, batch_first=True).long(), 30 | 'batch_word_mask': nn.utils.rnn.pad_sequence(batch_word_mask, batch_first=True).float(), 31 | 'batch_gt_times': torch.cat(batch_gt_times, 0) 32 | } 33 | 34 | return batch_data 35 | 36 | def average_to_fixed_length(visual_input): 37 | num_sample_clips = config.DATASET.NUM_SAMPLE_CLIPS 38 | num_clips = visual_input.shape[0] 39 | idxs = torch.arange(0, num_sample_clips+1, 1.0)/num_sample_clips*num_clips 40 | idxs = torch.min(torch.round(idxs).long(),torch.tensor(num_clips-1)) 41 | new_visual_input = [] 42 | for i in range(num_sample_clips): 43 | s_idx, e_idx = idxs[i].item(), idxs[i+1].item() 44 | if s_idx < e_idx: 45 | new_visual_input.append(torch.mean(visual_input[s_idx:e_idx],dim=0)) 46 | else: 47 | new_visual_input.append(visual_input[s_idx]) 48 | new_visual_input = torch.stack(new_visual_input, dim=0) 49 | return new_visual_input 50 | 51 | from datasets.activitynet import ActivityNet 52 | from datasets.tacos import TACoS 53 | -------------------------------------------------------------------------------- /lib/datasets/activitynet.py: -------------------------------------------------------------------------------- 1 | """ Dataset loader for the ActivityNet Captions dataset """ 2 | import os 3 | import json 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | import h5py 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | import torch.utils.data as data 12 | import torchtext 13 | 14 | from . import average_to_fixed_length 15 | from core.eval import iou 16 | from core.config import config 17 | 18 | class ActivityNet(data.Dataset): 19 | 20 | vocab = torchtext.vocab.pretrained_aliases["glove.6B.300d"]() 21 | vocab.itos.extend(['']) 22 | vocab.stoi[''] = vocab.vectors.shape[0] 23 | vocab.vectors = torch.cat([vocab.vectors, torch.zeros(1, vocab.dim)], dim=0) 24 | word_embedding = nn.Embedding.from_pretrained(vocab.vectors) 25 | 26 | def __init__(self, split): 27 | super(ActivityNet, self).__init__() 28 | 29 | self.vis_input_type = config.DATASET.VIS_INPUT_TYPE 30 | self.data_dir = config.DATA_DIR 31 | self.split = split 32 | 33 | # self.itos = ['PAD'] 34 | # self.ston = OrderedDict() 35 | # self.ston['PAD'] = 0 36 | 37 | with open('./data/ActivityNet/words_vocab_activitynet.json', 'r') as f: 38 | tmp = json.load(f) 39 | self.itos = tmp['words'] 40 | self.stoi = OrderedDict() 41 | for i, w in enumerate(self.itos): 42 | self.stoi[w] = i 43 | print(len(self.stoi)) 44 | 45 | # val_1.json is renamed as val.json, val_2.json is renamed as test.json 46 | with open(os.path.join(self.data_dir, '{}.json'.format(split)),'r') as f: 47 | annotations = json.load(f) 48 | anno_pairs = [] 49 | max_sent_len = 0 50 | for vid, video_anno in annotations.items(): 51 | duration = video_anno['duration'] 52 | for timestamp, sentence in zip(video_anno['timestamps'], video_anno['sentences']): 53 | if timestamp[0] < timestamp[1]: 54 | sentence = sentence.replace(',',' ').replace('/',' ').replace('\"',' ').replace('-',' ').replace(';',' ').replace('.',' ').replace('&',' ').replace('?',' ').replace('!',' ').replace('(',' ').replace(')',' ') 55 | anno_pairs.append( 56 | { 57 | 'video': vid, 58 | 'duration': duration, 59 | 'times':[max(timestamp[0],0),min(timestamp[1],duration)], 60 | 'description':sentence, 61 | } 62 | ) 63 | if len(sentence.split()) > max_sent_len: 64 | max_sent_len = len(sentence.split()) 65 | 66 | # for w in sentence.split(): 67 | # if w.lower() not in self.ston.keys(): 68 | # self.itos.append(w.lower()) 69 | # self.ston[w.lower()] = 1 70 | # else: 71 | # self.ston[w.lower()] += 1 72 | 73 | self.annotations = anno_pairs 74 | print('max_sent_len', max_sent_len) 75 | 76 | # self.itos.extend(['UNK']) 77 | # self.ston['UNK'] = 0 78 | # print('total words:', len(self.itos)) 79 | # # ston = sorted(self.ston.items(),key = lambda x:x[1],reverse = True) 80 | # # print(ston) 81 | # if len(self.itos) > 10000: 82 | # with open('words_vocab.json', 'w') as f: 83 | # json.dump({'words':self.itos}, f) 84 | 85 | def __getitem__(self, index): 86 | video_id = self.annotations[index]['video'] 87 | gt_s_time, gt_e_time = self.annotations[index]['times'] 88 | sentence = self.annotations[index]['description'] 89 | duration = self.annotations[index]['duration'] 90 | 91 | word_label = [self.stoi.get(w.lower(), 10727) for w in sentence.split()] 92 | range_i = range(len(word_label)) 93 | # if np.random.uniform(0,1)<0.8: 94 | word_mask = [1. if np.random.uniform(0,1)<0.15 else 0. for _ in range_i] 95 | if np.sum(word_mask) == 0.: 96 | mask_i = np.random.choice(range_i) 97 | word_mask[mask_i] = 1. 98 | if np.sum(word_mask) == len(word_mask): 99 | unmask_i = np.random.choice(range_i) 100 | word_mask[unmask_i] = 0. 101 | # else: 102 | # word_mask = [0. for _ in range_i] 103 | 104 | word_label = torch.tensor(word_label, dtype=torch.long) 105 | word_mask = torch.tensor(word_mask, dtype=torch.float) 106 | 107 | word_idxs = torch.tensor([self.vocab.stoi.get(w.lower(), 400000) for w in sentence.split()], dtype=torch.long) 108 | word_vectors = self.word_embedding(word_idxs) 109 | 110 | visual_input, visual_mask = self.get_video_features(video_id) 111 | 112 | 113 | # Time scaled to same size 114 | if config.DATASET.NUM_SAMPLE_CLIPS > 0: 115 | # visual_input = sample_to_fixed_length(visual_input, random_sampling=True) 116 | visual_input = average_to_fixed_length(visual_input) 117 | num_clips = config.DATASET.NUM_SAMPLE_CLIPS//config.DATASET.TARGET_STRIDE 118 | # Time unscaled NEED FIXED WINDOW SIZE 119 | else: 120 | num_clips = visual_input.shape[0]//config.DATASET.TARGET_STRIDE 121 | raise NotImplementedError 122 | # torch.arange(0,) 123 | 124 | map_gt = np.zeros((5, num_clips+1), dtype=np.float32) 125 | 126 | clip_duration = duration/num_clips 127 | gt_s = gt_s_time/clip_duration 128 | gt_e = gt_e_time/clip_duration 129 | gt_length = gt_e - gt_s 130 | gt_center = (gt_e + gt_s) / 2. 131 | map_gt[0, :] = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_s)/(0.25*gt_length) ) ) 132 | map_gt[0, map_gt[0, :]>=0.6] = 1. 133 | map_gt[0, map_gt[0, :]<0.1353] = 0. 134 | map_gt[1, :] = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_e)/(0.25*gt_length) ) ) 135 | map_gt[1, map_gt[1, :]>=0.6] = 1. 136 | map_gt[1, map_gt[1, :]<0.1353] = 0. 137 | # map_gt[2, gt_s_idx:gt_e_idx] = 1. 138 | map_gt[2, :] = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_center)/(0.21233*gt_length) ) ) 139 | map_gt[2, map_gt[2, :]>=0.78] = 1. 140 | map_gt[2, map_gt[2, :]<0.0625] = 0. 141 | map_gt[3, :] = gt_s - np.arange(num_clips+1) 142 | map_gt[4, :] = gt_e - np.arange(num_clips+1) 143 | if (map_gt[0, :]>0.4).sum() == 0: 144 | p = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_s)/(0.25*gt_length) ) ) 145 | idx = np.argsort(p) 146 | map_gt[0, idx[-1]] = 1. 147 | if (map_gt[1, :]>0.4).sum() == 0: 148 | p = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_e)/(0.25*gt_length) ) ) 149 | idx = np.argsort(p) 150 | map_gt[1, idx[-1]] = 1. 151 | if map_gt[2, :].sum() == 0: 152 | p = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_center)/(0.21233*gt_length) ) ) 153 | idx = np.argmax(p) 154 | map_gt[2, idx] = 1. 155 | 156 | item = { 157 | 'visual_input': visual_input, 158 | 'vis_mask': visual_mask, 159 | 'anno_idx': index, 160 | 'word_vectors': word_vectors, 161 | 'duration': duration, 162 | 'txt_mask': torch.ones(word_vectors.shape[0],), 163 | 'map_gt': torch.from_numpy(map_gt), 164 | 'word_label': word_label, 165 | 'word_mask': word_mask, 166 | 'gt_times': torch.from_numpy(np.array([gt_s, gt_e], dtype=np.float32)) 167 | } 168 | 169 | return item 170 | 171 | def __len__(self): 172 | return len(self.annotations) 173 | 174 | def get_video_features(self, vid): 175 | assert config.DATASET.VIS_INPUT_TYPE == 'c3d' 176 | with h5py.File(os.path.join(self.data_dir, 'sub_activitynet_v1-3.c3d.hdf5'), 'r') as f: 177 | features = torch.from_numpy(f[vid]['c3d_features'][:]) 178 | if config.DATASET.NORMALIZE: 179 | features = F.normalize(features,dim=1) 180 | vis_mask = torch.ones((features.shape[0], 1)) 181 | return features, vis_mask -------------------------------------------------------------------------------- /lib/datasets/tacos.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import json 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | import h5py 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | import torch.utils.data as data 12 | import torchtext 13 | 14 | from . import average_to_fixed_length 15 | from core.eval import iou 16 | from core.config import config 17 | 18 | class TACoS(data.Dataset): 19 | 20 | vocab = torchtext.vocab.pretrained_aliases["glove.6B.300d"]() 21 | vocab.itos.extend(['']) 22 | vocab.stoi[''] = vocab.vectors.shape[0] 23 | vocab.vectors = torch.cat([vocab.vectors, torch.zeros(1, vocab.dim)], dim=0) 24 | word_embedding = nn.Embedding.from_pretrained(vocab.vectors) 25 | 26 | def __init__(self, split): 27 | super(TACoS, self).__init__() 28 | 29 | self.vis_input_type = config.DATASET.VIS_INPUT_TYPE 30 | self.data_dir = config.DATA_DIR 31 | self.split = split 32 | 33 | # self.itos = ['PAD'] 34 | # self.ston = OrderedDict() 35 | # self.ston['PAD'] = 0 36 | 37 | with open('./data/TACoS/words_vocab_tacos.json', 'r') as f: 38 | tmp = json.load(f) 39 | self.itos = tmp['words'] 40 | self.stoi = OrderedDict() 41 | for i, w in enumerate(self.itos): 42 | self.stoi[w] = i 43 | print(len(self.stoi)) 44 | 45 | # val_1.json is renamed as val.json, val_2.json is renamed as test.json 46 | with open(os.path.join(self.data_dir, '{}.json'.format(split)),'r') as f: 47 | annotations = json.load(f) 48 | anno_pairs = [] 49 | max_sent_len = 0 50 | for vid, video_anno in annotations.items(): 51 | duration = video_anno['num_frames']/video_anno['fps'] 52 | for timestamp, sentence in zip(video_anno['timestamps'], video_anno['sentences']): 53 | if timestamp[0] < timestamp[1]: 54 | sentence = sentence.replace(',',' ').replace('/',' ').replace('\"',' ').replace('-',' ').replace(';',' ').replace('.',' ').replace('&',' ').replace('?',' ').replace('!',' ').replace('(',' ').replace(')',' ') 55 | anno_pairs.append( 56 | { 57 | 'video': vid, 58 | 'duration': duration, 59 | 'times':[max(timestamp[0]/video_anno['fps'],0),min(timestamp[1]/video_anno['fps'],duration)], 60 | 'description':sentence, 61 | } 62 | ) 63 | if len(sentence.split()) > max_sent_len: 64 | max_sent_len = len(sentence.split()) 65 | 66 | # for w in sentence.split(): 67 | # if w.lower() not in self.ston.keys(): 68 | # self.itos.append(w.lower()) 69 | # self.ston[w.lower()] = 1 70 | # else: 71 | # self.ston[w.lower()] += 1 72 | 73 | self.annotations = anno_pairs 74 | print('max_sent_len', max_sent_len) 75 | 76 | # self.itos.extend(['UNK']) 77 | # self.ston['UNK'] = 0 78 | # print('total words:', len(self.itos)) 79 | # # ston = sorted(self.ston.items(),key = lambda x:x[1],reverse = True) 80 | # # print(ston) 81 | # if len(self.itos) > 1500: 82 | # with open('words_vocab.json', 'w') as f: 83 | # json.dump({'words':self.itos}, f) 84 | 85 | def __getitem__(self, index): 86 | video_id = self.annotations[index]['video'] 87 | gt_s_time, gt_e_time = self.annotations[index]['times'] 88 | sentence = self.annotations[index]['description'] 89 | duration = self.annotations[index]['duration'] 90 | 91 | word_label = [self.stoi.get(w.lower(), 1513) for w in sentence.split()] 92 | range_i = range(len(word_label)) 93 | word_mask = [1. if np.random.uniform(0,1)<0.15 else 0. for _ in range_i] 94 | if np.sum(word_mask) == 0.: 95 | mask_i = np.random.choice(range_i) 96 | word_mask[mask_i] = 1. 97 | if np.sum(word_mask) == len(word_mask): 98 | unmask_i = np.random.choice(range_i) 99 | word_mask[unmask_i] = 0. 100 | word_label = torch.tensor(word_label, dtype=torch.long) 101 | word_mask = torch.tensor(word_mask, dtype=torch.float) 102 | 103 | word_idxs = torch.tensor([self.vocab.stoi.get(w.lower(), 400000) for w in sentence.split()], dtype=torch.long) 104 | word_vectors = self.word_embedding(word_idxs) 105 | 106 | visual_input, visual_mask = self.get_video_features(video_id) 107 | 108 | # visual_input = sample_to_fixed_length(visual_input, random_sampling=config.DATASET.RANDOM_SAMPLING) 109 | visual_input = average_to_fixed_length(visual_input) 110 | num_clips = config.DATASET.NUM_SAMPLE_CLIPS//config.DATASET.TARGET_STRIDE 111 | 112 | map_gt = np.zeros((5, num_clips+1), dtype=np.float32) 113 | 114 | clip_duration = duration/num_clips 115 | gt_s = gt_s_time/clip_duration 116 | gt_e = gt_e_time/clip_duration 117 | gt_length = gt_e - gt_s 118 | gt_center = (gt_e + gt_s) / 2. 119 | map_gt[0, :] = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_s)/(0.25*gt_length) ) ) 120 | map_gt[0, map_gt[0, :]>=0.7] = 1. 121 | map_gt[0, map_gt[0, :]<0.1353] = 0. 122 | map_gt[1, :] = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_e)/(0.25*gt_length) ) ) 123 | map_gt[1, map_gt[1, :]>=0.7] = 1. 124 | map_gt[1, map_gt[1, :]<0.1353] = 0. 125 | # map_gt[2, gt_s_idx:gt_e_idx] = 1. 126 | map_gt[2, :] = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_center)/(0.21233*gt_length) ) ) 127 | map_gt[2, map_gt[2, :]>=0.78] = 1. 128 | map_gt[2, map_gt[2, :]<0.0625] = 0. 129 | map_gt[3, :] = gt_s - np.arange(num_clips+1) 130 | map_gt[4, :] = gt_e - np.arange(num_clips+1) 131 | if (map_gt[0, :]>0.4).sum() == 0: 132 | p = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_s)/(0.25*gt_length) ) ) 133 | idx = np.argsort(p) 134 | map_gt[0, idx[-1]] = 1. 135 | if (map_gt[1, :]>0.4).sum() == 0: 136 | p = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_e)/(0.25*gt_length) ) ) 137 | idx = np.argsort(p) 138 | map_gt[1, idx[-1]] = 1. 139 | if map_gt[2, :].sum() == 0: 140 | p = np.exp( -0.5 * np.square( (np.arange(num_clips+1)-gt_center)/(0.21233*gt_length) ) ) 141 | idx = np.argmax(p) 142 | map_gt[2, idx] = 1. 143 | 144 | item = { 145 | 'visual_input': visual_input, 146 | 'vis_mask': visual_mask, 147 | 'anno_idx': index, 148 | 'word_vectors': word_vectors, 149 | 'duration': duration, 150 | 'txt_mask': torch.ones(word_vectors.shape[0],), 151 | 'map_gt': torch.from_numpy(map_gt), 152 | 'word_label': word_label, 153 | 'word_mask': word_mask, 154 | 'gt_times': torch.from_numpy(np.array([gt_s, gt_e], dtype=np.float32)) 155 | } 156 | 157 | return item 158 | 159 | def __len__(self): 160 | return len(self.annotations) 161 | 162 | def get_video_features(self, vid): 163 | assert config.DATASET.VIS_INPUT_TYPE == 'c3d' 164 | with h5py.File(os.path.join(self.data_dir, 'tall_c3d_features.hdf5'), 'r') as f: 165 | features = torch.from_numpy(f[vid][:]) 166 | if config.DATASET.NORMALIZE: 167 | features = F.normalize(features,dim=1) 168 | vis_mask = torch.ones((features.shape[0], 1)) 169 | return features, vis_mask -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.tan import TAN -------------------------------------------------------------------------------- /lib/models/bert_modules/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | def get_padded_mask_and_weight(*args): 4 | if len(args) == 2: 5 | mask, conv = args 6 | masked_weight = torch.round(F.conv2d(mask.clone().float(), torch.ones(1, 1, *conv.kernel_size).cuda(), 7 | stride=conv.stride, padding=conv.padding, dilation=conv.dilation)) 8 | elif len(args) == 5: 9 | mask, k, s, p, d = args 10 | masked_weight = torch.round(F.conv2d(mask.clone().float(), torch.ones(1, 1, k, k).cuda(), stride=s, padding=p, dilation=d)) 11 | else: 12 | raise NotImplementedError 13 | 14 | masked_weight[masked_weight > 0] = 1 / masked_weight[masked_weight > 0] #conv.kernel_size[0] * conv.kernel_size[1] 15 | padded_mask = masked_weight > 0 16 | 17 | return padded_mask, masked_weight 18 | 19 | from .vlbert import TLocVLBERT -------------------------------------------------------------------------------- /lib/models/bert_modules/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except AttributeError: 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | response = requests.head(url, allow_redirects=True) 192 | if response.status_code != 200: 193 | raise IOError("HEAD request failed for url {} with status code {}" 194 | .format(url, response.status_code)) 195 | etag = response.headers.get("ETag") 196 | 197 | filename = url_to_filename(url, etag) 198 | 199 | # get cache path to put the file 200 | cache_path = os.path.join(cache_dir, filename) 201 | 202 | if not os.path.exists(cache_path): 203 | # Download to temporary file, then copy to cache dir once finished. 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. 205 | with tempfile.NamedTemporaryFile() as temp_file: 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 207 | 208 | # GET file object 209 | if url.startswith("s3://"): 210 | s3_get(url, temp_file) 211 | else: 212 | http_get(url, temp_file) 213 | 214 | # we are copying the file before closing it, so flush to avoid truncation 215 | temp_file.flush() 216 | # shutil.copyfileobj() starts at the current position, so go to the start 217 | temp_file.seek(0) 218 | 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 220 | with open(cache_path, 'wb') as cache_file: 221 | shutil.copyfileobj(temp_file, cache_file) 222 | 223 | logger.info("creating metadata file for %s", cache_path) 224 | meta = {'url': url, 'etag': etag} 225 | meta_path = cache_path + '.json' 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 227 | json.dump(meta, meta_file) 228 | 229 | logger.info("removing temp file %s", temp_file.name) 230 | 231 | return cache_path 232 | 233 | 234 | def read_set_from_file(filename): 235 | ''' 236 | Extract a de-duped collection (set) of text from a file. 237 | Expected file format is one item per line. 238 | ''' 239 | collection = set() 240 | with open(filename, 'r', encoding='utf-8') as file_: 241 | for line in file_: 242 | collection.add(line.rstrip()) 243 | return collection 244 | 245 | 246 | def get_file_extension(path, dot=True, lower=True): 247 | ext = os.path.splitext(path)[1] 248 | ext = ext if dot else ext[1:] 249 | return ext.lower() if lower else ext -------------------------------------------------------------------------------- /lib/models/bert_modules/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import copy 21 | import json 22 | import logging 23 | import math 24 | import os 25 | import shutil 26 | import tarfile 27 | import tempfile 28 | import sys 29 | from io import open 30 | 31 | import torch 32 | from torch import nn 33 | from torch.nn import CrossEntropyLoss 34 | 35 | from .file_utils import cached_path 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | PRETRAINED_MODEL_ARCHIVE_MAP = { 40 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 41 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 42 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 43 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 44 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 45 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 46 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 47 | } 48 | 49 | CONFIG_NAME = 'bert_config.json' 50 | WEIGHTS_NAME = 'pytorch_model.bin' 51 | TF_WEIGHTS_NAME = 'model.ckpt' 52 | 53 | def load_tf_weights_in_bert(model, tf_checkpoint_path): 54 | """ Load tf checkpoints in a pytorch model 55 | """ 56 | try: 57 | import re 58 | import numpy as np 59 | import tensorflow as tf 60 | except ImportError: 61 | print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " 62 | "https://www.tensorflow.org/install/ for installation instructions.") 63 | raise 64 | tf_path = os.path.abspath(tf_checkpoint_path) 65 | print("Converting TensorFlow checkpoint from {}".format(tf_path)) 66 | # Load weights from TF model 67 | init_vars = tf.train.list_variables(tf_path) 68 | names = [] 69 | arrays = [] 70 | for name, shape in init_vars: 71 | print("Loading TF weight {} with shape {}".format(name, shape)) 72 | array = tf.train.load_variable(tf_path, name) 73 | names.append(name) 74 | arrays.append(array) 75 | 76 | for name, array in zip(names, arrays): 77 | name = name.split('/') 78 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 79 | # which are not required for using pretrained model 80 | if any(n in ["adam_v", "adam_m"] for n in name): 81 | print("Skipping {}".format("/".join(name))) 82 | continue 83 | pointer = model 84 | for m_name in name: 85 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 86 | l = re.split(r'_(\d+)', m_name) 87 | else: 88 | l = [m_name] 89 | if l[0] == 'kernel' or l[0] == 'gamma': 90 | pointer = getattr(pointer, 'weight') 91 | elif l[0] == 'output_bias' or l[0] == 'beta': 92 | pointer = getattr(pointer, 'bias') 93 | elif l[0] == 'output_weights': 94 | pointer = getattr(pointer, 'weight') 95 | else: 96 | pointer = getattr(pointer, l[0]) 97 | if len(l) >= 2: 98 | num = int(l[1]) 99 | pointer = pointer[num] 100 | if m_name[-11:] == '_embeddings': 101 | pointer = getattr(pointer, 'weight') 102 | elif m_name == 'kernel': 103 | array = np.transpose(array) 104 | try: 105 | assert pointer.shape == array.shape 106 | except AssertionError as e: 107 | e.args += (pointer.shape, array.shape) 108 | raise 109 | print("Initialize PyTorch weight {}".format(name)) 110 | pointer.data = torch.from_numpy(array) 111 | return model 112 | 113 | 114 | def gelu(x): 115 | """Implementation of the gelu activation function. 116 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 117 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 118 | Also see https://arxiv.org/abs/1606.08415 119 | """ 120 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 121 | 122 | 123 | def swish(x): 124 | return x * torch.sigmoid(x) 125 | 126 | 127 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 128 | 129 | 130 | class BertConfig(object): 131 | """Configuration class to store the configuration of a `BertModel`. 132 | """ 133 | def __init__(self, 134 | vocab_size_or_config_json_file, 135 | hidden_size=768, 136 | num_hidden_layers=12, 137 | num_attention_heads=12, 138 | intermediate_size=3072, 139 | hidden_act="gelu", 140 | hidden_dropout_prob=0.1, 141 | attention_probs_dropout_prob=0.1, 142 | max_position_embeddings=512, 143 | type_vocab_size=2, 144 | initializer_range=0.02): 145 | """Constructs BertConfig. 146 | Args: 147 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 148 | hidden_size: Size of the encoder layers and the pooler layer. 149 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 150 | num_attention_heads: Number of attention heads for each attention layer in 151 | the Transformer encoder. 152 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 153 | layer in the Transformer encoder. 154 | hidden_act: The non-linear activation function (function or string) in the 155 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 156 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 157 | layers in the embeddings, encoder, and pooler. 158 | attention_probs_dropout_prob: The dropout ratio for the attention 159 | probabilities. 160 | max_position_embeddings: The maximum sequence length that this model might 161 | ever be used with. Typically set this to something large just in case 162 | (e.g., 512 or 1024 or 2048). 163 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 164 | `BertModel`. 165 | initializer_range: The sttdev of the truncated_normal_initializer for 166 | initializing all weight matrices. 167 | """ 168 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 169 | and isinstance(vocab_size_or_config_json_file, unicode)): 170 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 171 | json_config = json.loads(reader.read()) 172 | for key, value in json_config.items(): 173 | self.__dict__[key] = value 174 | elif isinstance(vocab_size_or_config_json_file, int): 175 | self.vocab_size = vocab_size_or_config_json_file 176 | self.hidden_size = hidden_size 177 | self.num_hidden_layers = num_hidden_layers 178 | self.num_attention_heads = num_attention_heads 179 | self.hidden_act = hidden_act 180 | self.intermediate_size = intermediate_size 181 | self.hidden_dropout_prob = hidden_dropout_prob 182 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 183 | self.max_position_embeddings = max_position_embeddings 184 | self.type_vocab_size = type_vocab_size 185 | self.initializer_range = initializer_range 186 | else: 187 | raise ValueError("First argument must be either a vocabulary size (int)" 188 | "or the path to a pretrained model config file (str)") 189 | 190 | @classmethod 191 | def from_dict(cls, json_object): 192 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 193 | config = BertConfig(vocab_size_or_config_json_file=-1) 194 | for key, value in json_object.items(): 195 | config.__dict__[key] = value 196 | return config 197 | 198 | @classmethod 199 | def from_json_file(cls, json_file): 200 | """Constructs a `BertConfig` from a json file of parameters.""" 201 | with open(json_file, "r", encoding='utf-8') as reader: 202 | text = reader.read() 203 | return cls.from_dict(json.loads(text)) 204 | 205 | def __repr__(self): 206 | return str(self.to_json_string()) 207 | 208 | def to_dict(self): 209 | """Serializes this instance to a Python dictionary.""" 210 | output = copy.deepcopy(self.__dict__) 211 | return output 212 | 213 | def to_json_string(self): 214 | """Serializes this instance to a JSON string.""" 215 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 216 | 217 | try: 218 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 219 | except ImportError: 220 | print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") 221 | class BertLayerNorm(nn.Module): 222 | def __init__(self, hidden_size, eps=1e-12): 223 | """Construct a layernorm module in the TF style (epsilon inside the square root). 224 | """ 225 | super(BertLayerNorm, self).__init__() 226 | self.weight = nn.Parameter(torch.ones(hidden_size)) 227 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 228 | self.variance_epsilon = eps 229 | 230 | def forward(self, x): 231 | u = x.mean(-1, keepdim=True) 232 | s = (x - u).pow(2).mean(-1, keepdim=True) 233 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 234 | return self.weight * x + self.bias 235 | 236 | class BertEmbeddings(nn.Module): 237 | """Construct the embeddings from word, position and token_type embeddings. 238 | """ 239 | def __init__(self, config): 240 | super(BertEmbeddings, self).__init__() 241 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 242 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 243 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 244 | 245 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 246 | # any TensorFlow checkpoint file 247 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 248 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 249 | 250 | def forward(self, input_ids, token_type_ids=None): 251 | seq_length = input_ids.size(1) 252 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 253 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 254 | if token_type_ids is None: 255 | token_type_ids = torch.zeros_like(input_ids) 256 | 257 | words_embeddings = self.word_embeddings(input_ids) 258 | position_embeddings = self.position_embeddings(position_ids) 259 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 260 | 261 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 262 | embeddings = self.LayerNorm(embeddings) 263 | embeddings = self.dropout(embeddings) 264 | return embeddings 265 | 266 | 267 | class BertSelfAttention(nn.Module): 268 | def __init__(self, seq_type, config): 269 | super(BertSelfAttention, self).__init__() 270 | if config.hidden_size % config.num_attention_heads != 0: 271 | raise ValueError( 272 | "The hidden size (%d) is not a multiple of the number of attention " 273 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 274 | self.num_attention_heads = config.num_attention_heads 275 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 276 | self.all_head_size = self.num_attention_heads * self.attention_head_size 277 | 278 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 279 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 280 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 281 | 282 | self.query_other = nn.Linear(config.hidden_size, self.all_head_size) 283 | self.key_other = nn.Linear(config.hidden_size, self.all_head_size) 284 | self.value_other = nn.Linear(config.hidden_size, self.all_head_size) 285 | 286 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 287 | 288 | self.seq_type = seq_type 289 | 290 | def transpose_for_scores(self, x): 291 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 292 | x = x.view(*new_x_shape) 293 | return x.permute(0, 2, 1, 3) 294 | 295 | def forward(self, hidden_states, hidden_states_other, attention_mask, output_attention_probs=False): 296 | mixed_query_layer = self.query(hidden_states) 297 | mixed_key_layer = self.key(hidden_states) 298 | mixed_value_layer = self.value(hidden_states) 299 | 300 | query_layer = self.transpose_for_scores(mixed_query_layer) 301 | key_layer = self.transpose_for_scores(mixed_key_layer) 302 | value_layer = self.transpose_for_scores(mixed_value_layer) 303 | 304 | other_query_layer = self.query_other(hidden_states) 305 | other_key_layer = self.key_other(hidden_states_other) 306 | other_value_layer = self.value_other(hidden_states_other) 307 | 308 | other_query_layer = self.transpose_for_scores(other_query_layer) 309 | other_key_layer = self.transpose_for_scores(other_key_layer) 310 | other_value_layer = self.transpose_for_scores(other_value_layer) 311 | 312 | # Take the dot product between "query" and "key" to get the raw attention scores. 313 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 314 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 315 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 316 | 317 | other_attention_scores = torch.matmul(other_query_layer, other_key_layer.transpose(-1, -2)) 318 | other_attention_scores = other_attention_scores / math.sqrt(self.attention_head_size) 319 | 320 | if self.seq_type == 'TXT': 321 | attention_scores = attention_scores + attention_mask 322 | elif self.seq_type == 'VIS': 323 | other_attention_scores = other_attention_scores + attention_mask 324 | else: 325 | print('EROOR') 326 | exit() 327 | 328 | attention_scores = torch.cat([attention_scores, other_attention_scores], dim=-1) 329 | value_layer = torch.cat([value_layer, other_value_layer], dim=-2) 330 | 331 | # Normalize the attention scores to probabilities. 332 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 333 | # This is actually dropping out entire tokens to attend to, which might 334 | # seem a bit unusual, but is taken from the original Transformer paper. 335 | attention_probs = self.dropout(attention_probs) 336 | 337 | context_layer = torch.matmul(attention_probs, value_layer) 338 | 339 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 340 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 341 | context_layer = context_layer.view(*new_context_layer_shape) 342 | if output_attention_probs: 343 | return context_layer, attention_probs 344 | else: 345 | return context_layer 346 | 347 | 348 | class BertSelfOutput(nn.Module): 349 | def __init__(self, config): 350 | super(BertSelfOutput, self).__init__() 351 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 352 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 353 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 354 | 355 | def forward(self, hidden_states, input_tensor): 356 | hidden_states = self.dense(hidden_states) 357 | hidden_states = self.dropout(hidden_states) 358 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 359 | return hidden_states 360 | 361 | 362 | class BertAttention(nn.Module): 363 | def __init__(self, seq_type, config): 364 | super(BertAttention, self).__init__() 365 | self.self = BertSelfAttention(seq_type, config) 366 | self.output = BertSelfOutput(config) 367 | 368 | def forward(self, input_tensor, input_tensor_other, attention_mask, output_attention_probs=False): 369 | self_output = self.self(input_tensor, input_tensor_other, attention_mask, output_attention_probs=output_attention_probs) 370 | if output_attention_probs: 371 | self_output, attention_probs = self_output 372 | attention_output = self.output(self_output, input_tensor) 373 | if output_attention_probs: 374 | return attention_output, attention_probs 375 | return attention_output 376 | 377 | 378 | class BertIntermediate(nn.Module): 379 | def __init__(self, config): 380 | super(BertIntermediate, self).__init__() 381 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 382 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 383 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 384 | else: 385 | self.intermediate_act_fn = config.hidden_act 386 | 387 | def forward(self, hidden_states): 388 | hidden_states = self.dense(hidden_states) 389 | hidden_states = self.intermediate_act_fn(hidden_states) 390 | return hidden_states 391 | 392 | 393 | class BertOutput(nn.Module): 394 | def __init__(self, config): 395 | super(BertOutput, self).__init__() 396 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 397 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 398 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 399 | 400 | def forward(self, hidden_states, input_tensor): 401 | hidden_states = self.dense(hidden_states) 402 | hidden_states = self.dropout(hidden_states) 403 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 404 | return hidden_states 405 | 406 | 407 | class BertLayer(nn.Module): 408 | def __init__(self, config): 409 | super(BertLayer, self).__init__() 410 | self.attention_text = BertAttention('TXT', config) 411 | self.intermediate_text = BertIntermediate(config) 412 | self.output_text = BertOutput(config) 413 | 414 | self.attention_visual = BertAttention('VIS', config) 415 | self.intermediate_visual = BertIntermediate(config) 416 | self.output_visual = BertOutput(config) 417 | 418 | def forward(self, hidden_states, hidden_states_other, attention_mask, output_attention_probs=False): 419 | attention_output = self.attention_text(hidden_states, hidden_states_other, attention_mask, output_attention_probs=output_attention_probs) 420 | attention_output_other = self.attention_visual(hidden_states_other, hidden_states, attention_mask, output_attention_probs=output_attention_probs) 421 | if output_attention_probs: 422 | attention_output, attention_probs = attention_output 423 | attention_output_other, attention_probs_other = attention_output_other 424 | 425 | intermediate_output = self.intermediate_text(attention_output) 426 | layer_output = self.output_text(intermediate_output, attention_output) 427 | 428 | intermediate_output_other = self.intermediate_visual(attention_output_other) 429 | layer_output_other = self.output_visual(intermediate_output_other, attention_output_other) 430 | if output_attention_probs: 431 | return layer_output, layer_output_other, attention_probs, attention_probs_other 432 | else: 433 | return layer_output, layer_output_other 434 | 435 | 436 | class BertEncoder(nn.Module): 437 | def __init__(self, config): 438 | super(BertEncoder, self).__init__() 439 | layer = BertLayer(config) 440 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 441 | 442 | def forward(self, hidden_states, hidden_states_other, attention_mask, output_all_encoded_layers=False, output_attention_probs=False): 443 | all_encoder_layers = [] 444 | all_attention_probs = [] 445 | for layer_module in self.layer: 446 | layer_out = layer_module(hidden_states, hidden_states_other, attention_mask, output_attention_probs=output_attention_probs) 447 | if output_attention_probs: 448 | hidden_states, hidden_states_other, attention_probs, attention_probs_other = layer_out 449 | all_attention_probs.append([attention_probs, attention_probs_other]) 450 | else: 451 | hidden_states, hidden_states_other = layer_out 452 | if output_all_encoded_layers: 453 | all_encoder_layers.append([hidden_states, hidden_states_other]) 454 | if not output_all_encoded_layers: 455 | all_encoder_layers.append([hidden_states, hidden_states_other]) 456 | if output_attention_probs: 457 | return all_encoder_layers, all_attention_probs 458 | else: 459 | return all_encoder_layers 460 | 461 | 462 | class BertPooler(nn.Module): 463 | def __init__(self, config): 464 | super(BertPooler, self).__init__() 465 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 466 | self.activation = nn.Tanh() 467 | 468 | def forward(self, hidden_states): 469 | # We "pool" the model by simply taking the hidden state corresponding 470 | # to the first token. 471 | first_token_tensor = hidden_states[:, 0] 472 | pooled_output = self.dense(first_token_tensor) 473 | pooled_output = self.activation(pooled_output) 474 | return pooled_output 475 | 476 | 477 | class BertPredictionHeadTransform(nn.Module): 478 | def __init__(self, config): 479 | super(BertPredictionHeadTransform, self).__init__() 480 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 481 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 482 | self.transform_act_fn = ACT2FN[config.hidden_act] 483 | else: 484 | self.transform_act_fn = config.hidden_act 485 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 486 | 487 | def forward(self, hidden_states): 488 | hidden_states = self.dense(hidden_states) 489 | hidden_states = self.transform_act_fn(hidden_states) 490 | hidden_states = self.LayerNorm(hidden_states) 491 | return hidden_states 492 | 493 | 494 | class BertLMPredictionHead(nn.Module): 495 | def __init__(self, config, bert_model_embedding_weights): 496 | super(BertLMPredictionHead, self).__init__() 497 | self.transform = BertPredictionHeadTransform(config) 498 | 499 | # The output weights are the same as the input embeddings, but there is 500 | # an output-only bias for each token. 501 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 502 | bert_model_embedding_weights.size(0), 503 | bias=False) 504 | self.decoder.weight = bert_model_embedding_weights 505 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 506 | 507 | def forward(self, hidden_states): 508 | hidden_states = self.transform(hidden_states) 509 | hidden_states = self.decoder(hidden_states) + self.bias 510 | return hidden_states 511 | 512 | 513 | class BertOnlyMLMHead(nn.Module): 514 | def __init__(self, config, bert_model_embedding_weights): 515 | super(BertOnlyMLMHead, self).__init__() 516 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 517 | 518 | def forward(self, sequence_output): 519 | prediction_scores = self.predictions(sequence_output) 520 | return prediction_scores 521 | 522 | 523 | class BertOnlyNSPHead(nn.Module): 524 | def __init__(self, config): 525 | super(BertOnlyNSPHead, self).__init__() 526 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 527 | 528 | def forward(self, pooled_output): 529 | seq_relationship_score = self.seq_relationship(pooled_output) 530 | return seq_relationship_score 531 | 532 | 533 | class BertPreTrainingHeads(nn.Module): 534 | def __init__(self, config, bert_model_embedding_weights): 535 | super(BertPreTrainingHeads, self).__init__() 536 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 537 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 538 | 539 | def forward(self, sequence_output, pooled_output): 540 | prediction_scores = self.predictions(sequence_output) 541 | seq_relationship_score = self.seq_relationship(pooled_output) 542 | return prediction_scores, seq_relationship_score 543 | 544 | 545 | class BertPreTrainedModel(nn.Module): 546 | """ An abstract class to handle weights initialization and 547 | a simple interface for dowloading and loading pretrained models. 548 | """ 549 | def __init__(self, config, *inputs, **kwargs): 550 | super(BertPreTrainedModel, self).__init__() 551 | if not isinstance(config, BertConfig): 552 | raise ValueError( 553 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 554 | "To create a model from a Google pretrained model use " 555 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 556 | self.__class__.__name__, self.__class__.__name__ 557 | )) 558 | self.config = config 559 | 560 | def init_bert_weights(self, module): 561 | """ Initialize the weights. 562 | """ 563 | if isinstance(module, (nn.Linear, nn.Embedding)): 564 | # Slightly different from the TF version which uses truncated_normal for initialization 565 | # cf https://github.com/pytorch/pytorch/pull/5617 566 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 567 | elif isinstance(module, BertLayerNorm): 568 | module.bias.data.zero_() 569 | module.weight.data.fill_(1.0) 570 | if isinstance(module, nn.Linear) and module.bias is not None: 571 | module.bias.data.zero_() 572 | 573 | @classmethod 574 | def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, 575 | from_tf=False, *inputs, **kwargs): 576 | """ 577 | Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. 578 | Download and cache the pre-trained model file if needed. 579 | Params: 580 | pretrained_model_name_or_path: either: 581 | - a str with the name of a pre-trained model to load selected in the list of: 582 | . `bert-base-uncased` 583 | . `bert-large-uncased` 584 | . `bert-base-cased` 585 | . `bert-large-cased` 586 | . `bert-base-multilingual-uncased` 587 | . `bert-base-multilingual-cased` 588 | . `bert-base-chinese` 589 | - a path or url to a pretrained model archive containing: 590 | . `bert_config.json` a configuration file for the model 591 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 592 | - a path or url to a pretrained model archive containing: 593 | . `bert_config.json` a configuration file for the model 594 | . `model.chkpt` a TensorFlow checkpoint 595 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 596 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 597 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 598 | *inputs, **kwargs: additional input for the specific Bert class 599 | (ex: num_labels for BertForSequenceClassification) 600 | """ 601 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: 602 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] 603 | else: 604 | archive_file = pretrained_model_name_or_path 605 | # redirect to the cache, if necessary 606 | try: 607 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 608 | except EnvironmentError: 609 | logger.error( 610 | "Model name '{}' was not found in model name list ({}). " 611 | "We assumed '{}' was a path or url but couldn't find any file " 612 | "associated to this path or url.".format( 613 | pretrained_model_name_or_path, 614 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), 615 | archive_file)) 616 | return None 617 | if resolved_archive_file == archive_file: 618 | logger.info("loading archive file {}".format(archive_file)) 619 | else: 620 | logger.info("loading archive file {} from cache at {}".format( 621 | archive_file, resolved_archive_file)) 622 | tempdir = None 623 | if os.path.isdir(resolved_archive_file) or from_tf: 624 | serialization_dir = resolved_archive_file 625 | else: 626 | # Extract archive to temp dir 627 | tempdir = tempfile.mkdtemp() 628 | logger.info("extracting archive file {} to temp dir {}".format( 629 | resolved_archive_file, tempdir)) 630 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: 631 | archive.extractall(tempdir) 632 | serialization_dir = tempdir 633 | # Load config 634 | config_file = os.path.join(serialization_dir, CONFIG_NAME) 635 | config = BertConfig.from_json_file(config_file) 636 | logger.info("Model config {}".format(config)) 637 | # Instantiate model. 638 | model = cls(config, *inputs, **kwargs) 639 | if state_dict is None and not from_tf: 640 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) 641 | state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None) 642 | if tempdir: 643 | # Clean up temp dir 644 | shutil.rmtree(tempdir) 645 | if from_tf: 646 | # Directly load from a TensorFlow checkpoint 647 | weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) 648 | return load_tf_weights_in_bert(model, weights_path) 649 | # Load from a PyTorch state_dict 650 | old_keys = [] 651 | new_keys = [] 652 | for key in state_dict.keys(): 653 | new_key = None 654 | if 'gamma' in key: 655 | new_key = key.replace('gamma', 'weight') 656 | if 'beta' in key: 657 | new_key = key.replace('beta', 'bias') 658 | if new_key: 659 | old_keys.append(key) 660 | new_keys.append(new_key) 661 | for old_key, new_key in zip(old_keys, new_keys): 662 | state_dict[new_key] = state_dict.pop(old_key) 663 | 664 | missing_keys = [] 665 | unexpected_keys = [] 666 | error_msgs = [] 667 | # copy state_dict so _load_from_state_dict can modify it 668 | metadata = getattr(state_dict, '_metadata', None) 669 | state_dict = state_dict.copy() 670 | if metadata is not None: 671 | state_dict._metadata = metadata 672 | 673 | def load(module, prefix=''): 674 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 675 | module._load_from_state_dict( 676 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 677 | for name, child in module._modules.items(): 678 | if child is not None: 679 | load(child, prefix + name + '.') 680 | start_prefix = '' 681 | if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): 682 | start_prefix = 'bert.' 683 | load(model, prefix=start_prefix) 684 | if len(missing_keys) > 0: 685 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 686 | model.__class__.__name__, missing_keys)) 687 | if len(unexpected_keys) > 0: 688 | logger.info("Weights from pretrained model not used in {}: {}".format( 689 | model.__class__.__name__, unexpected_keys)) 690 | if len(error_msgs) > 0: 691 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 692 | model.__class__.__name__, "\n\t".join(error_msgs))) 693 | return model 694 | 695 | 696 | class BertModel(BertPreTrainedModel): 697 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 698 | Params: 699 | config: a BertConfig class instance with the configuration to build a new model 700 | Inputs: 701 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 702 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 703 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 704 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 705 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 706 | a `sentence B` token (see BERT paper for more details). 707 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 708 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 709 | input sequence length in the current batch. It's the mask that we typically use for attention when 710 | a batch has varying length sentences. 711 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 712 | Outputs: Tuple of (encoded_layers, pooled_output) 713 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 714 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 715 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 716 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 717 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 718 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 719 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 720 | classifier pretrained on top of the hidden state associated to the first character of the 721 | input (`CLS`) to train on the Next-Sentence task (see BERT's paper). 722 | Example usage: 723 | ```python 724 | # Already been converted into WordPiece token ids 725 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 726 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 727 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 728 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 729 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 730 | model = modeling.BertModel(config=config) 731 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 732 | ``` 733 | """ 734 | def __init__(self, config): 735 | super(BertModel, self).__init__(config) 736 | self.embeddings = BertEmbeddings(config) 737 | self.encoder = BertEncoder(config) 738 | self.pooler = BertPooler(config) 739 | self.apply(self.init_bert_weights) 740 | 741 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): 742 | if attention_mask is None: 743 | attention_mask = torch.ones_like(input_ids) 744 | if token_type_ids is None: 745 | token_type_ids = torch.zeros_like(input_ids) 746 | 747 | # We create a 3D attention mask from a 2D tensor mask. 748 | # Sizes are [batch_size, 1, 1, to_seq_length] 749 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 750 | # this attention mask is more simple than the triangular masking of causal attention 751 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 752 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 753 | 754 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 755 | # masked positions, this operation will create a tensor which is 0.0 for 756 | # positions we want to attend and -10000.0 for masked positions. 757 | # Since we are adding it to the raw scores before the softmax, this is 758 | # effectively the same as removing these entirely. 759 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 760 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 761 | 762 | embedding_output = self.embeddings(input_ids, token_type_ids) 763 | encoded_layers = self.encoder(embedding_output, 764 | extended_attention_mask, 765 | output_all_encoded_layers=output_all_encoded_layers) 766 | sequence_output = encoded_layers[-1] 767 | pooled_output = self.pooler(sequence_output) 768 | if not output_all_encoded_layers: 769 | encoded_layers = encoded_layers[-1] 770 | return encoded_layers, pooled_output 771 | 772 | 773 | class BertForPreTraining(BertPreTrainedModel): 774 | """BERT model with pre-training heads. 775 | This module comprises the BERT model followed by the two pre-training heads: 776 | - the masked language modeling head, and 777 | - the next sentence classification head. 778 | Params: 779 | config: a BertConfig class instance with the configuration to build a new model. 780 | Inputs: 781 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 782 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 783 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 784 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 785 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 786 | a `sentence B` token (see BERT paper for more details). 787 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 788 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 789 | input sequence length in the current batch. It's the mask that we typically use for attention when 790 | a batch has varying length sentences. 791 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 792 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 793 | is only computed for the labels set in [0, ..., vocab_size] 794 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] 795 | with indices selected in [0, 1]. 796 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 797 | Outputs: 798 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 799 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 800 | sentence classification loss. 801 | if `masked_lm_labels` or `next_sentence_label` is `None`: 802 | Outputs a tuple comprising 803 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 804 | - the next sentence classification logits of shape [batch_size, 2]. 805 | Example usage: 806 | ```python 807 | # Already been converted into WordPiece token ids 808 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 809 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 810 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 811 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 812 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 813 | model = BertForPreTraining(config) 814 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 815 | ``` 816 | """ 817 | def __init__(self, config): 818 | super(BertForPreTraining, self).__init__(config) 819 | self.bert = BertModel(config) 820 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 821 | self.apply(self.init_bert_weights) 822 | 823 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): 824 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 825 | output_all_encoded_layers=False) 826 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 827 | 828 | if masked_lm_labels is not None and next_sentence_label is not None: 829 | loss_fct = CrossEntropyLoss(ignore_index=-1) 830 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 831 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 832 | total_loss = masked_lm_loss + next_sentence_loss 833 | return total_loss 834 | else: 835 | return prediction_scores, seq_relationship_score 836 | 837 | 838 | class BertForMaskedLM(BertPreTrainedModel): 839 | """BERT model with the masked language modeling head. 840 | This module comprises the BERT model followed by the masked language modeling head. 841 | Params: 842 | config: a BertConfig class instance with the configuration to build a new model. 843 | Inputs: 844 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 845 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 846 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 847 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 848 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 849 | a `sentence B` token (see BERT paper for more details). 850 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 851 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 852 | input sequence length in the current batch. It's the mask that we typically use for attention when 853 | a batch has varying length sentences. 854 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 855 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 856 | is only computed for the labels set in [0, ..., vocab_size] 857 | Outputs: 858 | if `masked_lm_labels` is not `None`: 859 | Outputs the masked language modeling loss. 860 | if `masked_lm_labels` is `None`: 861 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 862 | Example usage: 863 | ```python 864 | # Already been converted into WordPiece token ids 865 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 866 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 867 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 868 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 869 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 870 | model = BertForMaskedLM(config) 871 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 872 | ``` 873 | """ 874 | def __init__(self, config): 875 | super(BertForMaskedLM, self).__init__(config) 876 | self.bert = BertModel(config) 877 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 878 | self.apply(self.init_bert_weights) 879 | 880 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): 881 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, 882 | output_all_encoded_layers=False) 883 | prediction_scores = self.cls(sequence_output) 884 | 885 | if masked_lm_labels is not None: 886 | loss_fct = CrossEntropyLoss(ignore_index=-1) 887 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 888 | return masked_lm_loss 889 | else: 890 | return prediction_scores 891 | 892 | 893 | class BertForNextSentencePrediction(BertPreTrainedModel): 894 | """BERT model with next sentence prediction head. 895 | This module comprises the BERT model followed by the next sentence classification head. 896 | Params: 897 | config: a BertConfig class instance with the configuration to build a new model. 898 | Inputs: 899 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 900 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 901 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 902 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 903 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 904 | a `sentence B` token (see BERT paper for more details). 905 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 906 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 907 | input sequence length in the current batch. It's the mask that we typically use for attention when 908 | a batch has varying length sentences. 909 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 910 | with indices selected in [0, 1]. 911 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 912 | Outputs: 913 | if `next_sentence_label` is not `None`: 914 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 915 | sentence classification loss. 916 | if `next_sentence_label` is `None`: 917 | Outputs the next sentence classification logits of shape [batch_size, 2]. 918 | Example usage: 919 | ```python 920 | # Already been converted into WordPiece token ids 921 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 922 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 923 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 924 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 925 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 926 | model = BertForNextSentencePrediction(config) 927 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 928 | ``` 929 | """ 930 | def __init__(self, config): 931 | super(BertForNextSentencePrediction, self).__init__(config) 932 | self.bert = BertModel(config) 933 | self.cls = BertOnlyNSPHead(config) 934 | self.apply(self.init_bert_weights) 935 | 936 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): 937 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 938 | output_all_encoded_layers=False) 939 | seq_relationship_score = self.cls( pooled_output) 940 | 941 | if next_sentence_label is not None: 942 | loss_fct = CrossEntropyLoss(ignore_index=-1) 943 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 944 | return next_sentence_loss 945 | else: 946 | return seq_relationship_score 947 | 948 | 949 | class BertForSequenceClassification(BertPreTrainedModel): 950 | """BERT model for classification. 951 | This module is composed of the BERT model with a linear layer on top of 952 | the pooled output. 953 | Params: 954 | `config`: a BertConfig class instance with the configuration to build a new model. 955 | `num_labels`: the number of classes for the classifier. Default = 2. 956 | Inputs: 957 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 958 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 959 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 960 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 961 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 962 | a `sentence B` token (see BERT paper for more details). 963 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 964 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 965 | input sequence length in the current batch. It's the mask that we typically use for attention when 966 | a batch has varying length sentences. 967 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 968 | with indices selected in [0, ..., num_labels]. 969 | Outputs: 970 | if `labels` is not `None`: 971 | Outputs the CrossEntropy classification loss of the output with the labels. 972 | if `labels` is `None`: 973 | Outputs the classification logits of shape [batch_size, num_labels]. 974 | Example usage: 975 | ```python 976 | # Already been converted into WordPiece token ids 977 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 978 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 979 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 980 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 981 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 982 | num_labels = 2 983 | model = BertForSequenceClassification(config, num_labels) 984 | logits = model(input_ids, token_type_ids, input_mask) 985 | ``` 986 | """ 987 | def __init__(self, config, num_labels): 988 | super(BertForSequenceClassification, self).__init__(config) 989 | self.num_labels = num_labels 990 | self.bert = BertModel(config) 991 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 992 | self.classifier = nn.Linear(config.hidden_size, num_labels) 993 | self.apply(self.init_bert_weights) 994 | 995 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 996 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 997 | pooled_output = self.dropout(pooled_output) 998 | logits = self.classifier(pooled_output) 999 | 1000 | if labels is not None: 1001 | loss_fct = CrossEntropyLoss() 1002 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1003 | return loss 1004 | else: 1005 | return logits 1006 | 1007 | 1008 | class BertForMultipleChoice(BertPreTrainedModel): 1009 | """BERT model for multiple choice tasks. 1010 | This module is composed of the BERT model with a linear layer on top of 1011 | the pooled output. 1012 | Params: 1013 | `config`: a BertConfig class instance with the configuration to build a new model. 1014 | `num_choices`: the number of classes for the classifier. Default = 2. 1015 | Inputs: 1016 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1017 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1018 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1019 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1020 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 1021 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 1022 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 1023 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1024 | input sequence length in the current batch. It's the mask that we typically use for attention when 1025 | a batch has varying length sentences. 1026 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1027 | with indices selected in [0, ..., num_choices]. 1028 | Outputs: 1029 | if `labels` is not `None`: 1030 | Outputs the CrossEntropy classification loss of the output with the labels. 1031 | if `labels` is `None`: 1032 | Outputs the classification logits of shape [batch_size, num_labels]. 1033 | Example usage: 1034 | ```python 1035 | # Already been converted into WordPiece token ids 1036 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 1037 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 1038 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 1039 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1040 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1041 | num_choices = 2 1042 | model = BertForMultipleChoice(config, num_choices) 1043 | logits = model(input_ids, token_type_ids, input_mask) 1044 | ``` 1045 | """ 1046 | def __init__(self, config, num_choices): 1047 | super(BertForMultipleChoice, self).__init__(config) 1048 | self.num_choices = num_choices 1049 | self.bert = BertModel(config) 1050 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1051 | self.classifier = nn.Linear(config.hidden_size, 1) 1052 | self.apply(self.init_bert_weights) 1053 | 1054 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1055 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 1056 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 1057 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 1058 | _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) 1059 | pooled_output = self.dropout(pooled_output) 1060 | logits = self.classifier(pooled_output) 1061 | reshaped_logits = logits.view(-1, self.num_choices) 1062 | 1063 | if labels is not None: 1064 | loss_fct = CrossEntropyLoss() 1065 | loss = loss_fct(reshaped_logits, labels) 1066 | return loss 1067 | else: 1068 | return reshaped_logits 1069 | 1070 | 1071 | class BertForTokenClassification(BertPreTrainedModel): 1072 | """BERT model for token-level classification. 1073 | This module is composed of the BERT model with a linear layer on top of 1074 | the full hidden state of the last layer. 1075 | Params: 1076 | `config`: a BertConfig class instance with the configuration to build a new model. 1077 | `num_labels`: the number of classes for the classifier. Default = 2. 1078 | Inputs: 1079 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1080 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1081 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1082 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1083 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1084 | a `sentence B` token (see BERT paper for more details). 1085 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1086 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1087 | input sequence length in the current batch. It's the mask that we typically use for attention when 1088 | a batch has varying length sentences. 1089 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] 1090 | with indices selected in [0, ..., num_labels]. 1091 | Outputs: 1092 | if `labels` is not `None`: 1093 | Outputs the CrossEntropy classification loss of the output with the labels. 1094 | if `labels` is `None`: 1095 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 1096 | Example usage: 1097 | ```python 1098 | # Already been converted into WordPiece token ids 1099 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1100 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1101 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1102 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1103 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1104 | num_labels = 2 1105 | model = BertForTokenClassification(config, num_labels) 1106 | logits = model(input_ids, token_type_ids, input_mask) 1107 | ``` 1108 | """ 1109 | def __init__(self, config, num_labels): 1110 | super(BertForTokenClassification, self).__init__(config) 1111 | self.num_labels = num_labels 1112 | self.bert = BertModel(config) 1113 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1114 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1115 | self.apply(self.init_bert_weights) 1116 | 1117 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 1118 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1119 | sequence_output = self.dropout(sequence_output) 1120 | logits = self.classifier(sequence_output) 1121 | 1122 | if labels is not None: 1123 | loss_fct = CrossEntropyLoss() 1124 | # Only keep active parts of the loss 1125 | if attention_mask is not None: 1126 | active_loss = attention_mask.view(-1) == 1 1127 | active_logits = logits.view(-1, self.num_labels)[active_loss] 1128 | active_labels = labels.view(-1)[active_loss] 1129 | loss = loss_fct(active_logits, active_labels) 1130 | else: 1131 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1132 | return loss 1133 | else: 1134 | return logits 1135 | 1136 | 1137 | class BertForQuestionAnswering(BertPreTrainedModel): 1138 | """BERT model for Question Answering (span extraction). 1139 | This module is composed of the BERT model with a linear layer on top of 1140 | the sequence output that computes start_logits and end_logits 1141 | Params: 1142 | `config`: a BertConfig class instance with the configuration to build a new model. 1143 | Inputs: 1144 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1145 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1146 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1147 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1148 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1149 | a `sentence B` token (see BERT paper for more details). 1150 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1151 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1152 | input sequence length in the current batch. It's the mask that we typically use for attention when 1153 | a batch has varying length sentences. 1154 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 1155 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1156 | into account for computing the loss. 1157 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 1158 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1159 | into account for computing the loss. 1160 | Outputs: 1161 | if `start_positions` and `end_positions` are not `None`: 1162 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 1163 | if `start_positions` or `end_positions` is `None`: 1164 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 1165 | position tokens of shape [batch_size, sequence_length]. 1166 | Example usage: 1167 | ```python 1168 | # Already been converted into WordPiece token ids 1169 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1170 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1171 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1172 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1173 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1174 | model = BertForQuestionAnswering(config) 1175 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 1176 | ``` 1177 | """ 1178 | def __init__(self, config): 1179 | super(BertForQuestionAnswering, self).__init__(config) 1180 | self.bert = BertModel(config) 1181 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 1182 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 1183 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1184 | self.apply(self.init_bert_weights) 1185 | 1186 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 1187 | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 1188 | logits = self.qa_outputs(sequence_output) 1189 | start_logits, end_logits = logits.split(1, dim=-1) 1190 | start_logits = start_logits.squeeze(-1) 1191 | end_logits = end_logits.squeeze(-1) 1192 | 1193 | if start_positions is not None and end_positions is not None: 1194 | # If we are on multi-GPU, split add a dimension 1195 | if len(start_positions.size()) > 1: 1196 | start_positions = start_positions.squeeze(-1) 1197 | if len(end_positions.size()) > 1: 1198 | end_positions = end_positions.squeeze(-1) 1199 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1200 | ignored_index = start_logits.size(1) 1201 | start_positions.clamp_(0, ignored_index) 1202 | end_positions.clamp_(0, ignored_index) 1203 | 1204 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1205 | start_loss = loss_fct(start_logits, start_positions) 1206 | end_loss = loss_fct(end_logits, end_positions) 1207 | total_loss = (start_loss + end_loss) / 2 1208 | return total_loss 1209 | else: 1210 | return start_logits, end_logits 1211 | -------------------------------------------------------------------------------- /lib/models/bert_modules/visual_linguistic_bert.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | from .modeling import BertLayerNorm, BertEncoder, BertPooler, ACT2FN, BertOnlyMLMHead 5 | import numpy as np 6 | import math 7 | 8 | # todo: add this to config 9 | # NUM_SPECIAL_WORDS = 1000 10 | 11 | class PositionalEncoding(nn.Module): 12 | 13 | def __init__(self, d_hid, n_position=116): 14 | super(PositionalEncoding, self).__init__() 15 | 16 | # Not a parameter 17 | self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) 18 | 19 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 20 | ''' Sinusoid position encoding table ''' 21 | # TODO: make it with torch instead of numpy 22 | 23 | def get_position_angle_vec(position): 24 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 25 | 26 | sinusoid_table = np.array([get_position_angle_vec(pos_i / n_position) for pos_i in range(n_position)]) 27 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2] * 2 * math.pi) # dim 2i 28 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2] * 2 * math.pi) # dim 2i+1 29 | 30 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 31 | 32 | def forward(self, x): 33 | return x + self.pos_table[:, :x.size(1)].clone().detach() 34 | 35 | class BaseModel(nn.Module): 36 | def __init__(self, config, **kwargs): 37 | self.config = config 38 | super(BaseModel, self).__init__() 39 | 40 | def init_weights(self, module): 41 | """ Initialize the weights. 42 | """ 43 | if isinstance(module, (nn.Linear, nn.Embedding)): 44 | # Slightly different from the TF version which uses truncated_normal for initialization 45 | # cf https://github.com/pytorch/pytorch/pull/5617 46 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 47 | elif isinstance(module, BertLayerNorm): 48 | module.bias.data.zero_() 49 | module.weight.data.fill_(1.0) 50 | if isinstance(module, nn.Linear) and module.bias is not None: 51 | module.bias.data.zero_() 52 | 53 | def forward(self, *args, **kwargs): 54 | raise NotImplemented 55 | 56 | 57 | class VisualLinguisticBert(BaseModel): 58 | def __init__(self, dataset, config, language_pretrained_model_path=None): 59 | super(VisualLinguisticBert, self).__init__(config) 60 | 61 | self.config = config 62 | 63 | # embeddings 64 | self.mask_embeddings = nn.Embedding(1, config.hidden_size) 65 | self.word_mapping = nn.Linear(300, config.hidden_size) # 300 is the dim of glove vector 66 | self.text_embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 67 | self.text_embedding_dropout = nn.Dropout(config.hidden_dropout_prob) 68 | self.visual_embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 69 | self.visual_embedding_dropout = nn.Dropout(config.hidden_dropout_prob) 70 | 71 | if dataset == "ActivityNet": 72 | self.postion_encoding = PositionalEncoding(config.hidden_size, n_position=116) 73 | elif dataset == "TACoS": 74 | self.postion_encoding = PositionalEncoding(config.hidden_size, n_position=194) 75 | else: 76 | print('DATASET ERROR') 77 | exit() 78 | 79 | # visual transform 80 | self.visual_1x1_text = None 81 | self.visual_1x1_object = None 82 | if config.visual_size != config.hidden_size: 83 | self.visual_1x1_text = nn.Linear(config.visual_size, config.hidden_size) 84 | self.visual_1x1_object = nn.Linear(config.visual_size, config.hidden_size) 85 | if config.visual_ln: 86 | self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12) 87 | self.visual_ln_object = BertLayerNorm(config.hidden_size, eps=1e-12) 88 | 89 | self.encoder = BertEncoder(config) 90 | 91 | # init weights 92 | self.apply(self.init_weights) 93 | if config.visual_ln: 94 | self.visual_ln_text.weight.data.fill_(self.config.visual_scale_text_init) 95 | self.visual_ln_object.weight.data.fill_(self.config.visual_scale_object_init) 96 | 97 | # load language pretrained model 98 | if language_pretrained_model_path is not None: 99 | print('load language pretrained model') 100 | self.load_language_pretrained_model(language_pretrained_model_path) 101 | 102 | def forward(self, 103 | text_input_feats, 104 | text_mask, 105 | word_mask, 106 | object_visual_embeddings, 107 | output_all_encoded_layers=False, 108 | output_attention_probs=False): 109 | 110 | # get seamless concatenate embeddings and mask 111 | text_embeddings, visual_embeddings = self.embedding(text_input_feats, 112 | text_mask, word_mask, 113 | object_visual_embeddings) 114 | 115 | # We create a 3D attention mask from a 2D tensor mask. 116 | # Sizes are [batch_size, 1, 1, to_seq_length] 117 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 118 | # this attention mask is more simple than the triangular masking of causal attention 119 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 120 | extended_attention_mask = text_mask.unsqueeze(1).unsqueeze(2) 121 | 122 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 123 | # masked positions, this operation will create a tensor which is 0.0 for 124 | # positions we want to attend and -10000.0 for masked positions. 125 | # Since we are adding it to the raw scores before the softmax, this is 126 | # effectively the same as removing these entirely. 127 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 128 | extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0 129 | # extended_attention_mask = 1.0 - extended_attention_mask 130 | # extended_attention_mask[extended_attention_mask != 0] = float('-inf') 131 | 132 | if output_attention_probs: 133 | encoded_layers, attention_probs = self.encoder(text_embeddings, 134 | visual_embeddings, 135 | extended_attention_mask, 136 | output_all_encoded_layers=output_all_encoded_layers, 137 | output_attention_probs=output_attention_probs) 138 | else: 139 | encoded_layers = self.encoder(text_embeddings, 140 | visual_embeddings, 141 | extended_attention_mask, 142 | output_all_encoded_layers=output_all_encoded_layers, 143 | output_attention_probs=output_attention_probs) 144 | 145 | # sequence_output = encoded_layers[-1] 146 | # pooled_output = self.pooler(sequence_output) if self.config.with_pooler else None 147 | if output_all_encoded_layers: 148 | encoded_layers_text = [] 149 | encoded_layers_object = [] 150 | for encoded_layer in encoded_layers: 151 | encoded_layers_text.append(encoded_layer[0]) 152 | encoded_layers_object.append(encoded_layer[1]) 153 | if output_attention_probs: 154 | attention_probs_text = [] 155 | attention_probs_object = [] 156 | for attention_prob in attention_probs: 157 | attention_probs_text.append(attention_prob[0]) 158 | attention_probs_object.append(attention_prob[1]) 159 | return encoded_layers_text, encoded_layers_object, attention_probs_text, attention_probs_object 160 | else: 161 | return encoded_layers_text, encoded_layers_object 162 | else: 163 | encoded_layers = encoded_layers[-1] 164 | if output_attention_probs: 165 | attention_probs = attention_probs[-1] 166 | return encoded_layers[0], encoded_layers[1], attention_probs[0], attention_probs[1] 167 | else: 168 | return encoded_layers[0], encoded_layers[1] 169 | 170 | def embedding(self, 171 | text_input_feats, 172 | text_mask, 173 | word_mask, 174 | object_visual_embeddings): 175 | 176 | text_linguistic_embedding = self.word_mapping(text_input_feats) 177 | text_input_feats_temp = text_input_feats.clone() 178 | mask_word_mean = text_mask 179 | if self.training: 180 | text_input_feats_temp[word_mask>0] = 0 181 | mask_word_mean = text_mask * (1. - word_mask) 182 | _zero_id = torch.zeros(text_linguistic_embedding.shape[:2], dtype=torch.long, device=text_linguistic_embedding.device) 183 | text_linguistic_embedding[word_mask>0] = self.mask_embeddings(_zero_id)[word_mask>0] 184 | 185 | if self.visual_1x1_object is not None: 186 | object_visual_embeddings = self.visual_1x1_object(object_visual_embeddings) 187 | if self.config.visual_ln: 188 | object_visual_embeddings = self.visual_ln_object(object_visual_embeddings) 189 | 190 | embeddings = torch.cat([object_visual_embeddings, text_linguistic_embedding], dim=1) 191 | embeddings = self.postion_encoding(embeddings) 192 | visual_embeddings, text_embeddings = torch.split(embeddings, [object_visual_embeddings.size(1),text_linguistic_embedding.size(1)], 1) 193 | 194 | text_embeddings = self.text_embedding_LayerNorm(text_embeddings) 195 | text_embeddings = self.text_embedding_dropout(text_embeddings) 196 | 197 | visual_embeddings = self.visual_embedding_LayerNorm(visual_embeddings) 198 | visual_embeddings = self.visual_embedding_dropout(visual_embeddings) 199 | 200 | return text_embeddings, visual_embeddings 201 | 202 | def load_language_pretrained_model(self, language_pretrained_model_path): 203 | pretrained_state_dict = torch.load(language_pretrained_model_path, map_location=lambda storage, loc: storage) 204 | encoder_pretrained_state_dict = {} 205 | pooler_pretrained_state_dict = {} 206 | embedding_ln_pretrained_state_dict = {} 207 | unexpected_keys = [] 208 | for k, v in pretrained_state_dict.items(): 209 | if k.startswith('bert.'): 210 | k = k[len('bert.'):] 211 | elif k.startswith('roberta.'): 212 | k = k[len('roberta.'):] 213 | else: 214 | unexpected_keys.append(k) 215 | continue 216 | if 'gamma' in k: 217 | k = k.replace('gamma', 'weight') 218 | if 'beta' in k: 219 | k = k.replace('beta', 'bias') 220 | if k.startswith('encoder.'): 221 | k_ = k[len('encoder.'):] 222 | if k_ in self.encoder.state_dict(): 223 | encoder_pretrained_state_dict[k_] = v 224 | else: 225 | unexpected_keys.append(k) 226 | elif k.startswith('embeddings.'): 227 | k_ = k[len('embeddings.'):] 228 | if k_ == 'word_embeddings.weight': 229 | self.word_embeddings.weight.data = v.to(dtype=self.word_embeddings.weight.data.dtype, 230 | device=self.word_embeddings.weight.data.device) 231 | elif k_ == 'position_embeddings.weight': 232 | self.position_embeddings.weight.data = v.to(dtype=self.position_embeddings.weight.data.dtype, 233 | device=self.position_embeddings.weight.data.device) 234 | elif k_ == 'token_type_embeddings.weight': 235 | self.token_type_embeddings.weight.data[:v.size(0)] = v.to( 236 | dtype=self.token_type_embeddings.weight.data.dtype, 237 | device=self.token_type_embeddings.weight.data.device) 238 | if v.size(0) == 1: 239 | # Todo: roberta token type embedding 240 | self.token_type_embeddings.weight.data[1] = v[0].clone().to( 241 | dtype=self.token_type_embeddings.weight.data.dtype, 242 | device=self.token_type_embeddings.weight.data.device) 243 | self.token_type_embeddings.weight.data[2] = v[0].clone().to( 244 | dtype=self.token_type_embeddings.weight.data.dtype, 245 | device=self.token_type_embeddings.weight.data.device) 246 | 247 | elif k_.startswith('LayerNorm.'): 248 | k__ = k_[len('LayerNorm.'):] 249 | if k__ in self.embedding_LayerNorm.state_dict(): 250 | embedding_ln_pretrained_state_dict[k__] = v 251 | else: 252 | unexpected_keys.append(k) 253 | else: 254 | unexpected_keys.append(k) 255 | elif self.config.with_pooler and k.startswith('pooler.'): 256 | k_ = k[len('pooler.'):] 257 | if k_ in self.pooler.state_dict(): 258 | pooler_pretrained_state_dict[k_] = v 259 | else: 260 | unexpected_keys.append(k) 261 | else: 262 | unexpected_keys.append(k) 263 | if len(unexpected_keys) > 0: 264 | print("Warnings: Unexpected keys: {}.".format(unexpected_keys)) 265 | self.embedding_LayerNorm.load_state_dict(embedding_ln_pretrained_state_dict) 266 | self.encoder.load_state_dict(encoder_pretrained_state_dict) 267 | if self.config.with_pooler and len(pooler_pretrained_state_dict) > 0: 268 | self.pooler.load_state_dict(pooler_pretrained_state_dict) 269 | 270 | 271 | class VisualLinguisticBertForPretraining(VisualLinguisticBert): 272 | def __init__(self, config, language_pretrained_model_path=None, 273 | with_rel_head=True, with_mlm_head=True, with_mvrc_head=True): 274 | 275 | super(VisualLinguisticBertForPretraining, self).__init__(config, language_pretrained_model_path=None) 276 | 277 | self.with_rel_head = with_rel_head 278 | self.with_mlm_head = with_mlm_head 279 | self.with_mvrc_head = with_mvrc_head 280 | if with_rel_head: 281 | self.relationsip_head = VisualLinguisticBertRelationshipPredictionHead(config) 282 | if with_mlm_head: 283 | self.mlm_head = BertOnlyMLMHead(config, self.word_embeddings.weight) 284 | if with_mvrc_head: 285 | self.mvrc_head = VisualLinguisticBertMVRCHead(config) 286 | 287 | # init weights 288 | self.apply(self.init_weights) 289 | if config.visual_ln: 290 | self.visual_ln_text.weight.data.fill_(self.config.visual_scale_text_init) 291 | self.visual_ln_object.weight.data.fill_(self.config.visual_scale_object_init) 292 | 293 | # load language pretrained model 294 | if language_pretrained_model_path is not None: 295 | self.load_language_pretrained_model(language_pretrained_model_path) 296 | 297 | if config.word_embedding_frozen: 298 | for p in self.word_embeddings.parameters(): 299 | p.requires_grad = False 300 | 301 | if config.pos_embedding_frozen: 302 | for p in self.position_embeddings.parameters(): 303 | p.requires_grad = False 304 | 305 | def forward(self, 306 | text_input_ids, 307 | text_token_type_ids, 308 | text_visual_embeddings, 309 | text_mask, 310 | object_vl_embeddings, 311 | object_mask, 312 | output_all_encoded_layers=True, 313 | output_text_and_object_separately=False): 314 | 315 | text_out, object_out, pooled_rep = super(VisualLinguisticBertForPretraining, self).forward( 316 | text_input_ids, 317 | text_token_type_ids, 318 | text_visual_embeddings, 319 | text_mask, 320 | object_vl_embeddings, 321 | object_mask, 322 | output_all_encoded_layers=False, 323 | output_text_and_object_separately=True 324 | ) 325 | 326 | if self.with_rel_head: 327 | relationship_logits = self.relationsip_head(pooled_rep) 328 | else: 329 | relationship_logits = None 330 | if self.with_mlm_head: 331 | mlm_logits = self.mlm_head(text_out) 332 | else: 333 | mlm_logits = None 334 | if self.with_mvrc_head: 335 | mvrc_logits = self.mvrc_head(object_out) 336 | else: 337 | mvrc_logits = None 338 | 339 | return relationship_logits, mlm_logits, mvrc_logits 340 | 341 | def load_language_pretrained_model(self, language_pretrained_model_path): 342 | pretrained_state_dict = torch.load(language_pretrained_model_path, map_location=lambda storage, loc: storage) 343 | encoder_pretrained_state_dict = {} 344 | pooler_pretrained_state_dict = {} 345 | embedding_ln_pretrained_state_dict = {} 346 | relationship_head_pretrained_state_dict = {} 347 | mlm_head_pretrained_state_dict = {} 348 | unexpected_keys = [] 349 | for _k, v in pretrained_state_dict.items(): 350 | if _k.startswith('bert.') or _k.startswith('roberta.'): 351 | k = _k[len('bert.'):] if _k.startswith('bert.') else _k[len('roberta.'):] 352 | if 'gamma' in k: 353 | k = k.replace('gamma', 'weight') 354 | if 'beta' in k: 355 | k = k.replace('beta', 'bias') 356 | if k.startswith('encoder.'): 357 | k_ = k[len('encoder.'):] 358 | if k_ in self.encoder.state_dict(): 359 | encoder_pretrained_state_dict[k_] = v 360 | else: 361 | unexpected_keys.append(_k) 362 | elif k.startswith('embeddings.'): 363 | k_ = k[len('embeddings.'):] 364 | if k_ == 'word_embeddings.weight': 365 | self.word_embeddings.weight.data = v.to(dtype=self.word_embeddings.weight.data.dtype, 366 | device=self.word_embeddings.weight.data.device) 367 | elif k_ == 'position_embeddings.weight': 368 | self.position_embeddings.weight.data = v.to(dtype=self.position_embeddings.weight.data.dtype, 369 | device=self.position_embeddings.weight.data.device) 370 | elif k_ == 'token_type_embeddings.weight': 371 | self.token_type_embeddings.weight.data[:v.size(0)] = v.to( 372 | dtype=self.token_type_embeddings.weight.data.dtype, 373 | device=self.token_type_embeddings.weight.data.device) 374 | if v.size(0) == 1: 375 | # Todo: roberta token type embedding 376 | self.token_type_embeddings.weight.data[1] = v[0].to( 377 | dtype=self.token_type_embeddings.weight.data.dtype, 378 | device=self.token_type_embeddings.weight.data.device) 379 | elif k_.startswith('LayerNorm.'): 380 | k__ = k_[len('LayerNorm.'):] 381 | if k__ in self.embedding_LayerNorm.state_dict(): 382 | embedding_ln_pretrained_state_dict[k__] = v 383 | else: 384 | unexpected_keys.append(_k) 385 | else: 386 | unexpected_keys.append(_k) 387 | elif self.config.with_pooler and k.startswith('pooler.'): 388 | k_ = k[len('pooler.'):] 389 | if k_ in self.pooler.state_dict(): 390 | pooler_pretrained_state_dict[k_] = v 391 | else: 392 | unexpected_keys.append(_k) 393 | elif _k.startswith('cls.seq_relationship.') and self.with_rel_head: 394 | k_ = _k[len('cls.seq_relationship.'):] 395 | if 'gamma' in k_: 396 | k_ = k_.replace('gamma', 'weight') 397 | if 'beta' in k_: 398 | k_ = k_.replace('beta', 'bias') 399 | if k_ in self.relationsip_head.caption_image_relationship.state_dict(): 400 | relationship_head_pretrained_state_dict[k_] = v 401 | else: 402 | unexpected_keys.append(_k) 403 | elif (_k.startswith('cls.predictions.') or _k.startswith('lm_head.')) and self.with_mlm_head: 404 | k_ = _k[len('cls.predictions.'):] if _k.startswith('cls.predictions.') else _k[len('lm_head.'):] 405 | if _k.startswith('lm_head.'): 406 | if 'dense' in k_ or 'layer_norm' in k_: 407 | k_ = 'transform.' + k_ 408 | if 'layer_norm' in k_: 409 | k_ = k_.replace('layer_norm', 'LayerNorm') 410 | if 'gamma' in k_: 411 | k_ = k_.replace('gamma', 'weight') 412 | if 'beta' in k_: 413 | k_ = k_.replace('beta', 'bias') 414 | if k_ in self.mlm_head.predictions.state_dict(): 415 | mlm_head_pretrained_state_dict[k_] = v 416 | else: 417 | unexpected_keys.append(_k) 418 | else: 419 | unexpected_keys.append(_k) 420 | if len(unexpected_keys) > 0: 421 | print("Warnings: Unexpected keys: {}.".format(unexpected_keys)) 422 | self.embedding_LayerNorm.load_state_dict(embedding_ln_pretrained_state_dict) 423 | self.encoder.load_state_dict(encoder_pretrained_state_dict) 424 | if self.config.with_pooler and len(pooler_pretrained_state_dict) > 0: 425 | self.pooler.load_state_dict(pooler_pretrained_state_dict) 426 | if self.with_rel_head and len(relationship_head_pretrained_state_dict) > 0: 427 | self.relationsip_head.caption_image_relationship.load_state_dict(relationship_head_pretrained_state_dict) 428 | if self.with_mlm_head: 429 | self.mlm_head.predictions.load_state_dict(mlm_head_pretrained_state_dict) 430 | 431 | 432 | class VisualLinguisticBertMVRCHeadTransform(BaseModel): 433 | def __init__(self, config): 434 | super(VisualLinguisticBertMVRCHeadTransform, self).__init__(config) 435 | 436 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 437 | self.act = ACT2FN[config.hidden_act] 438 | 439 | self.apply(self.init_weights) 440 | 441 | def forward(self, hidden_states): 442 | hidden_states = self.dense(hidden_states) 443 | hidden_states = self.act(hidden_states) 444 | 445 | return hidden_states 446 | 447 | 448 | class VisualLinguisticBertMVRCHead(BaseModel): 449 | def __init__(self, config): 450 | super(VisualLinguisticBertMVRCHead, self).__init__(config) 451 | 452 | self.transform = VisualLinguisticBertMVRCHeadTransform(config) 453 | self.region_cls_pred = nn.Linear(config.hidden_size, config.visual_region_classes) 454 | self.apply(self.init_weights) 455 | 456 | def forward(self, hidden_states): 457 | 458 | hidden_states = self.transform(hidden_states) 459 | logits = self.region_cls_pred(hidden_states) 460 | 461 | return logits 462 | 463 | 464 | class VisualLinguisticBertRelationshipPredictionHead(BaseModel): 465 | def __init__(self, config): 466 | super(VisualLinguisticBertRelationshipPredictionHead, self).__init__(config) 467 | 468 | self.caption_image_relationship = nn.Linear(config.hidden_size, 2) 469 | self.apply(self.init_weights) 470 | 471 | def forward(self, pooled_rep): 472 | 473 | relationship_logits = self.caption_image_relationship(pooled_rep) 474 | 475 | return relationship_logits 476 | 477 | -------------------------------------------------------------------------------- /lib/models/bert_modules/vlbert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .modeling import BertPredictionHeadTransform 6 | from .visual_linguistic_bert import VisualLinguisticBert 7 | 8 | BERT_WEIGHTS_NAME = 'pytorch_model.bin' 9 | 10 | 11 | class TLocVLBERT(nn.Module): 12 | def __init__(self, dataset, config): 13 | 14 | super(TLocVLBERT, self).__init__() 15 | 16 | self.config = config 17 | 18 | language_pretrained_model_path = None 19 | if config.BERT_PRETRAINED != '': 20 | language_pretrained_model_path = '{}-{:04d}.model'.format(config.BERT_PRETRAINED, 21 | config.BERT_PRETRAINED_EPOCH) 22 | elif os.path.isdir(config.BERT_MODEL_NAME): 23 | weight_path = os.path.join(config.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) 24 | if os.path.isfile(weight_path): 25 | language_pretrained_model_path = weight_path 26 | self.language_pretrained_model_path = language_pretrained_model_path 27 | if language_pretrained_model_path is None: 28 | print("Warning: no pretrained language model found, training from scratch!!!") 29 | 30 | if dataset == "ActivityNet": 31 | iou_mask_map = torch.zeros(33,33).float() 32 | for i in range(0,32,1): 33 | iou_mask_map[i,i+1:min(i+17,33)] = 1. 34 | for i in range(0,32-16,2): 35 | iou_mask_map[i,range(18+i,33,2)] = 1. 36 | elif dataset == "TACoS": 37 | iou_mask_map = torch.zeros(129,129).float() 38 | for i in range(0,128,1): 39 | iou_mask_map[i,1+i:min(i+17,129)] = 1. 40 | for i in range(0,128-16,2): 41 | iou_mask_map[i,range(18+i,min(33+i,129),2)] = 1. 42 | for i in range(0,128-32,4): 43 | iou_mask_map[i,range(36+i,min(65+i,129),4)] = 1. 44 | for i in range(0,128-64,8): 45 | iou_mask_map[i,range(72+i,129,8)] = 1. 46 | else: 47 | print('DATASET ERROR') 48 | exit() 49 | 50 | self.register_buffer('iou_mask_map', iou_mask_map) 51 | 52 | self.vlbert = VisualLinguisticBert(dataset, config, 53 | language_pretrained_model_path=language_pretrained_model_path) 54 | 55 | dim = config.hidden_size 56 | if config.CLASSIFIER_TYPE == "2fc": 57 | self.final_mlp = torch.nn.Sequential( 58 | torch.nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 59 | torch.nn.Linear(dim, config.CLASSIFIER_HIDDEN_SIZE), 60 | torch.nn.ReLU(inplace=True), 61 | torch.nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 62 | torch.nn.Linear(config.CLASSIFIER_HIDDEN_SIZE, config.vocab_size) 63 | ) 64 | self.final_mlp_2 = torch.nn.Sequential( 65 | torch.nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 66 | torch.nn.Linear(dim, dim*3), 67 | torch.nn.ReLU(inplace=True), 68 | torch.nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 69 | ) 70 | self.final_mlp_3 = torch.nn.Sequential( 71 | torch.nn.Linear(dim*3, config.CLASSIFIER_HIDDEN_SIZE), 72 | torch.nn.ReLU(inplace=True), 73 | torch.nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 74 | torch.nn.Linear(config.CLASSIFIER_HIDDEN_SIZE, 3) 75 | ) 76 | self.final_mlp_s = torch.nn.Sequential( 77 | torch.nn.Linear(dim, config.CLASSIFIER_HIDDEN_SIZE), 78 | torch.nn.ReLU(inplace=True), 79 | torch.nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 80 | torch.nn.Linear(config.CLASSIFIER_HIDDEN_SIZE, 1) 81 | ) 82 | self.final_mlp_e = torch.nn.Sequential( 83 | torch.nn.Linear(dim, config.CLASSIFIER_HIDDEN_SIZE), 84 | torch.nn.ReLU(inplace=True), 85 | torch.nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 86 | torch.nn.Linear(config.CLASSIFIER_HIDDEN_SIZE, 1) 87 | ) 88 | self.final_mlp_c = torch.nn.Sequential( 89 | torch.nn.Linear(dim, config.CLASSIFIER_HIDDEN_SIZE), 90 | torch.nn.ReLU(inplace=True), 91 | torch.nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 92 | torch.nn.Linear(config.CLASSIFIER_HIDDEN_SIZE, 1) 93 | ) 94 | 95 | # elif config.CLASSIFIER_TYPE == 'mlm': 96 | # transform = BertPredictionHeadTransform(config.PARAMS) 97 | # linear = nn.Linear(config.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE) 98 | # self.final_mlp = nn.Sequential( 99 | # transform, 100 | # nn.Dropout(config.CLASSIFIER_DROPOUT, inplace=False), 101 | # linear 102 | # ) 103 | else: 104 | raise ValueError("Not support classifier type: {}!".format(config.CLASSIFIER_TYPE)) 105 | 106 | # init weights 107 | self.init_weight() 108 | 109 | self.fix_params() 110 | 111 | def init_weight(self): 112 | for m in self.final_mlp.modules(): 113 | if isinstance(m, torch.nn.Linear): 114 | torch.nn.init.xavier_uniform_(m.weight) 115 | torch.nn.init.constant_(m.bias, 0) 116 | for m in self.final_mlp_2.modules(): 117 | if isinstance(m, torch.nn.Linear): 118 | torch.nn.init.xavier_uniform_(m.weight) 119 | torch.nn.init.constant_(m.bias, 0) 120 | for m in self.final_mlp_3.modules(): 121 | if isinstance(m, torch.nn.Linear): 122 | torch.nn.init.xavier_uniform_(m.weight) 123 | torch.nn.init.constant_(m.bias, 0) 124 | 125 | def fix_params(self): 126 | pass 127 | 128 | 129 | def forward(self, text_input_feats, text_mask, word_mask, object_visual_feats): 130 | ########################################### 131 | 132 | # Visual Linguistic BERT 133 | 134 | hidden_states_text, hidden_states_object = self.vlbert(text_input_feats, 135 | text_mask, 136 | word_mask, 137 | object_visual_feats, 138 | output_all_encoded_layers=False) 139 | 140 | logits_text = self.final_mlp(hidden_states_text) 141 | hidden_states_object = self.final_mlp_2(hidden_states_object) 142 | hidden_s, hidden_e, hidden_c = torch.split(hidden_states_object, self.config.hidden_size, dim=-1) 143 | 144 | T = hidden_states_object.size(1) 145 | s_idx = torch.arange(T, device=hidden_states_object.device) 146 | e_idx = torch.arange(T, device=hidden_states_object.device) 147 | c_point = hidden_c[:,(0.5*(s_idx[:,None] + e_idx[None,:])).long().flatten(),:].view(hidden_c.size(0),T,T,hidden_c.size(-1)) 148 | s_c_e_points = torch.cat((hidden_s[:,:,None,:].repeat(1,1,T,1), c_point, hidden_e[:,None,:,:].repeat(1,T,1,1)), -1) 149 | logits_iou = self.final_mlp_3(s_c_e_points).permute(0,3,1,2).contiguous() 150 | 151 | logits_visual = torch.cat((self.final_mlp_s(hidden_s), self.final_mlp_e(hidden_e), self.final_mlp_c(hidden_c)), -1) 152 | # logits_visual = logits_visual.permute(0,2,1).contiguous() 153 | 154 | return logits_text, logits_visual, logits_iou, self.iou_mask_map.clone().detach() 155 | -------------------------------------------------------------------------------- /lib/models/frame_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .frame_pool import FrameAvgPool, FrameMaxPool 2 | -------------------------------------------------------------------------------- /lib/models/frame_modules/frame_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class FrameAvgPool(nn.Module): 5 | 6 | def __init__(self, cfg): 7 | super(FrameAvgPool, self).__init__() 8 | input_size = cfg.INPUT_SIZE 9 | hidden_size = cfg.HIDDEN_SIZE 10 | kernel_size = cfg.KERNEL_SIZE 11 | stride = cfg.STRIDE 12 | self.avg_pool = nn.AvgPool1d(kernel_size, stride, int(kernel_size/2)) 13 | 14 | def forward(self, visual_input): 15 | vis_h = self.avg_pool(visual_input) 16 | return vis_h 17 | 18 | class FrameMaxPool(nn.Module): 19 | 20 | def __init__(self, input_size, hidden_size, stride): 21 | super(FrameMaxPool, self).__init__() 22 | self.max_pool = nn.MaxPool1d(stride) 23 | 24 | def forward(self, visual_input): 25 | vis_h = self.max_pool(visual_input) 26 | return vis_h 27 | -------------------------------------------------------------------------------- /lib/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def bce_rescale_loss(config, logits_text, logits_visual, logits_iou, iou_mask_map, gt_maps, gt_times, word_label, word_mask): 5 | T = gt_maps.shape[-1] 6 | joint_prob = torch.sigmoid(logits_visual[:,:3,:]) 7 | gt_p = gt_maps[:,:3,:] 8 | loss = F.binary_cross_entropy_with_logits(logits_visual[:,:3,:], gt_p, reduction='none') * (joint_prob-gt_p) * (joint_prob-gt_p) 9 | 10 | reg_mask = (gt_maps[:,0:1,:T,None] >= 0.4) * (gt_maps[:,1:2,None,:] >= 0.4) 11 | gt_tmp = torch.cat((gt_maps[:,3:4,:T,None].repeat(1,1,1,T), gt_maps[:,4:5,None,:].repeat(1,1,T,1)), 1) 12 | loss_reg = (torch.abs(logits_iou[:,:2,:,:] - gt_tmp) * reg_mask).sum((2,3)) / reg_mask.sum((2,3)) 13 | 14 | idxs = torch.arange(T, device=logits_iou.device) 15 | s_e_idx = torch.cat((idxs[None,None,:T,None].repeat(1,1,1,T), idxs[None,None,None,:].repeat(1,1,T,1)), 1) 16 | s_e_time = (s_e_idx + logits_iou[:,:2,:,:]).clone().detach() 17 | 18 | iou = torch.clamp(torch.min(gt_times[:,1][:,None,None], s_e_time[:,1,:,:]) - torch.max(gt_times[:,0][:,None,None], s_e_time[:,0,:,:]), min=0.0000000001) / torch.clamp(torch.max(gt_times[:,1][:,None,None], s_e_time[:,1,:,:]) - torch.min(gt_times[:,0][:,None,None], s_e_time[:,0,:,:]), min=0.0000001) 19 | 20 | temp = (s_e_time[:,0,:,:] < s_e_time[:,1,:,:]) * iou_mask_map[None,:,:] 21 | # iou[iou > 0.7] = 1. 22 | iou[iou < 0.5] = 0. 23 | loss_iou = (F.binary_cross_entropy_with_logits(logits_iou[:,2,:,:], iou, reduction='none') * temp * torch.pow(torch.sigmoid(logits_iou[:,2,:,:]) - iou, 2)).sum((1,2)) / temp.sum((1,2)) 24 | 25 | log_p = F.log_softmax(logits_text, -1)*word_mask.unsqueeze(2) 26 | 27 | grid = torch.arange(log_p.shape[-1], device=log_p.device).repeat(log_p.shape[0], log_p.shape[1], 1) 28 | 29 | text_loss = torch.sum(-log_p[grid==word_label.unsqueeze(2)]) / torch.clamp((word_mask.sum(1)>0).sum(), min=0.00000001) 30 | 31 | loss_value = config.W1*loss.sum(-1).mean() + config.W2*loss_reg.mean() + config.W3*loss_iou.mean() + config.W4*text_loss 32 | 33 | return loss_value, joint_prob, torch.sigmoid(logits_iou[:,2,:,:])*temp, s_e_time 34 | -------------------------------------------------------------------------------- /lib/models/tan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from core.config import config 4 | import models.frame_modules as frame_modules 5 | import models.bert_modules as bert_modules 6 | 7 | class TAN(nn.Module): 8 | def __init__(self): 9 | super(TAN, self).__init__() 10 | 11 | self.frame_layer = getattr(frame_modules, config.TAN.FRAME_MODULE.NAME)(config.TAN.FRAME_MODULE.PARAMS) 12 | self.bert_layer = getattr(bert_modules, config.TAN.VLBERT_MODULE.NAME)(config.DATASET.NAME, config.TAN.VLBERT_MODULE.PARAMS) 13 | 14 | def forward(self, textual_input, textual_mask, word_mask, visual_input): 15 | 16 | vis_h = self.frame_layer(visual_input.transpose(1, 2)) 17 | vis_h = vis_h.transpose(1, 2) 18 | logits_text, logits_visual, logits_iou, iou_mask_map = self.bert_layer(textual_input, textual_mask, word_mask, vis_h) 19 | # logits_text = logits_text.transpose(1, 2) 20 | logits_visual = logits_visual.transpose(1, 2) 21 | 22 | return logits_text, logits_visual, logits_iou, iou_mask_map 23 | 24 | def extract_features(self, textual_input, textual_mask, visual_input): 25 | pass 26 | -------------------------------------------------------------------------------- /moment_localization/_init_paths.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Houwen Peng and Zhipeng Zhang 5 | # Details: import other paths 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os.path as osp 13 | import sys 14 | 15 | 16 | def add_path(path): 17 | if path not in sys.path: 18 | sys.path.insert(0, path) 19 | 20 | 21 | this_dir = osp.dirname(__file__) 22 | 23 | lib_path = osp.join(this_dir, '..', 'lib') 24 | add_path(lib_path) -------------------------------------------------------------------------------- /moment_localization/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import pickle as pkl 5 | 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | import _init_paths 12 | from core.engine import Engine 13 | import datasets 14 | import models 15 | from core.utils import AverageMeter 16 | from core.config import config, update_config 17 | from core.eval import eval_predictions, display_results 18 | import models.loss as loss 19 | 20 | torch.manual_seed(0) 21 | torch.cuda.manual_seed(0) 22 | 23 | torch.set_printoptions(precision=2, sci_mode=False) 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Test localization network') 27 | 28 | # general 29 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 30 | args, rest = parser.parse_known_args() 31 | 32 | # update config 33 | update_config(args.cfg) 34 | 35 | # testing 36 | parser.add_argument('--gpus', help='gpus', type=str) 37 | parser.add_argument('--workers', help='num of dataloader workers', type=int) 38 | parser.add_argument('--dataDir', help='data path', type=str) 39 | parser.add_argument('--modelDir', help='model path', type=str) 40 | parser.add_argument('--logDir', help='log path', type=str) 41 | parser.add_argument('--split', default='val', required=True, choices=['train', 'val', 'test'], type=str) 42 | parser.add_argument('--verbose', default=False, action="store_true", help='print progress bar') 43 | args = parser.parse_args() 44 | 45 | return args 46 | 47 | def reset_config(config, args): 48 | if args.gpus: 49 | config.GPUS = args.gpus 50 | if args.workers: 51 | config.WORKERS = args.workers 52 | if args.dataDir: 53 | config.DATA_DIR = args.dataDir 54 | if args.modelDir: 55 | config.OUTPUT_DIR = args.modelDir 56 | if args.logDir: 57 | config.LOG_DIR = args.logDir 58 | if args.verbose: 59 | config.VERBOSE = args.verbose 60 | 61 | def save_scores(scores, data, dataset_name, split): 62 | results = {} 63 | for i, d in enumerate(data): 64 | results[d['video']] = scores[i] 65 | pkl.dump(results,open(os.path.join(config.RESULT_DIR, dataset_name, '{}_{}_{}.pkl'.format(config.MODEL.NAME,config.DATASET.VIS_INPUT_TYPE, 66 | split)),'wb')) 67 | 68 | if __name__ == '__main__': 69 | args = parse_args() 70 | reset_config(config, args) 71 | 72 | device = ("cuda" if torch.cuda.is_available() else "cpu" ) 73 | model = getattr(models, config.MODEL.NAME)() 74 | model_checkpoint = torch.load(config.MODEL.CHECKPOINT) 75 | model.load_state_dict(model_checkpoint) 76 | if torch.cuda.device_count() > 1: 77 | print("Using", torch.cuda.device_count(), "GPUs") 78 | model = torch.nn.DataParallel(model) 79 | model = model.to(device) 80 | model.eval() 81 | 82 | test_dataset = getattr(datasets, config.DATASET.NAME)(args.split) 83 | dataloader = DataLoader(test_dataset, 84 | batch_size=config.TRAIN.BATCH_SIZE, 85 | shuffle=False, 86 | num_workers=config.WORKERS, 87 | pin_memory=False, 88 | collate_fn=datasets.collate_fn) 89 | 90 | def network(sample): 91 | anno_idxs = sample['batch_anno_idxs'] 92 | textual_input = sample['batch_word_vectors'].cuda() 93 | textual_mask = sample['batch_txt_mask'].cuda() 94 | visual_input = sample['batch_vis_input'].cuda() 95 | map_gt = sample['batch_map_gt'].cuda() 96 | duration = sample['batch_duration'] 97 | word_label = sample['batch_word_label'].cuda() 98 | word_mask = sample['batch_word_mask'].cuda() 99 | gt_times = sample['batch_gt_times'].cuda() 100 | 101 | logits_text, logits_visual, logits_iou, iou_mask_map = model(textual_input, textual_mask, word_mask, visual_input) 102 | loss_value, joint_prob, iou_scores, regress = getattr(loss, config.LOSS.NAME)(config.LOSS.PARAMS, logits_text, logits_visual, logits_iou, iou_mask_map, map_gt, gt_times, word_label, word_mask) 103 | 104 | sorted_times = None if model.training else get_proposal_results(iou_scores, regress, duration) 105 | 106 | return loss_value, sorted_times 107 | 108 | def get_proposal_results(scores, regress, durations): 109 | # assume all valid scores are larger than one 110 | out_sorted_times = [] 111 | T = scores.shape[-1] 112 | 113 | regress = regress.cpu().detach().numpy() 114 | 115 | for score, reg, duration in zip(scores, regress, durations): 116 | sorted_indexs = np.dstack(np.unravel_index(np.argsort(score.cpu().detach().numpy().ravel())[::-1], (T, T))).tolist() 117 | sorted_indexs = np.array([ [reg[0,item[0],item[1]], reg[1,item[0],item[1]]] for item in sorted_indexs[0] if reg[0,item[0],item[1]] < reg[1,item[0],item[1]] ]).astype(float) 118 | sorted_indexs = torch.from_numpy(sorted_indexs).cuda() 119 | target_size = config.DATASET.NUM_SAMPLE_CLIPS // config.DATASET.TARGET_STRIDE 120 | out_sorted_times.append((sorted_indexs.float() / target_size * duration).tolist()) 121 | 122 | return out_sorted_times 123 | 124 | 125 | def on_test_start(state): 126 | state['loss_meter'] = AverageMeter() 127 | state['sorted_segments_list'] = [] 128 | state['output'] = [] 129 | if config.VERBOSE: 130 | state['progress_bar'] = tqdm(total=math.ceil(len(test_dataset)/config.TRAIN.BATCH_SIZE)) 131 | 132 | def on_test_forward(state): 133 | if config.VERBOSE: 134 | state['progress_bar'].update(1) 135 | state['loss_meter'].update(state['loss'].item(), 1) 136 | 137 | min_idx = min(state['sample']['batch_anno_idxs']) 138 | batch_indexs = [idx - min_idx for idx in state['sample']['batch_anno_idxs']] 139 | sorted_segments = [state['output'][i] for i in batch_indexs] 140 | state['sorted_segments_list'].extend(sorted_segments) 141 | 142 | def on_test_end(state): 143 | if config.VERBOSE: 144 | state['progress_bar'].close() 145 | print() 146 | 147 | annotations = test_dataset.annotations 148 | state['Rank@N,mIoU@M'], state['miou'] = eval_predictions(state['sorted_segments_list'], annotations, verbose=True) 149 | 150 | loss_message = '\ntest loss {:.4f}'.format(state['loss_meter'].avg) 151 | print(loss_message) 152 | state['loss_meter'].reset() 153 | test_table = display_results(state['Rank@N,mIoU@M'], state['miou'], 154 | 'performance on testing set') 155 | table_message = '\n'+test_table 156 | print(table_message) 157 | 158 | # save_scores(state['sorted_segments_list'], annotations, config.DATASET.NAME, args.split) 159 | 160 | 161 | engine = Engine() 162 | engine.hooks['on_test_start'] = on_test_start 163 | engine.hooks['on_test_forward'] = on_test_forward 164 | engine.hooks['on_test_end'] = on_test_end 165 | engine.test(network,dataloader, args.split) 166 | -------------------------------------------------------------------------------- /moment_localization/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import _init_paths 6 | import os 7 | import pprint 8 | import argparse 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | from tqdm import tqdm 15 | import datasets 16 | import models 17 | from core.config import config, update_config 18 | from core.engine import Engine 19 | from core.utils import AverageMeter 20 | from core import eval 21 | from core.utils import create_logger 22 | import models.loss as loss 23 | import math 24 | 25 | torch.manual_seed(0) 26 | torch.cuda.manual_seed(0) 27 | torch.autograd.set_detect_anomaly(True) 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='Train localization network') 30 | 31 | # general 32 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 33 | args, rest = parser.parse_known_args() 34 | 35 | # update config 36 | update_config(args.cfg) 37 | 38 | # training 39 | parser.add_argument('--gpus', help='gpus', type=str) 40 | parser.add_argument('--workers', help='num of dataloader workers', type=int) 41 | parser.add_argument('--dataDir', help='data path', type=str) 42 | parser.add_argument('--modelDir', help='model path', type=str) 43 | parser.add_argument('--logDir', help='log path', type=str) 44 | parser.add_argument('--verbose', default=False, action="store_true", help='print progress bar') 45 | parser.add_argument('--tag', help='tags shown in log', type=str) 46 | args = parser.parse_args() 47 | 48 | return args 49 | 50 | def reset_config(config, args): 51 | if args.gpus: 52 | config.GPUS = args.gpus 53 | if args.workers: 54 | config.WORKERS = args.workers 55 | if args.dataDir: 56 | config.DATA_DIR = args.dataDir 57 | if args.modelDir: 58 | config.MODEL_DIR = args.modelDir 59 | if args.logDir: 60 | config.LOG_DIR = args.logDir 61 | if args.verbose: 62 | config.VERBOSE = args.verbose 63 | if args.tag: 64 | config.TAG = args.tag 65 | 66 | 67 | if __name__ == '__main__': 68 | 69 | args = parse_args() 70 | reset_config(config, args) 71 | 72 | logger, final_output_dir = create_logger(config, args.cfg, config.TAG) 73 | logger.info('\n'+pprint.pformat(args)) 74 | logger.info('\n'+pprint.pformat(config)) 75 | 76 | # cudnn related setting 77 | cudnn.benchmark = config.CUDNN.BENCHMARK 78 | torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC 79 | torch.backends.cudnn.enabled = config.CUDNN.ENABLED 80 | 81 | dataset_name = config.DATASET.NAME 82 | model_name = config.MODEL.NAME 83 | 84 | train_dataset = getattr(datasets, dataset_name)('train') 85 | if config.TEST.EVAL_TRAIN: 86 | eval_train_dataset = getattr(datasets, dataset_name)('train') 87 | if not config.DATASET.NO_VAL: 88 | val_dataset = getattr(datasets, dataset_name)('val') 89 | test_dataset = getattr(datasets, dataset_name)('test') 90 | 91 | model = getattr(models, model_name)() 92 | if config.MODEL.CHECKPOINT and config.TRAIN.CONTINUE: 93 | model_checkpoint = torch.load(config.MODEL.CHECKPOINT) 94 | model.load_state_dict(model_checkpoint) 95 | if torch.cuda.device_count() > 1: 96 | print("Using", torch.cuda.device_count(), "GPUs") 97 | model = torch.nn.DataParallel(model) 98 | device = ("cuda" if torch.cuda.is_available() else "cpu" ) 99 | model = model.to(device) 100 | 101 | optimizer = optim.AdamW(model.parameters(),lr=config.TRAIN.LR, betas=(0.9, 0.999), weight_decay=config.TRAIN.WEIGHT_DECAY) 102 | # optimizer = optim.SGD(model.parameters(), lr=config.TRAIN.LR, momentum=0.9, weight_decay=config.TRAIN.WEIGHT_DECAY) 103 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.TRAIN.FACTOR, patience=config.TRAIN.PATIENCE, verbose=config.VERBOSE) 104 | 105 | 106 | def iterator(split): 107 | if split == 'train': 108 | dataloader = DataLoader(train_dataset, 109 | batch_size=config.TRAIN.BATCH_SIZE, 110 | shuffle=config.TRAIN.SHUFFLE, 111 | num_workers=config.WORKERS, 112 | pin_memory=False, 113 | collate_fn=datasets.collate_fn) 114 | elif split == 'val': 115 | dataloader = DataLoader(val_dataset, 116 | batch_size=config.TEST.BATCH_SIZE, 117 | shuffle=False, 118 | num_workers=config.WORKERS, 119 | pin_memory=False, 120 | collate_fn=datasets.collate_fn) 121 | elif split == 'test': 122 | dataloader = DataLoader(test_dataset, 123 | batch_size=config.TEST.BATCH_SIZE, 124 | shuffle=False, 125 | num_workers=config.WORKERS, 126 | pin_memory=False, 127 | collate_fn=datasets.collate_fn) 128 | elif split == 'train_no_shuffle': 129 | dataloader = DataLoader(eval_train_dataset, 130 | batch_size=config.TEST.BATCH_SIZE, 131 | shuffle=False, 132 | num_workers=config.WORKERS, 133 | pin_memory=False, 134 | collate_fn=datasets.collate_fn) 135 | else: 136 | raise NotImplementedError 137 | 138 | return dataloader 139 | 140 | def network(sample): 141 | anno_idxs = sample['batch_anno_idxs'] 142 | textual_input = sample['batch_word_vectors'].cuda() 143 | textual_mask = sample['batch_txt_mask'].cuda() 144 | visual_input = sample['batch_vis_input'].cuda() 145 | map_gt = sample['batch_map_gt'].cuda() 146 | duration = sample['batch_duration'] 147 | word_label = sample['batch_word_label'].cuda() 148 | word_mask = sample['batch_word_mask'].cuda() 149 | gt_times = sample['batch_gt_times'].cuda() 150 | 151 | logits_text, logits_visual, logits_iou, iou_mask_map = model(textual_input, textual_mask, word_mask, visual_input) 152 | loss_value, joint_prob, iou_scores, regress = getattr(loss, config.LOSS.NAME)(config.LOSS.PARAMS, logits_text, logits_visual, logits_iou, iou_mask_map, map_gt, gt_times, word_label, word_mask) 153 | 154 | sorted_times = None if model.training else get_proposal_results(iou_scores, regress, duration) 155 | 156 | return loss_value, sorted_times 157 | 158 | def get_proposal_results(scores, regress, durations): 159 | out_sorted_times = [] 160 | T = scores.shape[-1] 161 | 162 | regress = regress.cpu().detach().numpy() 163 | 164 | for score, reg, duration in zip(scores, regress, durations): 165 | sorted_indexs = np.dstack(np.unravel_index(np.argsort(score.cpu().detach().numpy().ravel())[::-1], (T, T))).tolist() 166 | sorted_indexs = np.array([ [reg[0,item[0],item[1]], reg[1,item[0],item[1]]] for item in sorted_indexs[0] if reg[0,item[0],item[1]] < reg[1,item[0],item[1]] ]).astype(float) 167 | sorted_indexs = torch.from_numpy(sorted_indexs).cuda() 168 | target_size = config.DATASET.NUM_SAMPLE_CLIPS // config.DATASET.TARGET_STRIDE 169 | out_sorted_times.append((sorted_indexs.float() / target_size * duration).tolist()) 170 | 171 | return out_sorted_times 172 | 173 | def on_start(state): 174 | state['loss_meter'] = AverageMeter() 175 | state['test_interval'] = int(len(train_dataset)/config.TRAIN.BATCH_SIZE*config.TEST.INTERVAL) 176 | state['t'] = 1 177 | model.train() 178 | if config.VERBOSE: 179 | state['progress_bar'] = tqdm(total=state['test_interval']) 180 | 181 | def on_forward(state): 182 | torch.nn.utils.clip_grad_norm_(model.parameters(), 10) 183 | state['loss_meter'].update(state['loss'].item(), 1) 184 | 185 | def on_update(state):# Save All 186 | if config.VERBOSE: 187 | state['progress_bar'].update(1) 188 | 189 | if state['t'] % state['test_interval'] == 0: 190 | model.eval() 191 | if config.VERBOSE: 192 | state['progress_bar'].close() 193 | 194 | loss_message = '\niter: {} train loss {:.4f}'.format(state['t'], state['loss_meter'].avg) 195 | table_message = '' 196 | if config.TEST.EVAL_TRAIN: 197 | train_state = engine.test(network, iterator('train_no_shuffle'), 'train') 198 | train_table = eval.display_results(train_state['Rank@N,mIoU@M'], train_state['miou'], 199 | 'performance on training set') 200 | table_message += '\n'+ train_table 201 | if not config.DATASET.NO_VAL: 202 | val_state = engine.test(network, iterator('val'), 'val') 203 | state['scheduler'].step(-val_state['loss_meter'].avg) 204 | loss_message += ' val loss {:.4f}'.format(val_state['loss_meter'].avg) 205 | val_state['loss_meter'].reset() 206 | val_table = eval.display_results(val_state['Rank@N,mIoU@M'], val_state['miou'], 207 | 'performance on validation set') 208 | table_message += '\n'+ val_table 209 | 210 | test_state = engine.test(network, iterator('test'), 'test') 211 | loss_message += ' test loss {:.4f}'.format(test_state['loss_meter'].avg) 212 | test_state['loss_meter'].reset() 213 | test_table = eval.display_results(test_state['Rank@N,mIoU@M'], test_state['miou'], 214 | 'performance on testing set') 215 | table_message += '\n' + test_table 216 | 217 | message = loss_message+table_message+'\n' 218 | logger.info(message) 219 | 220 | saved_model_filename = os.path.join(config.MODEL_DIR,'{}/{}/iter{:06d}-{:.4f}-{:.4f}.pkl'.format( 221 | dataset_name, model_name+'_'+config.DATASET.VIS_INPUT_TYPE, 222 | state['t'], test_state['Rank@N,mIoU@M'][0,0], test_state['Rank@N,mIoU@M'][0,1])) 223 | 224 | rootfolder1 = os.path.dirname(saved_model_filename) 225 | rootfolder2 = os.path.dirname(rootfolder1) 226 | rootfolder3 = os.path.dirname(rootfolder2) 227 | if not os.path.exists(rootfolder3): 228 | print('Make directory %s ...' % rootfolder3) 229 | os.mkdir(rootfolder3) 230 | if not os.path.exists(rootfolder2): 231 | print('Make directory %s ...' % rootfolder2) 232 | os.mkdir(rootfolder2) 233 | if not os.path.exists(rootfolder1): 234 | print('Make directory %s ...' % rootfolder1) 235 | os.mkdir(rootfolder1) 236 | 237 | if torch.cuda.device_count() > 1: 238 | torch.save(model.module.state_dict(), saved_model_filename) 239 | else: 240 | torch.save(model.state_dict(), saved_model_filename) 241 | 242 | 243 | if config.VERBOSE: 244 | state['progress_bar'] = tqdm(total=state['test_interval']) 245 | model.train() 246 | state['loss_meter'].reset() 247 | 248 | def on_end(state): 249 | if config.VERBOSE: 250 | state['progress_bar'].close() 251 | 252 | 253 | def on_test_start(state): 254 | state['loss_meter'] = AverageMeter() 255 | state['sorted_segments_list'] = [] 256 | if config.VERBOSE: 257 | if state['split'] == 'train': 258 | state['progress_bar'] = tqdm(total=math.ceil(len(train_dataset)/config.TEST.BATCH_SIZE)) 259 | elif state['split'] == 'val': 260 | state['progress_bar'] = tqdm(total=math.ceil(len(val_dataset)/config.TEST.BATCH_SIZE)) 261 | elif state['split'] == 'test': 262 | state['progress_bar'] = tqdm(total=math.ceil(len(test_dataset)/config.TEST.BATCH_SIZE)) 263 | else: 264 | raise NotImplementedError 265 | 266 | def on_test_forward(state): 267 | if config.VERBOSE: 268 | state['progress_bar'].update(1) 269 | state['loss_meter'].update(state['loss'].item(), 1) 270 | 271 | min_idx = min(state['sample']['batch_anno_idxs']) 272 | batch_indexs = [idx - min_idx for idx in state['sample']['batch_anno_idxs']] 273 | sorted_segments = [state['output'][i] for i in batch_indexs] 274 | state['sorted_segments_list'].extend(sorted_segments) 275 | 276 | def on_test_end(state): 277 | annotations = state['iterator'].dataset.annotations 278 | state['Rank@N,mIoU@M'], state['miou'] = eval.eval_predictions(state['sorted_segments_list'], annotations, verbose=False) 279 | if config.VERBOSE: 280 | state['progress_bar'].close() 281 | 282 | engine = Engine() 283 | engine.hooks['on_start'] = on_start 284 | engine.hooks['on_forward'] = on_forward 285 | engine.hooks['on_update'] = on_update 286 | engine.hooks['on_end'] = on_end 287 | engine.hooks['on_test_start'] = on_test_start 288 | engine.hooks['on_test_forward'] = on_test_forward 289 | engine.hooks['on_test_end'] = on_test_end 290 | engine.train(network, 291 | iterator('train'), 292 | maxepoch=config.TRAIN.MAX_EPOCH, 293 | optimizer=optimizer, 294 | scheduler=scheduler) 295 | --------------------------------------------------------------------------------