├── Data ├── FHM_rationale.json ├── README.md ├── chatgpt_abductive_reasoning.py ├── harmc_rationale.json ├── harmp_rationale.json └── training and inference data │ ├── FHM │ ├── captions.pkl │ ├── test.jsonl │ └── train.jsonl │ ├── harmc │ ├── captions.pkl │ ├── test.jsonl │ └── train.jsonl │ └── harmp │ ├── captions.pkl │ ├── test.jsonl │ └── train.jsonl ├── LICENSE ├── README.md ├── requirements.txt ├── run.py └── src ├── __init__.py ├── config.py ├── datamodules ├── __init__.py ├── datamodule_base.py ├── meme_datamodule.py └── multitask_datamodule.py ├── datasets ├── __init__.py ├── base_dataset.py └── meme.py ├── gadgets ├── __init__.py └── my_metrics.py ├── modules ├── __init__.py ├── dist_utils.py ├── heads.py ├── mm_module.py ├── mm_utils.py ├── objectives.py └── t5_model.py ├── transforms ├── __init__.py ├── randaug.py ├── transform.py └── utils.py └── utils ├── __init__.py └── glossary.py /Data/README.md: -------------------------------------------------------------------------------- 1 | ## Rationale Generation 2 | 3 | You can use the codes `chatgpt_abductive_reasoning.py` to generate the rationales. Remember to modify the path to different datasets and the OpenAI API key. 4 | 5 | We also provide our generated rationales with ChatGPT3.5: `harmc_rationale.json`, `harmp_rationale.json`, and `FHM_rationale.json`. 6 | 7 | 8 | ## Image Data 9 | 10 | The original image files can be found at [Harm-C](https://drive.google.com/file/d/1dxMrnyXcED-85HCcQiA_d5rr8acwl6lp/view?usp=sharing), [Harm-P](https://drive.google.com/file/d/1fw850yxKNqzpRpQKH88D13yfrwX1MLde/view?usp=sharing) and [FHM](https://hatefulmemeschallenge.com/#download). 11 | 12 | ## Data Preprocess 13 | 14 | To separate the text and image in the memes, we first in-paint the memes by combining MMOCR (Kuang et al., 2021) with SAM (Kirillov et al., 2023) to extract the text and pure image. Then during the captioning process, since the focus of this work is primarily on the multimodal reasoning for harmful meme detection from a fresh perspective on harnessing LLMs, we apply a pre-trained image captioning model ClipCap (Mokady et al., 2021) used in recent work (Cao et al., 2022), to generate textual descriptions about the dominant objects or events in the memes’ image, which is utilized as the inputs into LLMs for abductive reasoning. To generate the rationale for each meme, we employed ChatGPT (Ouyang et al., 2022), a widely used LLM developed by OpenAI, specifically utilizing the “gpt-3.5-turbo” version. Drawing the practice of previous work like MaskPrompt (Cao et al., 2022) on FHM data preprocessing, the input text is augmented with image caption, entities, and demographic information in the FHM data preprocessing for a fair comparison with the baseline. 15 | -------------------------------------------------------------------------------- /Data/chatgpt_abductive_reasoning.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import openai 4 | import pickle 5 | openai.api_key = "XXXXX" 6 | 7 | data = dict() 8 | id = 0 9 | with open("./FHM/mem_train.json", "r", encoding='utf8') as fin: 10 | data_list = json.load(fin) 11 | # for item in jsonlines.Reader(fin): 12 | # data[id]=item 13 | # id += 1 14 | fin.close() 15 | for data_item in data_list: 16 | data[id] = data_item 17 | id += 1 18 | 19 | with open('./FHM/captions.pkl','rb') as f: 20 | caption_dict = pickle.load(f) 21 | f.close() 22 | 23 | cids = list(data.keys()) 24 | pred = {} 25 | 26 | 27 | system_prompt = "You have been specially designed to perform abductive reasoning for the harmful meme detection task. Your primary function is that, according to a Harmfulness label about an Image with a text embedded, " \ 28 | "please provide me a streamlined rationale, without explicitly indicating the label, why it is classified as the given Harmfulness label. " \ 29 | "The image and the textual content in the meme are often uncorrelated, but its overall semantics is presented holistically. Thus it is important to note that you are prohibited from relying on your own imagination, as your goal is to provide the most accurate and reliable rationale possible " \ 30 | "so that people can infer the harmfulness according to your reasoning about the background context and relationship between the given text and image caption." 31 | count = 0 32 | while count < len(cids): 33 | 34 | cid = cids[count] 35 | # if data[cid]["id"] in pred: 36 | # count+=1 37 | # print(count) 38 | # continue 39 | try: 40 | text = data[cid]["clean_sent"].replace('\n', ' ').strip('\n') 41 | caption = caption_dict[data[cid]["img"].strip('.png')].strip('\n') 42 | if data[cid]["label"] == 1: 43 | label = 'harmful' 44 | else: 45 | label = 'harmless' 46 | 47 | user_prompt = f"Given a Text: '{text}', which is embedded in an Image: '{caption}'; and a harmfulness label '{label}', please give me a streamlined rationale associated with the meme, without explicitly indicating the label, why it is reasoned as {label}." 48 | 49 | reply = openai.ChatCompletion.create( 50 | model="gpt-3.5-turbo", 51 | messages=[ 52 | {"role": "system", "content": system_prompt}, 53 | {"role": "user", "content": user_prompt}, 54 | ], 55 | temperature=0, 56 | max_tokens=256, 57 | ) 58 | ans = reply["choices"][0]["message"]["content"] 59 | print(user_prompt) 60 | print(ans) 61 | print(count) 62 | 63 | pred[data[cid]["img"].strip('.png')] = ans.lower() 64 | with open("clipcap_FHM_rationale_.pkl", "wb") as fout: 65 | pickle.dump(pred, fout) 66 | fout.close() 67 | count += 1 68 | except: 69 | print("Let's have a sleep.") 70 | time.sleep(61) 71 | -------------------------------------------------------------------------------- /Data/training and inference data/FHM/captions.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/Mr.Harm-EMNLP2023/84a116fb7483621ca6c1b051b7055b443344817c/Data/training and inference data/FHM/captions.pkl -------------------------------------------------------------------------------- /Data/training and inference data/FHM/test.jsonl: -------------------------------------------------------------------------------- 1 | {"id":8291,"img":"img\/08291.png","label":1,"text":"white people is this a shooting range"} 2 | {"id":46971,"img":"img\/46971.png","label":1,"text":"bravery at its finest"} 3 | {"id":3745,"img":"img\/03745.png","label":1,"text":"your order comes to $37.50 and your white privilege discount brings the total to $37.50"} 4 | {"id":83745,"img":"img\/83745.png","label":1,"text":"it is time.. to send these parasites back to the desert"} 5 | {"id":80243,"img":"img\/80243.png","label":1,"text":"mississippi wind chime"} 6 | {"id":5279,"img":"img\/05279.png","label":1,"text":"knowing white people , that's probably the baby father"} 7 | {"id":1796,"img":"img\/01796.png","label":1,"text":"life hack #23 how to get stoned with no weed"} 8 | {"id":53046,"img":"img\/53046.png","label":1,"text":"you've heard of elf on a shelf, now get ready for"} 9 | {"id":82301,"img":"img\/82301.png","label":1,"text":"cooooooooooooon!!!!"} 10 | {"id":31752,"img":"img\/31752.png","label":1,"text":"when you get to choose your own mental illness"} 11 | {"id":27635,"img":"img\/27635.png","label":1,"text":"dat ass ahmarbkrich*d"} 12 | {"id":80597,"img":"img\/80597.png","label":1,"text":"look! it says it right here! we can fuck goats!"} 13 | {"id":45368,"img":"img\/45368.png","label":1,"text":"enough is enough children are more important than freaks"} 14 | {"id":17963,"img":"img\/17963.png","label":1,"text":"removes polish with chemicals nobody bats an eye removes polish with chemicals everybody loses his shit"} 15 | {"id":53968,"img":"img\/53968.png","label":1,"text":"when you have an 80 point iq but it's the highest in your village"} 16 | {"id":10749,"img":"img\/10749.png","label":1,"text":"my sense of humor is so dark it picks cotton"} 17 | {"id":25149,"img":"img\/25149.png","label":1,"text":"if this offends you you might be lack toes intolerant"} 18 | {"id":87520,"img":"img\/87520.png","label":1,"text":"yes, we know trump said \"mexico will pay for the wall\" but now mexicans are the wall. that boy good!!!!"} 19 | {"id":89071,"img":"img\/89071.png","label":1,"text":"they don't wash their ass with water and they eat their daughters' ass"} 20 | {"id":9563,"img":"img\/09563.png","label":1,"text":"safe sex isis style"} 21 | {"id":72048,"img":"img\/72048.png","label":1,"text":"hurry up barack will be home soon"} 22 | {"id":49826,"img":"img\/49826.png","label":1,"text":"white people on cinco de mayo vs. every other day"} 23 | {"id":26453,"img":"img\/26453.png","label":1,"text":"a ww2 jewish prisoner trying to escape the auchschwitz showers (1943, colorized) wwii jew escaping hitler's rath"} 24 | {"id":12650,"img":"img\/12650.png","label":1,"text":"which one has tranny problems?"} 25 | {"id":2568,"img":"img\/02568.png","label":1,"text":"whatsthe difference between heroin and i? cops think twice before shooting heroin"} 26 | {"id":83954,"img":"img\/83954.png","label":1,"text":"head teachers given the power to ban the burqa in british schools share if you want a full ban in britain!"} 27 | {"id":24316,"img":"img\/24316.png","label":1,"text":"sharks don't attack him why? they think it's whale shit"} 28 | {"id":62035,"img":"img\/62035.png","label":1,"text":"thank god my ass needed a break"} 29 | {"id":65342,"img":"img\/65342.png","label":1,"text":"barn army enlist, and bomb ant hills \"grab your pitchfork and meet at the barn\""} 30 | {"id":92058,"img":"img\/92058.png","label":1,"text":"how black people became christians"} 31 | {"id":58672,"img":"img\/58672.png","label":1,"text":"did you know that the chinese eat dogs savages disgusting"} 32 | {"id":54069,"img":"img\/54069.png","label":1,"text":"when you decide to go full gas"} 33 | {"id":91586,"img":"img\/91586.png","label":1,"text":"left the kitchen too many times"} 34 | {"id":87130,"img":"img\/87130.png","label":1,"text":"finally finding the right tutorial you need indian accent"} 35 | {"id":95640,"img":"img\/95640.png","label":1,"text":"syrian rape survivor speaks out"} 36 | {"id":47819,"img":"img\/47819.png","label":1,"text":"oh, addiction isn't a disease because it started with a choice?? tell me more about how the chlamydia in your crotch isn't a disease, yet it started with a choice"} 37 | {"id":59140,"img":"img\/59140.png","label":1,"text":"wanna go masturbate? fuck yeah! masturbating!"} 38 | {"id":13647,"img":"img\/13647.png","label":1,"text":"she said girls run the world... so i told her get back in the kitchen"} 39 | {"id":65832,"img":"img\/65832.png","label":1,"text":"why love one but eat the other ? i don't get it"} 40 | {"id":57621,"img":"img\/57621.png","label":1,"text":"if this gets 1k likes i'll post to islam how we see muslims how muslims see us"} 41 | {"id":40982,"img":"img\/40982.png","label":1,"text":"and just like that... this sandwich maker doubles as ironing board"} 42 | {"id":43275,"img":"img\/43275.png","label":1,"text":"my husband is always pushing me around... and talking behind my back"} 43 | {"id":12973,"img":"img\/12973.png","label":1,"text":"someone added to much soap to the dishwasher again"} 44 | {"id":74350,"img":"img\/74350.png","label":1,"text":"a head diaper is required when you have shit for brains"} 45 | {"id":2478,"img":"img\/02478.png","label":1,"text":"must..resist...urge to.. kill whitey."} 46 | {"id":51607,"img":"img\/51607.png","label":1,"text":"49 people were killed at christchurch muslim 49 people were killed at christchurch mosques muslim"} 47 | {"id":65403,"img":"img\/65403.png","label":1,"text":"money is evil give it to us"} 48 | {"id":29750,"img":"img\/29750.png","label":1,"text":"this cat is getting raped tonight"} 49 | {"id":26547,"img":"img\/26547.png","label":1,"text":"that's the way allah allah i like it"} 50 | {"id":2143,"img":"img\/02143.png","label":1,"text":"doesnt have food, water, electricity proud of nuclear weapons"} 51 | {"id":68192,"img":"img\/68192.png","label":1,"text":"scientist are working hard to cure them all"} 52 | {"id":52603,"img":"img\/52603.png","label":1,"text":"feeling kinda cute might steal your shit later.. idk"} 53 | {"id":5349,"img":"img\/05349.png","label":0,"text":"normal day in a islamic country"} 54 | {"id":6491,"img":"img\/06491.png","label":1,"text":"tranny for sale"} 55 | {"id":74013,"img":"img\/74013.png","label":1,"text":"nurse: yes sir, we are having mashed potatoes today alzheimer's patient: i don't remember asking you a god damn thing"} 56 | {"id":27485,"img":"img\/27485.png","label":1,"text":"thats the way allah allah i like it!!!"} 57 | {"id":13026,"img":"img\/13026.png","label":1,"text":"under the name of multiculturalism we let terrorist into this country"} 58 | {"id":17265,"img":"img\/17265.png","label":0,"text":"meanwhile in baltimore"} 59 | {"id":91405,"img":"img\/91405.png","label":0,"text":"wears rag on her head so her brother husband doesnt want to molest her"} 60 | {"id":14026,"img":"img\/14026.png","label":1,"text":"told girlfriend that mom is deaf so speak loud and slow told mom that girlfriend is retarded"} 61 | {"id":18945,"img":"img\/18945.png","label":1,"text":"when you decide to go full gas"} 62 | {"id":83920,"img":"img\/83920.png","label":1,"text":"just under 17 and a half years ago, 19 people were allowed into the us who should not have been all these years later, we are told that the answer to safety is to allow everyone in"} 63 | {"id":83497,"img":"img\/83497.png","label":0,"text":"different types of radicalization"} 64 | {"id":4569,"img":"img\/04569.png","label":0,"text":"when is the time to eat i am hungry"} 65 | {"id":39607,"img":"img\/39607.png","label":1,"text":"those stupid basterds thought it was a real shower"} 66 | {"id":9715,"img":"img\/09715.png","label":1,"text":"i started a business that sells landmines disguised as prayer mats.. ..prophets are going through the roof"} 67 | {"id":1925,"img":"img\/01925.png","label":1,"text":"i am not racist i just don't like brown sugar"} 68 | {"id":94813,"img":"img\/94813.png","label":0,"text":"thats the way allah allah i like it!!!"} 69 | {"id":29437,"img":"img\/29437.png","label":1,"text":"what can happen if you eat wheaties for 38 years"} 70 | {"id":95038,"img":"img\/95038.png","label":0,"text":"when you take these pics we think about cumming in your mouth"} 71 | {"id":64891,"img":"img\/64891.png","label":1,"text":"lebanon: before muslims after muslims but don't worry! after islam takes over there'll be a golden age.. apparently"} 72 | {"id":63280,"img":"img\/63280.png","label":1,"text":"so you're never going to forgive the germans because of the holocaust? same"} 73 | {"id":17682,"img":"img\/17682.png","label":1,"text":"give them what they deserve kill them all"} 74 | {"id":84510,"img":"img\/84510.png","label":1,"text":"bartender: how about a nice 12 yr old? him: boy or girl?"} 75 | {"id":28690,"img":"img\/28690.png","label":0,"text":"if you're gonna talk shit about someone... you better make sure they aren't crazier than you are"} 76 | {"id":3567,"img":"img\/03567.png","label":1,"text":"potato is my favorite color of the alphabet"} 77 | {"id":95830,"img":"img\/95830.png","label":1,"text":"face of terrorism in us"} 78 | {"id":29873,"img":"img\/29873.png","label":1,"text":"who knew that this country is full of white trash"} 79 | {"id":70953,"img":"img\/70953.png","label":1,"text":"\"god's chosen people\" according to the book they wrote"} 80 | {"id":56124,"img":"img\/56124.png","label":0,"text":"how to cook turkey in usa how to cook turkey in romania"} 81 | {"id":10386,"img":"img\/10386.png","label":1,"text":"the book about terrorism"} 82 | {"id":19530,"img":"img\/19530.png","label":1,"text":"they shot the wrong gorilla"} 83 | {"id":94185,"img":"img\/94185.png","label":1,"text":"how a black woman can fight crime"} 84 | {"id":48370,"img":"img\/48370.png","label":1,"text":"72 virgins are ready to get fucked up"} 85 | {"id":74058,"img":"img\/74058.png","label":0,"text":"he made millions of children homeless"} 86 | {"id":45062,"img":"img\/45062.png","label":0,"text":"me and my kids are going to rob this country blind at the expense of the very people that voted for me. and they're going to love it. believe me! bigly!"} 87 | {"id":94170,"img":"img\/94170.png","label":0,"text":"when you decide to go full gas"} 88 | {"id":4538,"img":"img\/04538.png","label":0,"text":"bows here, bows there, bows everywhere"} 89 | {"id":83946,"img":"img\/83946.png","label":1,"text":"when a white kid shoots up the school the same day you planned to bomb it"} 90 | {"id":68127,"img":"img\/68127.png","label":0,"text":"all you guys think about is dope"} 91 | {"id":97453,"img":"img\/97453.png","label":0,"text":"potato is my favorite color of the alphabet"} 92 | {"id":39578,"img":"img\/39578.png","label":1,"text":"one of the bullets they dug out of a victim from the orland gay night club shooting!"} 93 | {"id":49360,"img":"img\/49360.png","label":0,"text":"who wants ice cream me me me"} 94 | {"id":87034,"img":"img\/87034.png","label":0,"text":"you don't need a parachute to go skydiving twice you need a parachute to go skydiving twice"} 95 | {"id":6123,"img":"img\/06123.png","label":1,"text":"a group of men"} 96 | {"id":17028,"img":"img\/17028.png","label":1,"text":"feeling bored, let's go to a mosque"} 97 | {"id":42903,"img":"img\/42903.png","label":0,"text":"those stupid basterds thought it was a real shower"} 98 | {"id":71680,"img":"img\/71680.png","label":1,"text":"muslim strip club"} 99 | {"id":24135,"img":"img\/24135.png","label":0,"text":"good night keep your butthole tight"} 100 | {"id":32875,"img":"img\/32875.png","label":0,"text":"if he wears number one he's most likely the fastes on the team"} 101 | {"id":79085,"img":"img\/79085.png","label":1,"text":"today we are making vegetable soup"} 102 | {"id":61872,"img":"img\/61872.png","label":1,"text":"don't be a fool it is a bomb"} 103 | {"id":30579,"img":"img\/30579.png","label":0,"text":"white power!"} 104 | {"id":62504,"img":"img\/62504.png","label":0,"text":"we've taken this country! now fuck off! we're taking it back!"} 105 | {"id":43810,"img":"img\/43810.png","label":1,"text":"how to cook turkey in usa how to cook turkey in romania"} 106 | {"id":19730,"img":"img\/19730.png","label":1,"text":"jamal is practicing for class sport"} 107 | {"id":37420,"img":"img\/37420.png","label":0,"text":"like lebanon terrorism can destroy a beautiful country"} 108 | {"id":82509,"img":"img\/82509.png","label":0,"text":"72 virgins are ready to get fucked up"} 109 | {"id":96284,"img":"img\/96284.png","label":1,"text":"welcome to melbourne"} 110 | {"id":97068,"img":"img\/97068.png","label":1,"text":"these hitler jokes are rude anne frankly i don't care"} 111 | {"id":3214,"img":"img\/03214.png","label":0,"text":"bartender: how about a nice 12 yr old? him: boy or girl?"} 112 | {"id":23810,"img":"img\/23810.png","label":1,"text":"these aren't people"} 113 | {"id":91836,"img":"img\/91836.png","label":1,"text":"one of my 5 boyfriends slept with another girl i'll never trust men again!"} 114 | {"id":5126,"img":"img\/05126.png","label":1,"text":"can you smell what the reich is cooking"} 115 | {"id":76092,"img":"img\/76092.png","label":0,"text":"everybody is afraid of gangs"} 116 | {"id":64072,"img":"img\/64072.png","label":1,"text":"trump said immigrants come here to steal your wife is a great example she stole my speech"} 117 | {"id":91058,"img":"img\/91058.png","label":1,"text":"get in loser we're invading poland"} 118 | {"id":21643,"img":"img\/21643.png","label":0,"text":"haha pig man is peeing standing up"} 119 | {"id":19523,"img":"img\/19523.png","label":0,"text":"they see them rollin.... they hating.."} 120 | {"id":32691,"img":"img\/32691.png","label":1,"text":"this is how muslims celebrate new year"} 121 | {"id":87169,"img":"img\/87169.png","label":1,"text":"albanians talking photo with a porn star"} 122 | {"id":92738,"img":"img\/92738.png","label":1,"text":"they are the best in basketball because they can steal, shoot and run"} 123 | {"id":53027,"img":"img\/53027.png","label":0,"text":"putting bows on your pet"} 124 | {"id":27195,"img":"img\/27195.png","label":0,"text":"there's a reason we keep the beer in the bottom of the cooler"} 125 | {"id":43910,"img":"img\/43910.png","label":1,"text":"i just started a company selling landmines that look like prayer mats prophets are going through the roof"} 126 | {"id":5213,"img":"img\/05213.png","label":0,"text":"the families of the hillsborough tragedy were denied legal aid.... but shamima begum gets it to fight for her uk citizenship our system is a sick joke! - martin costello"} 127 | {"id":73945,"img":"img\/73945.png","label":0,"text":"jamal is holding a gun"} 128 | {"id":56207,"img":"img\/56207.png","label":1,"text":"we broke up but she said we could still be cousins"} 129 | {"id":68459,"img":"img\/68459.png","label":1,"text":"family trip in mexico"} 130 | {"id":63175,"img":"img\/63175.png","label":0,"text":"if i was uncircumcised i'd hide candy in my dick skin so when a girl about to suck me up, she has a pleasant surprise.. oh my god, skittles"} 131 | {"id":82945,"img":"img\/82945.png","label":0,"text":"be careful who you call ugly in highschool"} 132 | {"id":34975,"img":"img\/34975.png","label":0,"text":"what if the only reason mary got pregnant... is because god is gay and needed a surrogate"} 133 | {"id":61973,"img":"img\/61973.png","label":1,"text":"what do alcoholics and necrophiliacs have in common they both love to crack open a cold one."} 134 | {"id":16420,"img":"img\/16420.png","label":0,"text":"abnormal vs normal"} 135 | {"id":73069,"img":"img\/73069.png","label":1,"text":"immigration invasion learn the difference"} 136 | {"id":47056,"img":"img\/47056.png","label":1,"text":"illegal aliens are not welcome"} 137 | {"id":56428,"img":"img\/56428.png","label":0,"text":"if the shocker don't rock her spock her"} 138 | {"id":54206,"img":"img\/54206.png","label":0,"text":"you can smell disgust and hate from this photo"} 139 | {"id":93172,"img":"img\/93172.png","label":0,"text":"this is a terrorist"} 140 | {"id":56980,"img":"img\/56980.png","label":0,"text":"coffee isn't helping get the jumper cables"} 141 | {"id":79615,"img":"img\/79615.png","label":0,"text":"what women think men want what men really want"} 142 | {"id":38047,"img":"img\/38047.png","label":0,"text":"\"stupid bitch\" is not a race"} 143 | {"id":94560,"img":"img\/94560.png","label":1,"text":"and for my next class project!"} 144 | {"id":96250,"img":"img\/96250.png","label":0,"text":"i'm going home i'm russian"} 145 | {"id":28951,"img":"img\/28951.png","label":0,"text":"if the brim of your hat is flat and has a sticker on it i automatically assume you're a complete tool"} 146 | {"id":27614,"img":"img\/27614.png","label":0,"text":"maybe if you think before you act you'd be treated better"} 147 | {"id":4857,"img":"img\/04857.png","label":1,"text":"...even the dumb jews know that! trump is a racist."} 148 | {"id":79603,"img":"img\/79603.png","label":1,"text":"its friday cunts lets get on the gas"} 149 | {"id":26985,"img":"img\/26985.png","label":1,"text":"i just heard my girlfriend is pregnant"} 150 | {"id":14873,"img":"img\/14873.png","label":1,"text":"human dna found in ancient goat carcass scientists believe this is first proof muhammad existed"} 151 | {"id":54893,"img":"img\/54893.png","label":0,"text":"old people at weddings always poke me and say \"you're next.\" so, i started doing the same thing to them at funerals"} 152 | {"id":48236,"img":"img\/48236.png","label":0,"text":"that blood you donated could be in someone's boner right now"} 153 | {"id":74908,"img":"img\/74908.png","label":0,"text":"she is condemning shooting the gorilla"} 154 | {"id":96472,"img":"img\/96472.png","label":1,"text":"white power!"} 155 | {"id":89430,"img":"img\/89430.png","label":0,"text":"when you know you don't have to jackoff tonight"} 156 | {"id":76921,"img":"img\/76921.png","label":1,"text":"she probably asked a wrong question"} 157 | {"id":68409,"img":"img\/68409.png","label":1,"text":"for the last time im not michelle obama"} 158 | {"id":16354,"img":"img\/16354.png","label":1,"text":"we've taken this country! now fuck off! we're taking it back!"} 159 | {"id":69548,"img":"img\/69548.png","label":1,"text":"wait, mohammad come back i didn't mean to laugh, really size dosen't matter to me lol"} 160 | {"id":1456,"img":"img\/01456.png","label":1,"text":"they see them rollin..... they hating.."} 161 | {"id":17950,"img":"img\/17950.png","label":0,"text":"i was excoriated for a harmless \"tweet\" and \"my show\" was cancelled"} 162 | {"id":76295,"img":"img\/76295.png","label":0,"text":"islamic extremist planning new project"} 163 | {"id":42580,"img":"img\/42580.png","label":1,"text":"jack and jill went up the hill so jack could lick her candy but jack got a shock and a mouthful of cock cause jills real name was randy"} 164 | {"id":96180,"img":"img\/96180.png","label":1,"text":"putting bows on your pet"} 165 | {"id":98701,"img":"img\/98701.png","label":1,"text":"oh no i can't steal anymore"} 166 | {"id":15243,"img":"img\/15243.png","label":0,"text":"4 am. everyone is asleep perfect time to run from one end of the house to the other at full speed as loudly as possible"} 167 | {"id":41058,"img":"img\/41058.png","label":0,"text":"john wilkes booth be like: yall sure abraham lincoln dead? he aint up here"} 168 | {"id":62375,"img":"img\/62375.png","label":1,"text":"im not saying i'm racist im just saying all muslims are terrorists"} 169 | {"id":63507,"img":"img\/63507.png","label":0,"text":"evolution from peeling bananas to busting caps"} 170 | {"id":31208,"img":"img\/31208.png","label":1,"text":"maybe if we weren't consistent sources of crime we'd be treated better"} 171 | {"id":62319,"img":"img\/62319.png","label":1,"text":"america in 2045 \"grandpa...why didn't you stop them?!\""} 172 | {"id":68253,"img":"img\/68253.png","label":1,"text":"free blanket maybe white man not so bad after all"} 173 | {"id":34687,"img":"img\/34687.png","label":1,"text":"everyone is afraid of dark"} 174 | {"id":34209,"img":"img\/34209.png","label":0,"text":"how about a nice 12 yr old? bartender: him: boy i would love that"} 175 | {"id":39827,"img":"img\/39827.png","label":1,"text":"kick them out of country"} 176 | {"id":46920,"img":"img\/46920.png","label":1,"text":"we want sharia law !! well.... go home"} 177 | {"id":84273,"img":"img\/84273.png","label":0,"text":"when you're the reason for the company safety video"} 178 | {"id":39018,"img":"img\/39018.png","label":0,"text":"she said \"i want a ring\" i said \"bitch take your phone off silent!\""} 179 | {"id":54780,"img":"img\/54780.png","label":1,"text":"travel size coffee maker"} 180 | {"id":69150,"img":"img\/69150.png","label":1,"text":"\"white privilege the privilege of being called \"racist\" by other people who see nothing else about you except the color of your skin"} 181 | {"id":29843,"img":"img\/29843.png","label":0,"text":"the world's most wanted terrorist obama bin lying"} 182 | {"id":19385,"img":"img\/19385.png","label":0,"text":"rose's are red violet's are blue little stupid ass bitch i ain't fuckin witchu"} 183 | {"id":64510,"img":"img\/64510.png","label":0,"text":"mohammad won the contest in putting underwear on goats"} 184 | {"id":47103,"img":"img\/47103.png","label":0,"text":"maybe we weren't consistent sources of crime we'd be treated better"} 185 | {"id":61038,"img":"img\/61038.png","label":0,"text":"no wonder why the pipes are not working"} 186 | {"id":71620,"img":"img\/71620.png","label":1,"text":"we are returning our broken coffee maker"} 187 | {"id":91602,"img":"img\/91602.png","label":0,"text":"upper downer"} 188 | {"id":30145,"img":"img\/30145.png","label":0,"text":"before my presidency, there was no isis. there was no caliphate there was no blm. there was no antifa. there was no war on police i built that"} 189 | {"id":91756,"img":"img\/91756.png","label":1,"text":"abnormal vs normal"} 190 | {"id":63745,"img":"img\/63745.png","label":1,"text":"we could destroy all of them but we let some of them survive so you know why we did it"} 191 | {"id":47162,"img":"img\/47162.png","label":1,"text":"what do these two have in common? they both are good at sucking they both come with a bag"} 192 | {"id":59806,"img":"img\/59806.png","label":0,"text":"on tuesday, november 3, 2020, we make america great again by voting this nightmare out of office!"} 193 | {"id":7135,"img":"img\/07135.png","label":1,"text":"this is the greatest news manufacturing in 20th century"} 194 | {"id":35497,"img":"img\/35497.png","label":0,"text":"old people at weddings always poke me and say \"you're next.\" so, i started doing the same thing to them at funerals"} 195 | {"id":84362,"img":"img\/84362.png","label":0,"text":"unleash the inner beast on the field"} 196 | {"id":84015,"img":"img\/84015.png","label":1,"text":"my irony meter just exploded"} 197 | {"id":87251,"img":"img\/87251.png","label":0,"text":"i love apes they are both ugly and cute"} 198 | {"id":82590,"img":"img\/82590.png","label":1,"text":"burning them was a great public service"} 199 | {"id":59738,"img":"img\/59738.png","label":1,"text":"what's the difference between a refugee and e.t e.t learned english and wanted to go home"} 200 | {"id":35470,"img":"img\/35470.png","label":0,"text":"can you smell what the reigh is cooking"} 201 | {"id":72061,"img":"img\/72061.png","label":1,"text":"i like my rice crispy please"} 202 | {"id":3524,"img":"img\/03524.png","label":1,"text":"terrorist can come to this country way to easy"} 203 | {"id":86170,"img":"img\/86170.png","label":1,"text":"how to keep the thefts out"} 204 | {"id":50841,"img":"img\/50841.png","label":0,"text":"give them what they deserve kill them all"} 205 | {"id":46087,"img":"img\/46087.png","label":0,"text":"give that man a suicide"} 206 | {"id":42058,"img":"img\/42058.png","label":0,"text":"for lunch vlad the impaler loved turkey"} 207 | {"id":1967,"img":"img\/01967.png","label":0,"text":"when you decide to stop racism by burning its symbol"} 208 | {"id":54930,"img":"img\/54930.png","label":1,"text":"not all muslims are isis! some are al-qaida the rest are taliban"} 209 | {"id":38076,"img":"img\/38076.png","label":1,"text":"\"he's a squirter\" doesn't sound as sexy when you are gay"} 210 | {"id":38910,"img":"img\/38910.png","label":1,"text":"oh boy i sure love being muslim in the new zealand city of christchurch on march 15"} 211 | {"id":6723,"img":"img\/06723.png","label":1,"text":"the hardest part to eat in a vegetable"} 212 | {"id":2157,"img":"img\/02157.png","label":1,"text":"mohammed i'm pregnant.... don't sass me, you know it's funny! ~cupcakes"} 213 | {"id":62948,"img":"img\/62948.png","label":1,"text":"black lives matter >\"how\",\"why\"and \"to whom\" is unclear"} 214 | {"id":21075,"img":"img\/21075.png","label":0,"text":"size of a shuttle bus legroom of a pinto"} 215 | {"id":73601,"img":"img\/73601.png","label":1,"text":"when you are white and stole your first bike"} 216 | {"id":32579,"img":"img\/32579.png","label":0,"text":"come to the dark side we have weed"} 217 | {"id":21567,"img":"img\/21567.png","label":1,"text":"our dad is a potato farmer"} 218 | {"id":1765,"img":"img\/01765.png","label":1,"text":"after a girl dies, what organ in her body stays warm? my dick"} 219 | {"id":53172,"img":"img\/53172.png","label":0,"text":"straight outta the oven step one :turn on the gas"} 220 | {"id":32568,"img":"img\/32568.png","label":1,"text":"america taught us how to be happy slaves the same way school teach us how to be happy employees"} 221 | {"id":1742,"img":"img\/01742.png","label":0,"text":"in just one hour from now i'll only have 4 hours left until i have to work a mere 3 hours"} 222 | {"id":56098,"img":"img\/56098.png","label":1,"text":"i said \"a glass of juice\"."} 223 | {"id":78612,"img":"img\/78612.png","label":0,"text":"go see fifty shades darker in theaters alone and throw mayonnaise at couples in the dark"} 224 | {"id":34018,"img":"img\/34018.png","label":1,"text":"fuck the catholics fuck the catholic church fuck the pope"} 225 | {"id":53609,"img":"img\/53609.png","label":0,"text":"when your friend insists on paying for dinner but you still want to contribute - just the tip, come on... let me"} 226 | {"id":89362,"img":"img\/89362.png","label":1,"text":"has extra chromosome probably stole it"} 227 | {"id":80426,"img":"img\/80426.png","label":1,"text":"a talking ape"} 228 | {"id":83045,"img":"img\/83045.png","label":1,"text":"so much trash, we need a garbage collector"} 229 | {"id":61503,"img":"img\/61503.png","label":0,"text":"someone else used kids to push an agenda"} 230 | {"id":41890,"img":"img\/41890.png","label":1,"text":"house of pedophile"} 231 | {"id":52091,"img":"img\/52091.png","label":1,"text":"women ask for equal rights, so i give them equal lefts aswell"} 232 | {"id":24396,"img":"img\/24396.png","label":1,"text":"so like, why are men even gay? girls have buttholes too you know."} 233 | {"id":72168,"img":"img\/72168.png","label":1,"text":"upper downer"} 234 | {"id":85761,"img":"img\/85761.png","label":1,"text":"no wonder why the pipes are not working"} 235 | {"id":35719,"img":"img\/35719.png","label":0,"text":"jamal is practicing for class sport"} 236 | {"id":90256,"img":"img\/90256.png","label":1,"text":"whenever you feel stupid, remember there are people who throw stones at a wall! aiming for an invisible devil, told by an invisible god for an invisible heaven!"} 237 | {"id":20984,"img":"img\/20984.png","label":0,"text":"when you are white and stole your first bike"} 238 | {"id":95086,"img":"img\/95086.png","label":1,"text":"novaky concentration camp these jews can claim holocaust pensions"} 239 | {"id":49028,"img":"img\/49028.png","label":0,"text":"i finally found my girlfriends g-spot turns out it was in her sister the whole time"} 240 | {"id":18356,"img":"img\/18356.png","label":0,"text":"i took a shower this morning"} 241 | {"id":52634,"img":"img\/52634.png","label":0,"text":"the secret to a successful marriage is to keep his stomach full and his balls empty"} 242 | {"id":80512,"img":"img\/80512.png","label":1,"text":"you can make jokes about anything, just not illegals that's crossing the border"} 243 | {"id":62703,"img":"img\/62703.png","label":0,"text":"they shot the wrong gorilla"} 244 | {"id":95176,"img":"img\/95176.png","label":0,"text":"notice how the kenyan skidmark has been silent about the mueller report?"} 245 | {"id":27384,"img":"img\/27384.png","label":0,"text":"that face you make when you're a bully and just got beat up by a girl"} 246 | {"id":82437,"img":"img\/82437.png","label":1,"text":"we should learn from ocean how to treat refugees"} 247 | {"id":26943,"img":"img\/26943.png","label":1,"text":"my incesters built this country!!!!"} 248 | {"id":9152,"img":"img\/09152.png","label":0,"text":"i don't have ex's. i have y's. like, \"y the hell did i date you?!?\""} 249 | {"id":7429,"img":"img\/07429.png","label":1,"text":"i'm about to have my 5th abortion fuck you batman i'm the real crime fighter"} 250 | {"id":75639,"img":"img\/75639.png","label":1,"text":"changing every single country it touches"} 251 | {"id":46812,"img":"img\/46812.png","label":0,"text":"goat meat is very popular in islamic countries"} 252 | {"id":63827,"img":"img\/63827.png","label":1,"text":"teacher of the year for black community"} 253 | {"id":41276,"img":"img\/41276.png","label":1,"text":"evolution from peeling bananas to busting caps"} 254 | {"id":56413,"img":"img\/56413.png","label":1,"text":"feeling cute might shoot a darkie later idk"} 255 | {"id":69512,"img":"img\/69512.png","label":0,"text":"happy birthday you sexy beast"} 256 | {"id":43175,"img":"img\/43175.png","label":1,"text":"when was the last time you saw so many mentally ill people on the streat"} 257 | {"id":54129,"img":"img\/54129.png","label":0,"text":"dump pipe to the ocean"} 258 | {"id":23645,"img":"img\/23645.png","label":1,"text":"what do they have in common? you enjoy pushing them down the stairs"} 259 | {"id":89425,"img":"img\/89425.png","label":1,"text":"hello chocolate paradise princess i'm albert"} 260 | {"id":16923,"img":"img\/16923.png","label":1,"text":"do you like goat anus? smell my finger"} 261 | {"id":4769,"img":"img\/04769.png","label":1,"text":"if she's pouting like this and you don't immediately drop the argument , motherfucker you gay"} 262 | {"id":69815,"img":"img\/69815.png","label":1,"text":"thank you for allowing me into your country now we need to talk about the things that need to change around here to accommodate my beliefs"} 263 | {"id":91468,"img":"img\/91468.png","label":0,"text":"you want a whore? buy one. you want a queen? earn her. you want a nickel bag? fifteen bucks, little man"} 264 | {"id":52104,"img":"img\/52104.png","label":0,"text":"lebanon: before islam after islam but don't worry! after islam takes over, there'll be a golden age... apparently"} 265 | {"id":52031,"img":"img\/52031.png","label":0,"text":"theriously?? that's methed up!"} 266 | {"id":8795,"img":"img\/08795.png","label":0,"text":"when you are white and working in oil field"} 267 | {"id":50317,"img":"img\/50317.png","label":1,"text":"what in the hell is wrong with america voting towel heads into congress"} 268 | {"id":73914,"img":"img\/73914.png","label":1,"text":"be careful who you call ugly in highschool"} 269 | {"id":3197,"img":"img\/03197.png","label":1,"text":"i'm no bird expert but i'm guessing it's 4 females and 1 male"} 270 | {"id":49621,"img":"img\/49621.png","label":0,"text":"wait, mohammad come back i didn't mean to laugh, really size dosen't matter to me lol"} 271 | {"id":68530,"img":"img\/68530.png","label":1,"text":"a fast sandwich maker"} 272 | {"id":98547,"img":"img\/98547.png","label":0,"text":"waves mexican flag doesn't want to live in mexico"} 273 | {"id":3519,"img":"img\/03519.png","label":1,"text":"i'm going home i'm russian"} 274 | {"id":43698,"img":"img\/43698.png","label":0,"text":"don't be late for work get the protestor plow"} 275 | {"id":84756,"img":"img\/84756.png","label":1,"text":"when you ask the kid that misses class a lot why he is bald and he starts crying all right. then keep vour secrets"} 276 | {"id":93051,"img":"img\/93051.png","label":0,"text":"and for my next class project!"} 277 | {"id":68257,"img":"img\/68257.png","label":0,"text":"everyone is afraid of dark"} 278 | {"id":12834,"img":"img\/12834.png","label":1,"text":"in west philadelphia, born and raised in the anus is where i got most of my aids"} 279 | {"id":57208,"img":"img\/57208.png","label":0,"text":"i had potato for lunch"} 280 | {"id":26439,"img":"img\/26439.png","label":0,"text":"parental opioid use has negative effect on children's intelligence"} 281 | {"id":53769,"img":"img\/53769.png","label":0,"text":"in 1979 jimmy carter created the department of education since then the u.s. has gone from 1st to 17th in education"} 282 | {"id":28406,"img":"img\/28406.png","label":0,"text":"when she gives you that look"} 283 | {"id":53418,"img":"img\/53418.png","label":0,"text":"a real man loads the dishwasher every night"} 284 | {"id":10785,"img":"img\/10785.png","label":1,"text":"time to put the \"panic\" back in hispanic"} 285 | {"id":84302,"img":"img\/84302.png","label":0,"text":"why do we share more when these people are hurt but not these people"} 286 | {"id":76015,"img":"img\/76015.png","label":0,"text":"everyone celebrates christmas !"} 287 | {"id":34189,"img":"img\/34189.png","label":0,"text":"fight for what you believe in fight for the future"} 288 | {"id":52079,"img":"img\/52079.png","label":0,"text":"just bought a house! now we're house-poor"} 289 | {"id":63987,"img":"img\/63987.png","label":0,"text":"the migrant caravan arrives!"} 290 | {"id":73526,"img":"img\/73526.png","label":0,"text":"if she's still making noises, you didn't hit her hard enough"} 291 | {"id":2145,"img":"img\/02145.png","label":0,"text":"and then i asked mom, what's for dinner?"} 292 | {"id":84102,"img":"img\/84102.png","label":0,"text":"the definition of utter disgust in plain black and white"} 293 | {"id":32981,"img":"img\/32981.png","label":0,"text":"the latest and greatest. a truck that comes with a dishwasher!"} 294 | {"id":73605,"img":"img\/73605.png","label":0,"text":"no matter what smile"} 295 | {"id":7382,"img":"img\/07382.png","label":1,"text":"wouldn't be the last time those people used the oven this way"} 296 | {"id":64071,"img":"img\/64071.png","label":0,"text":"when you have someone to help and take care of you what could happen if you didn't"} 297 | {"id":80912,"img":"img\/80912.png","label":0,"text":"mr obama it is an honor"} 298 | {"id":84762,"img":"img\/84762.png","label":1,"text":"do you know how to turn your dishwasher into a snowblower? throw em a shovel"} 299 | {"id":39076,"img":"img\/39076.png","label":0,"text":"playing with the monkeys in thailand"} 300 | {"id":85679,"img":"img\/85679.png","label":1,"text":"mr obama it is an honor"} 301 | {"id":38095,"img":"img\/38095.png","label":0,"text":"haters will say it's photoshop"} 302 | {"id":52394,"img":"img\/52394.png","label":0,"text":"i forgot what are we fighting over again?"} 303 | {"id":14865,"img":"img\/14865.png","label":0,"text":"terrible as hitler was, he did enjoy watching sports"} 304 | {"id":54819,"img":"img\/54819.png","label":0,"text":"beauty shot of girlfriends focusing on different thoughts and ideas"} 305 | {"id":46082,"img":"img\/46082.png","label":0,"text":"the future of the democratic party"} 306 | {"id":73962,"img":"img\/73962.png","label":1,"text":"you should treat that the way you treat you vacuum cleaner when it stops sucking change the bag"} 307 | {"id":78462,"img":"img\/78462.png","label":0,"text":"in the last days there will be scoffers, ridiculers, deniers and mockers following after their own lusts"} 308 | {"id":49805,"img":"img\/49805.png","label":0,"text":"imagine being so disgusting there have to be laws to try to stop normal people from hating you"} 309 | {"id":24098,"img":"img\/24098.png","label":1,"text":"thanksgiving in china"} 310 | {"id":67435,"img":"img\/67435.png","label":1,"text":"when you ask a jewish girl for her number and she starts rolling her sleeve"} 311 | {"id":60893,"img":"img\/60893.png","label":0,"text":"busy, like a hive of bees, controlling what your family sees and edits all the news your hear so you'll know what to think and fear till all the truth there is, is gone. that is unless you stand, en masse. \""} 312 | {"id":93148,"img":"img\/93148.png","label":0,"text":"the proper way to pop a blackhead"} 313 | {"id":7351,"img":"img\/07351.png","label":0,"text":"\"plays uno steals all the green cards\""} 314 | {"id":16850,"img":"img\/16850.png","label":0,"text":"okay okay okay here's my serious face do i look fuhrerious yet?"} 315 | {"id":96312,"img":"img\/96312.png","label":0,"text":"when you support your little one every step of the way"} 316 | {"id":86195,"img":"img\/86195.png","label":0,"text":"dolled up and rolling with the gang like"} 317 | {"id":84162,"img":"img\/84162.png","label":0,"text":"i could have killed all of them. but i let some of them survive so you know why i killed them"} 318 | {"id":12067,"img":"img\/12067.png","label":0,"text":"pre-heat to 350 degrees"} 319 | {"id":5938,"img":"img\/05938.png","label":1,"text":"in the muslim world all gay men are well hung"} 320 | {"id":43092,"img":"img\/43092.png","label":0,"text":"pretty much sums up islam"} 321 | {"id":30148,"img":"img\/30148.png","label":1,"text":"you mean i don't have to pick it?!?"} 322 | {"id":12045,"img":"img\/12045.png","label":0,"text":"learn to hunt. it's a valuable skill"} 323 | {"id":92068,"img":"img\/92068.png","label":0,"text":"still better than mexican"} 324 | {"id":73021,"img":"img\/73021.png","label":0,"text":"when mom asks where all the jews went"} 325 | {"id":56241,"img":"img\/56241.png","label":1,"text":"94% of them test positive for rh blood type which means they have monkey blood as a result they are born with a tail and other animal traits"} 326 | {"id":28905,"img":"img\/28905.png","label":1,"text":"hey, i just met you and this is crazy but here's your number.. so camping, maybe?"} 327 | {"id":13750,"img":"img\/13750.png","label":0,"text":"wouldn't be the last time those people used the oven this way"} 328 | {"id":6582,"img":"img\/06582.png","label":1,"text":"they are terrorists anyone who doesn't agree is an idiot"} 329 | {"id":75142,"img":"img\/75142.png","label":0,"text":"you should treat that the way you treat your vacuum cleaner, when it stops sucking change the bag"} 330 | {"id":56149,"img":"img\/56149.png","label":0,"text":"the original scarecrow"} 331 | {"id":41796,"img":"img\/41796.png","label":0,"text":"nein danke i'm very happy with my gas supplier"} 332 | {"id":31570,"img":"img\/31570.png","label":1,"text":"busy, like a hive of bees, controlling what your family sees and edits all the news your hear so you'll know what to think and fear till all the truth there is, is gone that is unless you stand, en masse. \""} 333 | {"id":1726,"img":"img\/01726.png","label":1,"text":"a real man loads the dishwasher every night"} 334 | {"id":23785,"img":"img\/23785.png","label":0,"text":"hansel and gretel was a scary childrens story"} 335 | {"id":56473,"img":"img\/56473.png","label":0,"text":"knows about you. your family your children everything"} 336 | {"id":62970,"img":"img\/62970.png","label":0,"text":"if you want to know how they were treated illustration of man with four slaves chained together"} 337 | {"id":41035,"img":"img\/41035.png","label":0,"text":"my new muslim girlfriend not much to look at"} 338 | {"id":10285,"img":"img\/10285.png","label":1,"text":"you are what you eat meanwhile in africa"} 339 | {"id":95604,"img":"img\/95604.png","label":1,"text":"they want to be accepted for who they are, yet theren't able to accept themselves for who they are"} 340 | {"id":68401,"img":"img\/68401.png","label":0,"text":"when you date an asian boy and you trynna get his family to accept you"} 341 | {"id":50984,"img":"img\/50984.png","label":0,"text":"a cat in a pot"} 342 | {"id":17908,"img":"img\/17908.png","label":0,"text":"making a quick 3pm snack"} 343 | {"id":42631,"img":"img\/42631.png","label":1,"text":"i believe they are like racism just shouldn't exist"} 344 | {"id":61349,"img":"img\/61349.png","label":1,"text":"the proper way to pop a blackhead"} 345 | {"id":65093,"img":"img\/65093.png","label":0,"text":"cookie kunty seen in paris"} 346 | {"id":86357,"img":"img\/86357.png","label":0,"text":"apparently not erdogan"} 347 | {"id":74965,"img":"img\/74965.png","label":0,"text":"some people some other people"} 348 | {"id":75918,"img":"img\/75918.png","label":0,"text":"a picture of my great great great grand grandfather's tractors they don't make them like they used to"} 349 | {"id":14975,"img":"img\/14975.png","label":0,"text":"hitting is a crime. it's about time people understood that"} 350 | {"id":92317,"img":"img\/92317.png","label":0,"text":"you either die a hero, or live long enough to become the villain"} 351 | {"id":19536,"img":"img\/19536.png","label":1,"text":"when your jewish friend smells a stash of coins in public"} 352 | {"id":90236,"img":"img\/90236.png","label":0,"text":"voila! fresh batch of baked goods !"} 353 | {"id":30642,"img":"img\/30642.png","label":0,"text":"rabbi sitting in synagogue chillin and doing his thing"} 354 | {"id":59206,"img":"img\/59206.png","label":0,"text":"traveling through chimney"} 355 | {"id":19075,"img":"img\/19075.png","label":0,"text":"when you smell cake"} 356 | {"id":79042,"img":"img\/79042.png","label":0,"text":"the door to the crematorium and gas chambers are closed"} 357 | {"id":98720,"img":"img\/98720.png","label":0,"text":"when your wife just died giving birth to your daughter and you're alone in the room with your naked just born daughter"} 358 | {"id":2634,"img":"img\/02634.png","label":1,"text":"and just like that... this sandwich maker doubles as an ironing board"} 359 | {"id":84107,"img":"img\/84107.png","label":0,"text":"i ordered a dishwasher and they sent me a sandwich maker"} 360 | {"id":7198,"img":"img\/07198.png","label":1,"text":"islam is a religion of peace stop criticizing my religion"} 361 | {"id":45702,"img":"img\/45702.png","label":1,"text":"if she's still making noises, you didn't hit her hard enough"} 362 | {"id":46085,"img":"img\/46085.png","label":0,"text":"city workers repair statue cover it up with black sheet"} 363 | {"id":8451,"img":"img\/08451.png","label":0,"text":"when you ask a jewish girl for her number and she starts rolling her sleeve"} 364 | {"id":73482,"img":"img\/73482.png","label":0,"text":"it's time we took guns away and put them out of the reach of our children"} 365 | {"id":28017,"img":"img\/28017.png","label":1,"text":"say it! say it ! islam is the religion of peace"} 366 | {"id":2364,"img":"img\/02364.png","label":0,"text":"a jewish shop owner, wears a kippah or yarmulke whilst in his shop in tehran"} 367 | {"id":42538,"img":"img\/42538.png","label":0,"text":"and at that moment everything else in the world just melted away"} 368 | {"id":47183,"img":"img\/47183.png","label":0,"text":"taking the liberals to a whole new level"} 369 | {"id":72904,"img":"img\/72904.png","label":1,"text":"terrorist trojan horse enemy of the state"} 370 | {"id":90243,"img":"img\/90243.png","label":1,"text":"best gamer ever 6 million kills, 1 death"} 371 | {"id":41382,"img":"img\/41382.png","label":0,"text":"they want to be accepted for who they are, yet theren't able to accept themselves for who they are"} 372 | {"id":78156,"img":"img\/78156.png","label":0,"text":"you mean i don't have to pick it?!?"} 373 | {"id":74250,"img":"img\/74250.png","label":0,"text":"white kids kids dreaming about their first day of school"} 374 | {"id":97132,"img":"img\/97132.png","label":0,"text":"in the muslim world all gay men are well hung"} 375 | {"id":35840,"img":"img\/35840.png","label":1,"text":"i ordered a dishwasher and they sent me a sandwich maker"} 376 | {"id":78134,"img":"img\/78134.png","label":0,"text":"in other countries in pakistan"} 377 | {"id":93620,"img":"img\/93620.png","label":1,"text":"in other countries in pakistan"} 378 | {"id":57369,"img":"img\/57369.png","label":0,"text":"lose 15 lbs with this 1 simple trick photoshop"} 379 | {"id":10362,"img":"img\/10362.png","label":1,"text":"german british danish austrian diversity. apparently we don't have it and these guys do: iraqi pakistani syrian egyptian"} 380 | {"id":36058,"img":"img\/36058.png","label":0,"text":"love is love"} 381 | {"id":9467,"img":"img\/09467.png","label":1,"text":"when mom asks where all the jews went"} 382 | {"id":60183,"img":"img\/60183.png","label":1,"text":"nein danke i'm very happy with my gas supplier"} 383 | {"id":47263,"img":"img\/47263.png","label":1,"text":"pre-heat to 350 degrees"} 384 | {"id":18726,"img":"img\/18726.png","label":0,"text":"when your bed is short and your feed stick out"} 385 | {"id":50379,"img":"img\/50379.png","label":0,"text":"do you know how to turn your dishwasher into a snowblower? throw em a shovel"} 386 | {"id":6352,"img":"img\/06352.png","label":0,"text":"and then i asked mom, can we go play?"} 387 | {"id":58093,"img":"img\/58093.png","label":0,"text":"i treat them like java treats its generic parameters at run time as objects"} 388 | {"id":19508,"img":"img\/19508.png","label":1,"text":"the definition of utter disgust in plain black and white"} 389 | {"id":2783,"img":"img\/02783.png","label":1,"text":"a picture of my great great great grand grandfather's tractors they don't make them like they used to"} 390 | {"id":41728,"img":"img\/41728.png","label":0,"text":"i believe they are like racism. just shouldn't exist"} 391 | {"id":51306,"img":"img\/51306.png","label":0,"text":"eat some cookies you're not yourself when you're hungry"} 392 | {"id":50241,"img":"img\/50241.png","label":1,"text":"bitches be like no weave, no makeup, no lashes, no eyebrows"} 393 | {"id":49650,"img":"img\/49650.png","label":1,"text":"and then i asked mom, what's for dinner?"} 394 | {"id":17045,"img":"img\/17045.png","label":0,"text":"celebrating the presidential election"} 395 | {"id":89536,"img":"img\/89536.png","label":0,"text":"so many lives lost because of this man"} 396 | {"id":26187,"img":"img\/26187.png","label":1,"text":"i love my slow cooker"} 397 | {"id":74906,"img":"img\/74906.png","label":1,"text":"are you a boy or a girl i'm non binary"} 398 | {"id":28190,"img":"img\/28190.png","label":1,"text":"when finishing a race is your passion"} 399 | {"id":89642,"img":"img\/89642.png","label":1,"text":"i treat them like java treats its generic parameters at run time as objects"} 400 | {"id":59170,"img":"img\/59170.png","label":0,"text":"and just like that... this sandwich maker doubles as an ironing board"} 401 | {"id":30927,"img":"img\/30927.png","label":0,"text":"doesnt let anything stop her have a great day"} 402 | {"id":32049,"img":"img\/32049.png","label":0,"text":"german british danish austrian diversity. apparently we don't have it and these guys do: iraqi, pakistani, syrian, egyptian"} 403 | {"id":95173,"img":"img\/95173.png","label":0,"text":"take the third reich and then the first left?"} 404 | {"id":50198,"img":"img\/50198.png","label":0,"text":"you don't need cocaine to be happy"} 405 | {"id":37658,"img":"img\/37658.png","label":1,"text":"when your dishwasher is broken so you take it back to walmart to get a new one"} 406 | {"id":95613,"img":"img\/95613.png","label":1,"text":"what a weird way to spell fat and single"} 407 | {"id":43805,"img":"img\/43805.png","label":1,"text":"the original scarecrow"} 408 | {"id":18306,"img":"img\/18306.png","label":1,"text":"white kids kids dreaming about their first day of school"} 409 | {"id":46518,"img":"img\/46518.png","label":1,"text":"it's the goat humper, looks like a bunch of pigs just walked over your grave"} 410 | {"id":18547,"img":"img\/18547.png","label":1,"text":"1st day of 4th grade sandy hook elementary school"} 411 | {"id":2518,"img":"img\/02518.png","label":0,"text":"when the dog bites you and you bite back to assert dominance"} 412 | {"id":4621,"img":"img\/04621.png","label":0,"text":"you are what you eat meanwhile in africa"} 413 | {"id":98543,"img":"img\/98543.png","label":1,"text":"upper downer"} 414 | {"id":85621,"img":"img\/85621.png","label":0,"text":"no, that's not his daughter... that's his wife! yet the world is silent.."} 415 | {"id":18367,"img":"img\/18367.png","label":0,"text":"i love my slow cooker"} 416 | {"id":84036,"img":"img\/84036.png","label":0,"text":"militant dressed as martyr"} 417 | {"id":86417,"img":"img\/86417.png","label":1,"text":"i could have killed all of them. but i let some of them survive so you know why i killed them"} 418 | {"id":92567,"img":"img\/92567.png","label":0,"text":"pray for peace pray for wellbeing"} 419 | {"id":63921,"img":"img\/63921.png","label":0,"text":"taking a photo with family"} 420 | {"id":72864,"img":"img\/72864.png","label":1,"text":"traveling through chimney"} 421 | {"id":30586,"img":"img\/30586.png","label":0,"text":"just can't get rid of the stain. this is why i hate beets"} 422 | {"id":83264,"img":"img\/83264.png","label":1,"text":"randomly deciding to bake something"} 423 | {"id":93041,"img":"img\/93041.png","label":0,"text":"\"who's gonna make the sandwiches? what do we do with all these sandwiches?\""} 424 | {"id":97305,"img":"img\/97305.png","label":1,"text":"when your wife just died giving birth to your daughter and you're alone in the room with your naked just born daughter"} 425 | {"id":3568,"img":"img\/03568.png","label":0,"text":"time to play rich vs the constitution"} 426 | {"id":7653,"img":"img\/07653.png","label":1,"text":"in the last days there will be god hating scoffers, ridiculers, deniers and mockers following after their own lusts"} 427 | {"id":26397,"img":"img\/26397.png","label":0,"text":"there are still parts of the world where people celebrate capital punishment"} 428 | {"id":3217,"img":"img\/03217.png","label":0,"text":"they are like sperm you get millions, but only one works"} 429 | {"id":40618,"img":"img\/40618.png","label":0,"text":"when your company valuation hits 1 billion"} 430 | {"id":67208,"img":"img\/67208.png","label":0,"text":"does nike also sell stones in case they take it off?"} 431 | {"id":3798,"img":"img\/03798.png","label":0,"text":"hey, i just met you and this is crazy but here's your number... so camping, maybe?"} 432 | {"id":12468,"img":"img\/12468.png","label":0,"text":"say it! say it ! islam is the religion of peace"} 433 | {"id":16842,"img":"img\/16842.png","label":0,"text":"every kitchen needs a sandwich maker just like this one"} 434 | {"id":68043,"img":"img\/68043.png","label":1,"text":"plays uno steals all the green cards"} 435 | {"id":51807,"img":"img\/51807.png","label":0,"text":"are you a boy or a girl i'm non binary"} 436 | {"id":78914,"img":"img\/78914.png","label":0,"text":"terrorist trojan horse enemy of the state"} 437 | {"id":47950,"img":"img\/47950.png","label":0,"text":"democrats ilhan omar alexandria ocasio-cortez ayanna pressley rashida tlais we're new, bold, and we hate you! we hate whites, jews, christians, conservatives, republicans,"} 438 | {"id":95487,"img":"img\/95487.png","label":0,"text":"watching vegetables grow"} 439 | {"id":49260,"img":"img\/49260.png","label":0,"text":"it's the goat humper looks like a bunch of pigs just walked over your grave"} 440 | {"id":20861,"img":"img\/20861.png","label":1,"text":"pretty much sums up islam"} 441 | {"id":38914,"img":"img\/38914.png","label":0,"text":"cotton candy is what makes the theme park wonderful"} 442 | {"id":47016,"img":"img\/47016.png","label":1,"text":"last summer camp was so \"lit\""} 443 | {"id":14389,"img":"img\/14389.png","label":1,"text":"democrats we're new, bold, and we hate you! alexandria ocasio-cortez rashida tlais ilhan omar ayanna pressley we hate whites, jews, christians, conservatives, republicans"} 444 | {"id":35487,"img":"img\/35487.png","label":0,"text":"randomly deciding to bake something"} 445 | {"id":84150,"img":"img\/84150.png","label":0,"text":"when your jewish friend smells a stash of coins in public"} 446 | {"id":13809,"img":"img\/13809.png","label":1,"text":"doing o's with the smoke"} 447 | {"id":50261,"img":"img\/50261.png","label":1,"text":"haters will say it's photoshop"} 448 | {"id":75286,"img":"img\/75286.png","label":1,"text":"look at this sandwich maker club i found on wish on clearance"} 449 | {"id":64312,"img":"img\/64312.png","label":1,"text":"no, that's not his daughter.. that's his wife! yet the world is silent.."} 450 | {"id":67103,"img":"img\/67103.png","label":1,"text":"still better than mexican"} 451 | {"id":18059,"img":"img\/18059.png","label":1,"text":"they are like sperm you get millions, but only one works"} 452 | {"id":90643,"img":"img\/90643.png","label":0,"text":"she represents the new wave of democrats"} 453 | {"id":16749,"img":"img\/16749.png","label":0,"text":"broken dishwasher"} 454 | {"id":37814,"img":"img\/37814.png","label":0,"text":"this is el paso shooter's father every. single. time"} 455 | {"id":92046,"img":"img\/92046.png","label":0,"text":"they are terrorists anyone who doesn't agree is an idiot"} 456 | {"id":90843,"img":"img\/90843.png","label":0,"text":"time to put the \"panic\" back in hispanic"} 457 | {"id":85237,"img":"img\/85237.png","label":0,"text":"these are all the weapons that were seized from isis"} 458 | {"id":98235,"img":"img\/98235.png","label":0,"text":"fashion forward !"} 459 | {"id":90267,"img":"img\/90267.png","label":0,"text":"i've been driving this baby for years the mileage isn't great but still gets the job done"} 460 | {"id":48296,"img":"img\/48296.png","label":1,"text":"when you date an asian boy and you trynna get his family to accept you"} 461 | {"id":29174,"img":"img\/29174.png","label":0,"text":"the crimes against humanity should not be swiftly forgotten"} 462 | {"id":65801,"img":"img\/65801.png","label":0,"text":"first friday prayers of the islamic holy day"} 463 | {"id":91763,"img":"img\/91763.png","label":1,"text":"imagine being so disgusting there have to be laws to try to stop normal people from hating you"} 464 | {"id":94387,"img":"img\/94387.png","label":0,"text":"i'm a fan of the color black on you although the rose gold looks pretty good as well"} 465 | {"id":80947,"img":"img\/80947.png","label":1,"text":"my new muslim girlfriend not much to look at"} 466 | {"id":78251,"img":"img\/78251.png","label":1,"text":"watching vegetables grow"} 467 | {"id":94738,"img":"img\/94738.png","label":0,"text":"shopping is boring. why not have some fun at the walmart parking lot?"} 468 | {"id":57823,"img":"img\/57823.png","label":0,"text":"bitches be like no weave, no makeup, no lashes, no eyebrows & no filter. all natural"} 469 | {"id":32415,"img":"img\/32415.png","label":0,"text":"when finishing a race is your passion"} 470 | {"id":5316,"img":"img\/05316.png","label":0,"text":"doing o's with the smoke"} 471 | {"id":61085,"img":"img\/61085.png","label":0,"text":"upper downer"} 472 | {"id":27498,"img":"img\/27498.png","label":1,"text":"who's gonna make the sandwiches? what do we do with all these sandwiches?"} 473 | {"id":37160,"img":"img\/37160.png","label":0,"text":"best gamer ever 6 million kills, 1 death\""} 474 | {"id":18742,"img":"img\/18742.png","label":0,"text":"i'm proud of it! come at me!"} 475 | {"id":19243,"img":"img\/19243.png","label":0,"text":"surprise!! wait, where is everyone?"} 476 | {"id":54108,"img":"img\/54108.png","label":1,"text":"does nike also sell stones in case they take it off?"} 477 | {"id":93528,"img":"img\/93528.png","label":1,"text":"portable dishwasher slash sandwich maker"} 478 | {"id":78659,"img":"img\/78659.png","label":0,"text":"what a weird way to spell fat and single"} 479 | {"id":67082,"img":"img\/67082.png","label":0,"text":"\"1st day of 4th grade sandy hook elementary school\""} 480 | {"id":64125,"img":"img\/64125.png","label":0,"text":"ooof. that's gotta hurt"} 481 | {"id":41296,"img":"img\/41296.png","label":0,"text":"94% of them test positive for rh blood type which means they have monkey blood as a result they are born with a tail and other animal traits"} 482 | {"id":53491,"img":"img\/53491.png","label":0,"text":"and in 3 days he came back to life bringing salvation"} 483 | {"id":37692,"img":"img\/37692.png","label":0,"text":"\"islam is a religion of peace stop criticizing my religion\""} 484 | {"id":36201,"img":"img\/36201.png","label":0,"text":"look at me bling up my prosthetic"} 485 | {"id":29054,"img":"img\/29054.png","label":1,"text":"broken dishwasher"} 486 | {"id":6273,"img":"img\/06273.png","label":1,"text":"the jihad squad"} 487 | {"id":25061,"img":"img\/25061.png","label":1,"text":"this is el paso shooter's father. every. single. time"} 488 | {"id":20437,"img":"img\/20437.png","label":0,"text":"thanksgiving in china"} 489 | {"id":5439,"img":"img\/05439.png","label":0,"text":"portable dishwasher slash sandwich maker"} 490 | {"id":34528,"img":"img\/34528.png","label":1,"text":"every kitchen needs a sandwich maker just like this one"} 491 | {"id":93541,"img":"img\/93541.png","label":0,"text":"the jihad squad"} 492 | {"id":16704,"img":"img\/16704.png","label":0,"text":"last summer camp was so \"lit\""} 493 | {"id":48792,"img":"img\/48792.png","label":0,"text":"leaving a trail of death and destruction"} 494 | {"id":7528,"img":"img\/07528.png","label":1,"text":"the latest and greatest. a truck that comes with a dishwasher!"} 495 | {"id":1268,"img":"img\/01268.png","label":0,"text":"father and daughter bonding time is the best like and share if you agree"} 496 | {"id":83675,"img":"img\/83675.png","label":0,"text":"i'm gonna be like phelps one day"} 497 | {"id":37198,"img":"img\/37198.png","label":0,"text":"when you're so relaxed you can feel yourself go to the next level"} 498 | {"id":48670,"img":"img\/48670.png","label":0,"text":"look at this sandwich maker club i found on wish on clearance"} 499 | {"id":9863,"img":"img\/09863.png","label":0,"text":"diverse group of women"} 500 | {"id":97320,"img":"img\/97320.png","label":0,"text":"\"when your dishwasher is broken so you take it back to walmart to get a new one\""} -------------------------------------------------------------------------------- /Data/training and inference data/harmc/captions.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/Mr.Harm-EMNLP2023/84a116fb7483621ca6c1b051b7055b443344817c/Data/training and inference data/harmc/captions.pkl -------------------------------------------------------------------------------- /Data/training and inference data/harmp/captions.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/Mr.Harm-EMNLP2023/84a116fb7483621ca6c1b051b7055b443344817c/Data/training and inference data/harmp/captions.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mr.Harm 2 | Official PyTorch implementation for the paper - **Beneath the Surface: Unveiling Harmful Memes with Multimodal Reasoning Distilled from Large Language Models**. 3 | 4 | (**EMNLP 2023**: *The 2023 Conference on Empirical Methods in Natural Language Processing (Findings), Dec 2023, Singapore*.) [[`paper`](https://aclanthology.org/2023.findings-emnlp.611/)] 5 | 6 | 7 | ## Install 8 | 9 | ```bash 10 | conda create -n meme python=3.8 11 | conda activate meme 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Data 16 | 17 | Please refer to [Data](https://github.com/HKBUNLP/Mr.Harm-EMNLP2023/tree/main/Data). 18 | 19 | ## Training 20 | - Learn from LLMs 21 | ```bash 22 | export DATA="/path/to/data/folder" 23 | export LOG="/path/to/save/ckpts/name" 24 | 25 | rm -rf $LOG 26 | mkdir $LOG 27 | 28 | CUDA_VISIBLE_DEVICES=0 python run.py with data_root=$DATA \ 29 | num_gpus=1 num_nodes=1 task_train per_gpu_batchsize=32 batch_size=32 \ 30 | clip32_base224 text_t5_base image_size=224 vit_randaug mode="rationale" \ 31 | log_dir=$LOG precision=32 max_epoch=10 learning_rate=5e-5 32 | ``` 33 | 34 | - Learn from Labels 35 | ```bash 36 | export DATA="/path/to/data/folder" 37 | export LOG="/path/to/save/ckpts/name" 38 | 39 | rm -rf $LOG 40 | mkdir $LOG 41 | 42 | CUDA_VISIBLE_DEVICES=0 python run.py with data_root=$DATA \ 43 | num_gpus=1 num_nodes=1 task_train per_gpu_batchsize=32 batch_size=32 \ 44 | clip32_base224 text_t5_base image_size=224 vit_randaug mode="label" \ 45 | log_dir=$LOG precision=32 max_epoch=30 learning_rate=5e-5 \ 46 | load_path="/path/to/distill_LLMs.ckpt" 47 | ``` 48 | 49 | ## Inference 50 | 51 | ```bash 52 | export DATA="/path/to/data/folder" 53 | export LOG="/path/to/log/folder" 54 | 55 | CUDA_VISIBLE_DEVICES=0 python run.py with data_root=$DATA \ 56 | num_gpus=1 num_nodes=1 task_train per_gpu_batchsize=1 batch_size=1 \ 57 | clip32_base224 text_t5_base image_size=224 vit_randaug \ 58 | log_dir=$LOG precision=32 test_only=True \ 59 | load_path="/path/to/label_learn.ckpt" \ 60 | out_path="/path/to/save/label_pred.json" 61 | ``` 62 | Then, you can use the `/path/to/save/label_pred.json` and the gold labels to get the scores. 63 | 64 | ## Citation 65 | 66 | ``` 67 | @inproceedings{lin-etal-2023-beneath, 68 | title = "Beneath the Surface: Unveiling Harmful Memes with Multimodal Reasoning Distilled from Large Language Models", 69 | author = "Lin, Hongzhan and 70 | Luo, Ziyang and 71 | Ma, Jing and 72 | Chen, Long", 73 | editor = "Bouamor, Houda and 74 | Pino, Juan and 75 | Bali, Kalika", 76 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2023", 77 | month = dec, 78 | year = "2023", 79 | address = "Singapore", 80 | publisher = "Association for Computational Linguistics", 81 | url = "https://aclanthology.org/2023.findings-emnlp.611", 82 | doi = "10.18653/v1/2023.findings-emnlp.611", 83 | pages = "9114--9128", 84 | } 85 | ``` 86 | 87 | ## Acknowledgements 88 | 89 | The code is based on [ViLT](https://github.com/dandelin/ViLT) and [METER](https://github.com/zdou0830/METER/tree/main). 90 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==1.3.2 2 | transformers==4.24.0 3 | Pillow==8.1.0 4 | tqdm==4.56.0 5 | ipdb==0.13.4 6 | numpy==1.19.5 7 | einops==0.3.0 8 | pyarrow==2.0.0 9 | sacred==0.8.2 10 | pandas==1.1.5 11 | torchmetrics==0.6.0 12 | ftfy 13 | torchvision==0.13.0 14 | jsonlines 15 | chardet 16 | torch==1.12.0 17 | sentencepiece 18 | scikit-learn -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import pytorch_lightning as pl 4 | import os 5 | os.environ["NCCL_DEBUG"] = "INFO" 6 | 7 | from src.config import ex 8 | from src.modules import MMTransformerSS 9 | from src.datamodules.multitask_datamodule import MTDataModule 10 | 11 | import resource 12 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 13 | # resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1])) 14 | 15 | @ex.automain 16 | def main(_config): 17 | _config = copy.deepcopy(_config) 18 | pl.seed_everything(_config["seed"]) 19 | 20 | dm = MTDataModule(_config, dist=True) 21 | 22 | model = MMTransformerSS(_config) 23 | exp_name = f'{_config["exp_name"]}' 24 | 25 | os.makedirs(_config["log_dir"], exist_ok=True) 26 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 27 | save_top_k=1, 28 | verbose=True, 29 | monitor="val/the_metric", 30 | mode="max" if _config["mode"] != "rationale" else "min", 31 | save_last=True, 32 | ) 33 | logger = pl.loggers.TensorBoardLogger( 34 | _config["log_dir"], 35 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 36 | ) 37 | 38 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 39 | callbacks = [checkpoint_callback, lr_callback] 40 | 41 | num_gpus = ( 42 | _config["num_gpus"] 43 | if isinstance(_config["num_gpus"], int) 44 | else len(_config["num_gpus"]) 45 | ) 46 | 47 | grad_steps = max(_config["batch_size"] // ( 48 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] 49 | ), 1) 50 | 51 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else None 52 | 53 | trainer = pl.Trainer( 54 | gpus=_config["num_gpus"], 55 | num_nodes=_config["num_nodes"], 56 | precision=_config["precision"], 57 | accelerator="ddp", 58 | benchmark=True, 59 | deterministic=True, 60 | max_epochs=_config["max_epoch"] if max_steps is None else 1000, 61 | max_steps=max_steps, 62 | callbacks=callbacks, 63 | logger=logger, 64 | prepare_data_per_node=False, 65 | replace_sampler_ddp=False, 66 | accumulate_grad_batches=grad_steps, 67 | log_every_n_steps=10, 68 | flush_logs_every_n_steps=10, 69 | resume_from_checkpoint=_config["resume_from"], 70 | weights_summary="top", 71 | fast_dev_run=_config["fast_dev_run"], 72 | val_check_interval=_config["val_check_interval"], 73 | ) 74 | 75 | if not _config["test_only"]: 76 | trainer.fit(model, datamodule=dm) 77 | else: 78 | trainer.test(model, datamodule=dm) 79 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/Mr.Harm-EMNLP2023/84a116fb7483621ca6c1b051b7055b443344817c/src/__init__.py -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | 3 | ex = Experiment("Meme", save_git_info=False) 4 | 5 | 6 | def _loss_names(d): 7 | ret = { 8 | "car": 0, 9 | } 10 | ret.update(d) 11 | return ret 12 | 13 | 14 | @ex.config 15 | def config(): 16 | exp_name = "Meme" 17 | seed = 0 18 | datasets = ["meme"] 19 | loss_names = _loss_names({"clm": 1}) 20 | batch_size = 4096 # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller. 21 | temperature = 0.05 22 | 23 | # Image setting 24 | train_transform_keys = ["vit"] 25 | val_transform_keys = ["vit"] 26 | image_size = 224 27 | patch_size = 16 28 | 29 | # Text Setting 30 | max_text_len = 40 31 | tokenizer = "t5-small" 32 | vocab_size = 32128 33 | whole_word_masking = False # note that whole_word_masking does not work for RoBERTa 34 | mlm_prob = 0.3 35 | 36 | # Transformer Setting 37 | input_image_embed_size = 768 38 | input_text_embed_size = 768 39 | vit = 'google/vit-base-patch32-224-in21k' 40 | hidden_size = 768 41 | num_heads = 12 42 | mlp_ratio = 4 43 | drop_rate = 0.1 44 | 45 | # Optimizer Setting 46 | optim_type = "adamw" 47 | learning_rate = 1e-5 48 | weight_decay = 0.01 49 | decay_power = 1 50 | max_epoch = 100 51 | max_steps = 100000 52 | warmup_steps = 10000 53 | end_lr = 0 54 | 55 | # PL Trainer Setting 56 | resume_from = None 57 | fast_dev_run = False 58 | val_check_interval = 1.0 59 | test_only = False 60 | get_recall_metric = False 61 | 62 | # below params varies with the environment 63 | data_root = "" 64 | log_dir = "result" 65 | per_gpu_batchsize = 0 # you should define this manually with per_gpu_batch_size=# 66 | num_gpus = 8 67 | num_nodes = 1 68 | load_path = "" 69 | num_workers = 8 70 | precision = 32 71 | # resume_from = "" 72 | mode = "rationale" 73 | out_path="" 74 | 75 | @ex.named_config 76 | def task_train(): 77 | exp_name = "MEME" 78 | datasets = ["meme"] 79 | loss_names = _loss_names({ 80 | "clm": 1, 81 | }) 82 | batch_size = 256 83 | temperature = 0.05 84 | max_epoch = 30 85 | max_steps = None 86 | warmup_steps = 0.1 87 | whole_word_masking = False 88 | 89 | vocab_size = 32128 90 | max_text_len = 40 91 | image_size = 224 92 | tokenizer = "bert-base-uncased" 93 | train_transform_keys = ["vit"] 94 | val_transform_keys = ["vit"] 95 | learning_rate = 5e-5 96 | val_check_interval = 1.0 97 | hidden_size = 768 98 | num_heads = 12 99 | 100 | 101 | # visual encoder 102 | @ex.named_config 103 | def vit32_base224(): 104 | vit = "google/vit-base-patch32-224-in21k" 105 | patch_size = 32 106 | image_size = 224 107 | train_transform_keys = ["vit"] 108 | val_transform_keys = ["vit"] 109 | input_image_embed_size = 768 110 | 111 | @ex.named_config 112 | def vit16_base224(): 113 | vit = "google/vit-base-patch16-224-in21k" 114 | patch_size = 16 115 | image_size = 224 116 | train_transform_keys = ["vit"] 117 | val_transform_keys = ["vit"] 118 | input_image_embed_size = 768 119 | 120 | @ex.named_config 121 | def vit16_base384(): 122 | vit = "google/vit-base-patch16-384" 123 | patch_size = 16 124 | image_size = 384 125 | train_transform_keys = ["vit"] 126 | val_transform_keys = ["vit"] 127 | input_image_embed_size = 768 128 | 129 | @ex.named_config 130 | def clip32_base224(): 131 | vit = "openai/clip-vit-base-patch32" 132 | patch_size = 32 133 | image_size = 224 134 | train_transform_keys = ["vit"] 135 | val_transform_keys = ["vit"] 136 | input_image_embed_size = 768 137 | 138 | @ex.named_config 139 | def clip16_base224(): 140 | vit = "openai/clip-vit-base-patch16" 141 | patch_size = 16 142 | image_size = 224 143 | train_transform_keys = ["vit"] 144 | val_transform_keys = ["vit"] 145 | input_image_embed_size = 768 146 | 147 | # text encoder 148 | @ex.named_config 149 | def text_bert(): 150 | tokenizer = "bert-base-uncased" 151 | vocab_size = 30522 152 | input_text_embed_size = 768 153 | 154 | # text encoder 155 | @ex.named_config 156 | def text_t5_small(): 157 | tokenizer = "google/flan-t5-small" 158 | vocab_size = 32128 159 | input_text_embed_size = 512 160 | 161 | @ex.named_config 162 | def text_t5_base(): 163 | tokenizer = "google/flan-t5-base" 164 | vocab_size = 32128 165 | input_text_embed_size = 768 166 | 167 | @ex.named_config 168 | def text_t5_large(): 169 | tokenizer = "google/flan-t5-large" 170 | vocab_size = 32128 171 | input_text_embed_size = 1024 172 | 173 | # random augmentation 174 | @ex.named_config 175 | def vit_randaug(): 176 | train_transform_keys = ["vit_randaug"] 177 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .meme_datamodule import MemeDataModule 2 | 3 | _datamodules = { 4 | "meme": MemeDataModule, 5 | } 6 | -------------------------------------------------------------------------------- /src/datamodules/datamodule_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from transformers import ( 6 | DataCollatorForLanguageModeling, 7 | DataCollatorForWholeWordMask, 8 | AutoTokenizer, 9 | ) 10 | 11 | 12 | def get_pretrained_tokenizer(from_pretrained): 13 | if torch.distributed.is_initialized(): 14 | if torch.distributed.get_rank() == 0: 15 | AutoTokenizer.from_pretrained(from_pretrained) 16 | torch.distributed.barrier() 17 | 18 | return AutoTokenizer.from_pretrained(from_pretrained) 19 | 20 | 21 | class BaseDataModule(LightningDataModule): 22 | def __init__(self, _config): 23 | super().__init__() 24 | 25 | self.data_dir = _config["data_root"] 26 | 27 | self.num_workers = _config["num_workers"] 28 | self.batch_size = _config["per_gpu_batchsize"] 29 | self.eval_batch_size = self.batch_size 30 | 31 | self.image_size = _config["image_size"] 32 | self.patch_size = _config["patch_size"] 33 | self.max_text_len = _config["max_text_len"] 34 | 35 | self.train_transform_keys = ( 36 | ["default_train"] 37 | if len(_config["train_transform_keys"]) == 0 38 | else _config["train_transform_keys"] 39 | ) 40 | 41 | self.val_transform_keys = ( 42 | ["default_val"] 43 | if len(_config["val_transform_keys"]) == 0 44 | else _config["val_transform_keys"] 45 | ) 46 | 47 | tokenizer = _config["tokenizer"] 48 | self.tokenizer = get_pretrained_tokenizer(tokenizer) 49 | self.vocab_size = self.tokenizer.vocab_size 50 | 51 | collator = ( 52 | DataCollatorForWholeWordMask 53 | if _config["whole_word_masking"] 54 | else DataCollatorForLanguageModeling 55 | ) 56 | 57 | self.mlm_collator = collator( 58 | tokenizer=self.tokenizer, mlm=False, mlm_probability=_config["mlm_prob"] 59 | ) 60 | self.setup_flag = False 61 | 62 | @property 63 | def dataset_cls(self): 64 | raise NotImplementedError("return tuple of dataset class") 65 | 66 | @property 67 | def dataset_name(self): 68 | raise NotImplementedError("return name of dataset") 69 | 70 | def set_train_dataset(self): 71 | self.train_dataset = self.dataset_cls( 72 | data_dir=self.data_dir, 73 | transform_keys=self.train_transform_keys, 74 | split="train", 75 | image_size=self.image_size, 76 | patch_size=self.patch_size, 77 | max_text_len=self.max_text_len, 78 | tokenizer=self.tokenizer, 79 | ) 80 | 81 | def set_val_dataset(self): 82 | self.val_dataset = self.dataset_cls( 83 | data_dir=self.data_dir, 84 | transform_keys=self.val_transform_keys, 85 | split="val", 86 | image_size=self.image_size, 87 | patch_size=self.patch_size, 88 | max_text_len=self.max_text_len, 89 | tokenizer=self.tokenizer, 90 | ) 91 | 92 | def set_test_dataset(self): 93 | self.test_dataset = self.dataset_cls( 94 | data_dir=self.data_dir, 95 | transform_keys=self.val_transform_keys, 96 | split="test", 97 | image_size=self.image_size, 98 | patch_size=self.patch_size, 99 | max_text_len=self.max_text_len, 100 | tokenizer=self.tokenizer, 101 | ) 102 | 103 | def make_val_dset(self, image_only=False): 104 | return self.dataset_cls( 105 | data_dir=self.data_dir, 106 | transform_keys=self.val_transform_keys, 107 | split="test", 108 | image_size=self.image_size, 109 | patch_size=self.patch_size, 110 | max_text_len=self.max_text_len, 111 | image_only=image_only, 112 | tokenizer=self.tokenizer, 113 | ) 114 | 115 | def setup(self, stage): 116 | if not self.setup_flag: 117 | self.set_train_dataset() 118 | self.set_val_dataset() 119 | self.set_test_dataset() 120 | 121 | self.train_dataset.tokenizer = self.tokenizer 122 | self.val_dataset.tokenizer = self.tokenizer 123 | self.test_dataset.tokenizer = self.tokenizer 124 | 125 | self.setup_flag = True 126 | 127 | def train_dataloader(self): 128 | loader = DataLoader( 129 | self.train_dataset, 130 | batch_size=self.batch_size, 131 | shuffle=True, 132 | num_workers=self.num_workers, 133 | pin_memory=True, 134 | collate_fn=self.train_dataset.collate, 135 | drop_last=False, 136 | ) 137 | return loader 138 | 139 | def val_dataloader(self): 140 | loader = DataLoader( 141 | self.val_dataset, 142 | batch_size=self.eval_batch_size, 143 | shuffle=False, 144 | num_workers=self.num_workers, 145 | pin_memory=True, 146 | collate_fn=self.val_dataset.collate, 147 | drop_last=False 148 | ) 149 | return loader 150 | 151 | def test_dataloader(self): 152 | loader = DataLoader( 153 | self.test_dataset, 154 | batch_size=self.eval_batch_size, 155 | shuffle=False, 156 | num_workers=self.num_workers, 157 | pin_memory=True, 158 | collate_fn=self.test_dataset.collate, 159 | drop_last=False 160 | ) 161 | return loader 162 | -------------------------------------------------------------------------------- /src/datamodules/meme_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import MemeDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class MemeDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return MemeDataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return MemeDataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "meme" 20 | -------------------------------------------------------------------------------- /src/datamodules/multitask_datamodule.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataset import ConcatDataset 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from . import _datamodules 9 | 10 | 11 | class MTDataModule(LightningDataModule): 12 | def __init__(self, _config, dist=False): 13 | datamodule_keys = _config["datasets"] 14 | assert len(datamodule_keys) > 0 15 | 16 | super().__init__() 17 | 18 | self.dm_keys = datamodule_keys 19 | self.dm_dicts = {key: _datamodules[key](_config) for key in datamodule_keys} 20 | self.dms = [v for k, v in self.dm_dicts.items()] 21 | 22 | self.batch_size = self.dms[0].batch_size 23 | self.vocab_size = self.dms[0].vocab_size 24 | self.num_workers = self.dms[0].num_workers 25 | 26 | self.dist = dist 27 | 28 | def prepare_data(self): 29 | for dm in self.dms: 30 | dm.prepare_data() 31 | 32 | def setup(self, stage): 33 | for dm in self.dms: 34 | dm.setup(stage) 35 | 36 | self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms]) 37 | self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms]) 38 | self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms]) 39 | self.tokenizer = self.dms[0].tokenizer 40 | 41 | self.collate = functools.partial( 42 | self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 43 | ) 44 | 45 | if self.dist: 46 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 47 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=True) 48 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False) 49 | else: 50 | self.train_sampler = None 51 | self.val_sampler = None 52 | self.test_sampler = None 53 | 54 | def train_dataloader(self): 55 | loader = DataLoader( 56 | self.train_dataset, 57 | batch_size=self.batch_size, 58 | sampler=self.train_sampler, 59 | num_workers=self.num_workers, 60 | collate_fn=self.collate, 61 | ) 62 | return loader 63 | 64 | def val_dataloader(self, batch_size=None): 65 | loader = DataLoader( 66 | self.val_dataset, 67 | batch_size=batch_size if batch_size is not None else self.batch_size, 68 | sampler=self.val_sampler, 69 | num_workers=self.num_workers, 70 | collate_fn=self.collate, 71 | ) 72 | return loader 73 | 74 | def test_dataloader(self): 75 | loader = DataLoader( 76 | self.test_dataset, 77 | batch_size=self.batch_size, 78 | sampler=self.test_sampler, 79 | num_workers=self.num_workers, 80 | collate_fn=self.collate, 81 | ) 82 | return loader 83 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .meme import MemeDataset 2 | -------------------------------------------------------------------------------- /src/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import io 4 | import pandas as pd 5 | import os 6 | import jsonlines 7 | 8 | from PIL import Image 9 | from ..transforms import keys_to_transforms 10 | 11 | def jsonl_reader(path): 12 | data = [] 13 | with jsonlines.open(path) as reader: 14 | for obj in reader: 15 | data.append(obj) 16 | return data 17 | 18 | class JsonDataset(torch.utils.data.Dataset): 19 | def __init__( 20 | self, 21 | data_dir, 22 | input_filename, 23 | transform_keys, 24 | image_size, 25 | patch_size, 26 | img_key, 27 | text_key, 28 | label_key, 29 | rationale_key, 30 | tokenizer=None, 31 | max_text_len=50, 32 | image_only=False, 33 | ): 34 | """ 35 | data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data 36 | transform_keys : keys for generating augmented views of images 37 | """ 38 | assert len(transform_keys) >= 1 39 | super().__init__() 40 | self.data_dir = data_dir 41 | self.image_only = image_only 42 | self.data = jsonl_reader(f"{data_dir}/{input_filename}") 43 | self.img_key = img_key 44 | self.text_key = text_key 45 | self.label_key = label_key 46 | self.rationale_key = rationale_key 47 | self.transforms = keys_to_transforms(transform_keys, size=image_size) 48 | self.max_text_len = max_text_len 49 | self.image_size = image_size 50 | self.patch_size = patch_size 51 | self.tokenizer = tokenizer 52 | 53 | def __len__(self): 54 | return len(self.data) if self.data else 0 55 | 56 | def get_image(self, idx): 57 | image_features = self.transforms[0](Image.open(f"{self.data_dir}/images/{str(self.data[idx][self.img_key])}")).unsqueeze(0) 58 | return { 59 | "image_features": image_features, # [1, 3, H, W] 60 | "raw_index": idx, 61 | "img_path": f"{self.data_dir}/images/{str(self.data[idx][self.img_key])}", 62 | "img_index": self.data[idx]["id"], 63 | } 64 | 65 | def get_text(self, idx): 66 | text = str(self.data[idx][self.text_key]).lower() 67 | encoding = self.tokenizer( 68 | text, 69 | padding="max_length", 70 | truncation=True, 71 | max_length=self.max_text_len, 72 | return_special_tokens_mask=True, 73 | ) 74 | return { 75 | "text": (text, encoding), 76 | "raw_index": idx, 77 | } 78 | 79 | def get_label(self, idx): 80 | text = "The answer is: " + str(self.data[idx][self.label_key][0]).lower() 81 | encoding = self.tokenizer( 82 | text, 83 | padding="max_length", 84 | truncation=True, 85 | max_length=32, 86 | return_special_tokens_mask=True, 87 | ) 88 | return { 89 | "label": (text, encoding), 90 | "raw_index": idx, 91 | } 92 | 93 | def get_rationale(self, idx): 94 | text = "Output: " + str(self.data[idx][self.rationale_key][0]).lower() 95 | encoding = self.tokenizer( 96 | text, 97 | padding="max_length", 98 | truncation=True, 99 | max_length=self.max_text_len, 100 | return_special_tokens_mask=True, 101 | ) 102 | return { 103 | "rationale": (text, encoding), 104 | "raw_index": idx, 105 | } 106 | 107 | def get_suite(self, idx): 108 | result = None 109 | while result is None: 110 | try: 111 | ret = dict() 112 | ret.update(self.get_image(idx)) 113 | if not self.image_only: 114 | ret.update(self.get_text(idx)) 115 | ret.update(self.get_label(idx)) 116 | try: 117 | ret.update(self.get_rationale(idx)) 118 | except: 119 | pass 120 | result = True 121 | except Exception as e: 122 | print(f"Error while read file idx {idx} in {self.data_dir}/{str(self.data[idx][self.img_key])} -> {e}") 123 | idx = random.randint(0, len(self.data) - 1) 124 | 125 | return ret 126 | 127 | def collate(self, batch, mlm_collator): 128 | # collate的作用是将一个batch的数据重新打包 129 | # 现在的batch就是一个list的dictionary 130 | batch_size = len(batch) 131 | # 收集所有的key 132 | keys = set([key for b in batch for key in b.keys()]) 133 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 134 | 135 | batch_image_features = torch.cat(dict_batch["image_features"], dim=0) # [bs, 3, H, W] 136 | dict_batch["image_features"] = batch_image_features 137 | 138 | txt_keys = [k for k in list(dict_batch.keys()) if "text" in k] 139 | 140 | if len(txt_keys) != 0: 141 | texts = [[d[0] for d in dict_batch[txt_key]] for txt_key in txt_keys] 142 | encodings = [[d[1] for d in dict_batch[txt_key]] for txt_key in txt_keys] 143 | flatten_encodings = [e for encoding in encodings for e in encoding] 144 | 145 | # Prepare for text encoder 146 | flatten_mlms = mlm_collator(flatten_encodings) 147 | 148 | for i, txt_key in enumerate(txt_keys): 149 | texts, encodings = ( 150 | [d[0] for d in dict_batch[txt_key]], 151 | [d[1] for d in dict_batch[txt_key]], 152 | ) 153 | 154 | mlm_ids, mlm_labels = ( 155 | flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)], 156 | flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)], 157 | ) 158 | 159 | input_ids = torch.zeros_like(mlm_ids) 160 | attention_mask = torch.zeros_like(mlm_ids) 161 | for _i, encoding in enumerate(encodings): 162 | _input_ids, _attention_mask = ( 163 | torch.tensor(encoding["input_ids"]), 164 | torch.tensor(encoding["attention_mask"]), 165 | ) 166 | input_ids[_i, : len(_input_ids)] = _input_ids 167 | attention_mask[_i, : len(_attention_mask)] = _attention_mask 168 | 169 | dict_batch[txt_key] = texts 170 | dict_batch[f"{txt_key}_ids"] = input_ids 171 | dict_batch[f"{txt_key}_masks"] = attention_mask 172 | 173 | # Prepare for text decoder, labels 174 | label_keys = [k for k in list(dict_batch.keys()) if "label" in k] 175 | 176 | if len(label_keys) != 0: 177 | labels = [[d[0] for d in dict_batch[label_key]] for label_key in label_keys] 178 | encodings = [[d[1] for d in dict_batch[label_key]] for label_key in label_keys] 179 | flatten_encodings = [e for encoding in encodings for e in encoding] 180 | 181 | flatten_mlms = mlm_collator(flatten_encodings) 182 | 183 | for i, label_key in enumerate(label_keys): 184 | labels, encodings = ( 185 | [d[0] for d in dict_batch[label_key]], 186 | [d[1] for d in dict_batch[label_key]], 187 | ) 188 | 189 | mlm_ids, mlm_labels = ( 190 | flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)], 191 | flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)], 192 | ) 193 | 194 | input_ids = torch.zeros_like(mlm_ids) 195 | attention_mask = torch.zeros_like(mlm_ids) 196 | for _i, encoding in enumerate(encodings): 197 | _input_ids, _attention_mask = ( 198 | torch.tensor(encoding["input_ids"]), 199 | torch.tensor(encoding["attention_mask"]), 200 | ) 201 | input_ids[_i, : len(_input_ids)] = _input_ids 202 | attention_mask[_i, : len(_attention_mask)] = _attention_mask 203 | 204 | dict_batch[label_key] = labels 205 | dict_batch[f"{label_key}_ids"] = input_ids 206 | dict_batch[f"{label_key}_masks"] = attention_mask 207 | 208 | # Prepare for text decoder, rationale 209 | rationale_keys = [k for k in list(dict_batch.keys()) if "rationale" in k] 210 | 211 | if len(rationale_keys) != 0: 212 | rationales = [[d[0] for d in dict_batch[rationale_key]] for rationale_key in rationale_keys] 213 | encodings = [[d[1] for d in dict_batch[rationale_key]] for rationale_key in rationale_keys] 214 | flatten_encodings = [e for encoding in encodings for e in encoding] 215 | 216 | flatten_mlms = mlm_collator(flatten_encodings) 217 | 218 | for i, rationale_key in enumerate(rationale_keys): 219 | rationales, encodings = ( 220 | [d[0] for d in dict_batch[rationale_key]], 221 | [d[1] for d in dict_batch[rationale_key]], 222 | ) 223 | 224 | mlm_ids, mlm_labels = ( 225 | flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)], 226 | flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)], 227 | ) 228 | 229 | input_ids = torch.zeros_like(mlm_ids) 230 | attention_mask = torch.zeros_like(mlm_ids) 231 | for _i, encoding in enumerate(encodings): 232 | _input_ids, _attention_mask = ( 233 | torch.tensor(encoding["input_ids"]), 234 | torch.tensor(encoding["attention_mask"]), 235 | ) 236 | input_ids[_i, : len(_input_ids)] = _input_ids 237 | attention_mask[_i, : len(_attention_mask)] = _attention_mask 238 | 239 | dict_batch[rationale_key] = rationales 240 | dict_batch[f"{rationale_key}_ids"] = input_ids 241 | dict_batch[f"{rationale_key}_masks"] = attention_mask 242 | 243 | return dict_batch -------------------------------------------------------------------------------- /src/datasets/meme.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import JsonDataset 2 | import io 3 | from PIL import Image 4 | 5 | class MemeDataset(JsonDataset): 6 | def __init__(self, *args, split="", **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | input_filename = "train.rationale.label.jsonl" 12 | elif split == "val": 13 | input_filename = "val.rationale.label.jsonl" 14 | elif split == "test": 15 | input_filename = "test.jsonl" 16 | 17 | img_key = "image" 18 | text_key = "text" 19 | label_key = "labels" 20 | rationale_key = "rationale" 21 | 22 | super().__init__( 23 | *args, 24 | **kwargs, 25 | input_filename=input_filename, 26 | img_key=img_key, 27 | text_key=text_key, 28 | label_key=label_key, 29 | rationale_key=rationale_key, 30 | ) 31 | 32 | 33 | def __getitem__(self, index): 34 | suite = self.get_suite(index) 35 | return suite 36 | -------------------------------------------------------------------------------- /src/gadgets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/Mr.Harm-EMNLP2023/84a116fb7483621ca6c1b051b7055b443344817c/src/gadgets/__init__.py -------------------------------------------------------------------------------- /src/gadgets/my_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.metrics import Metric 3 | 4 | 5 | class Accuracy(Metric): 6 | def __init__(self, dist_sync_on_step=False): 7 | super().__init__(dist_sync_on_step=dist_sync_on_step) 8 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 9 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 10 | 11 | def update(self, logits, target): 12 | logits, target = ( 13 | logits.detach().to(self.correct.device), 14 | target.detach().to(self.correct.device), 15 | ) 16 | preds = logits.argmax(dim=-1) 17 | preds = preds[target != -100] 18 | target = target[target != -100] 19 | if target.numel() == 0: 20 | return 1 21 | 22 | assert preds.shape == target.shape 23 | 24 | self.correct += torch.sum(preds == target) 25 | self.total += target.numel() 26 | 27 | def compute(self): 28 | return self.correct / self.total 29 | 30 | 31 | class Scalar(Metric): 32 | def __init__(self, dist_sync_on_step=False): 33 | super().__init__(dist_sync_on_step=dist_sync_on_step) 34 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 35 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 36 | 37 | def update(self, scalar): 38 | if isinstance(scalar, torch.Tensor): 39 | scalar = scalar.detach().to(self.scalar.device) 40 | else: 41 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 42 | self.scalar += scalar 43 | self.total += 1 44 | 45 | def compute(self): 46 | return self.scalar / self.total 47 | 48 | 49 | class VQAScore(Metric): 50 | def __init__(self, dist_sync_on_step=False): 51 | super().__init__(dist_sync_on_step=dist_sync_on_step) 52 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 53 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 54 | 55 | def update(self, logits, target): 56 | logits, target = ( 57 | logits.detach().float().to(self.score.device), 58 | target.detach().float().to(self.score.device), 59 | ) 60 | logits = torch.max(logits, 1)[1] 61 | one_hots = torch.zeros(*target.size()).to(target) 62 | one_hots.scatter_(1, logits.view(-1, 1), 1) 63 | scores = one_hots * target 64 | 65 | self.score += scores.sum() 66 | self.total += len(logits) 67 | 68 | def compute(self): 69 | return self.score / self.total 70 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .mm_module import MMTransformerSS -------------------------------------------------------------------------------- /src/modules/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import functools 8 | import logging 9 | import numpy as np 10 | import pickle 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import torch 15 | 16 | _LOCAL_PROCESS_GROUP = None 17 | """ 18 | A torch process group which only includes processes that on the same machine as the current process. 19 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 20 | """ 21 | 22 | 23 | def get_world_size() -> int: 24 | if not dist.is_available(): 25 | return 1 26 | if not dist.is_initialized(): 27 | return 1 28 | return dist.get_world_size() 29 | 30 | 31 | def get_rank() -> int: 32 | if not dist.is_available(): 33 | return 0 34 | if not dist.is_initialized(): 35 | return 0 36 | return dist.get_rank() 37 | 38 | 39 | def get_local_rank() -> int: 40 | """ 41 | Returns: 42 | The rank of the current process within the local (per-machine) process group. 43 | """ 44 | if not dist.is_available(): 45 | return 0 46 | if not dist.is_initialized(): 47 | return 0 48 | assert _LOCAL_PROCESS_GROUP is not None 49 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 50 | 51 | 52 | def get_local_size() -> int: 53 | """ 54 | Returns: 55 | The size of the per-machine process group, 56 | i.e. the number of processes per machine. 57 | """ 58 | if not dist.is_available(): 59 | return 1 60 | if not dist.is_initialized(): 61 | return 1 62 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 63 | 64 | 65 | def is_main_process() -> bool: 66 | return get_rank() == 0 67 | 68 | 69 | def synchronize(): 70 | """ 71 | Helper function to synchronize (barrier) among all processes when 72 | using distributed training 73 | """ 74 | if not dist.is_available(): 75 | return 76 | if not dist.is_initialized(): 77 | return 78 | world_size = dist.get_world_size() 79 | if world_size == 1: 80 | return 81 | dist.barrier() 82 | 83 | 84 | @functools.lru_cache() 85 | def _get_global_gloo_group(): 86 | """ 87 | Return a process group based on gloo backend, containing all the ranks 88 | The result is cached. 89 | """ 90 | if dist.get_backend() == "nccl": 91 | return dist.new_group(backend="gloo") 92 | else: 93 | return dist.group.WORLD 94 | 95 | 96 | def _serialize_to_tensor(data, group): 97 | backend = dist.get_backend(group) 98 | assert backend in ["gloo", "nccl"] 99 | device = torch.device("cpu" if backend == "gloo" else "cuda") 100 | 101 | buffer = pickle.dumps(data) 102 | if len(buffer) > 1024 ** 3: 103 | logger = logging.getLogger(__name__) 104 | logger.warning( 105 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 106 | get_rank(), len(buffer) / (1024 ** 3), device 107 | ) 108 | ) 109 | storage = torch.ByteStorage.from_buffer(buffer) 110 | tensor = torch.ByteTensor(storage).to(device=device) 111 | return tensor 112 | 113 | 114 | def _pad_to_largest_tensor(tensor, group): 115 | """ 116 | Returns: 117 | list[int]: size of the tensor, on each rank 118 | Tensor: padded tensor that has the max size 119 | """ 120 | world_size = dist.get_world_size(group=group) 121 | assert ( 122 | world_size >= 1 123 | ), "comm.gather/all_gather must be called from ranks within the given group!" 124 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 125 | size_list = [ 126 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 127 | for _ in range(world_size) 128 | ] 129 | dist.all_gather(size_list, local_size, group=group) 130 | size_list = [int(size.item()) for size in size_list] 131 | 132 | max_size = max(size_list) 133 | 134 | # we pad the tensor because torch all_gather does not support 135 | # gathering tensors of different shapes 136 | if local_size != max_size: 137 | padding = torch.zeros( 138 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 139 | ) 140 | tensor = torch.cat((tensor, padding), dim=0) 141 | return size_list, tensor 142 | 143 | 144 | def all_gather(data, group=None): 145 | """ 146 | Run all_gather on arbitrary picklable data (not necessarily tensors). 147 | 148 | Args: 149 | data: any picklable object 150 | group: a torch process group. By default, will use a group which 151 | contains all ranks on gloo backend. 152 | 153 | Returns: 154 | list[data]: list of data gathered from each rank 155 | """ 156 | if get_world_size() == 1: 157 | return [data] 158 | if group is None: 159 | group = _get_global_gloo_group() 160 | if dist.get_world_size(group) == 1: 161 | return [data] 162 | 163 | tensor = _serialize_to_tensor(data, group) 164 | 165 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 166 | max_size = max(size_list) 167 | 168 | # receiving Tensor from all ranks 169 | tensor_list = [ 170 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 171 | for _ in size_list 172 | ] 173 | dist.all_gather(tensor_list, tensor, group=group) 174 | 175 | data_list = [] 176 | for size, tensor in zip(size_list, tensor_list): 177 | buffer = tensor.cpu().numpy().tobytes()[:size] 178 | data_list.append(pickle.loads(buffer)) 179 | 180 | return data_list 181 | 182 | 183 | def gather(data, dst=0, group=None): 184 | """ 185 | Run gather on arbitrary picklable data (not necessarily tensors). 186 | 187 | Args: 188 | data: any picklable object 189 | dst (int): destination rank 190 | group: a torch process group. By default, will use a group which 191 | contains all ranks on gloo backend. 192 | 193 | Returns: 194 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 195 | an empty list. 196 | """ 197 | if get_world_size() == 1: 198 | return [data] 199 | if group is None: 200 | group = _get_global_gloo_group() 201 | if dist.get_world_size(group=group) == 1: 202 | return [data] 203 | rank = dist.get_rank(group=group) 204 | 205 | tensor = _serialize_to_tensor(data, group) 206 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 207 | 208 | # receiving Tensor from all ranks 209 | if rank == dst: 210 | max_size = max(size_list) 211 | tensor_list = [ 212 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 213 | for _ in size_list 214 | ] 215 | dist.gather(tensor, tensor_list, dst=dst, group=group) 216 | 217 | data_list = [] 218 | for size, tensor in zip(size_list, tensor_list): 219 | buffer = tensor.cpu().numpy().tobytes()[:size] 220 | data_list.append(pickle.loads(buffer)) 221 | return data_list 222 | else: 223 | dist.gather(tensor, [], dst=dst, group=group) 224 | return [] 225 | 226 | 227 | def shared_random_seed(): 228 | """ 229 | Returns: 230 | int: a random number that is the same across all workers. 231 | If workers need a shared RNG, they can use this shared seed to 232 | create one. 233 | 234 | All workers must call this function, otherwise it will deadlock. 235 | """ 236 | ints = np.random.randint(2 ** 31) 237 | all_ints = all_gather(ints) 238 | return all_ints[0] 239 | 240 | 241 | def reduce_dict(input_dict, average=True): 242 | """ 243 | Reduce the values in the dictionary from all processes so that process with rank 244 | 0 has the reduced results. 245 | 246 | Args: 247 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 248 | average (bool): whether to do average or sum 249 | 250 | Returns: 251 | a dict with the same keys as input_dict, after reduction. 252 | """ 253 | world_size = get_world_size() 254 | if world_size < 2: 255 | return input_dict 256 | with torch.no_grad(): 257 | names = [] 258 | values = [] 259 | # sort the keys so that they are consistent across processes 260 | for k in sorted(input_dict.keys()): 261 | names.append(k) 262 | values.append(input_dict[k]) 263 | values = torch.stack(values, dim=0) 264 | dist.reduce(values, dst=0) 265 | if dist.get_rank() == 0 and average: 266 | # only main process gets accumulated, so only divide by 267 | # world_size in this case 268 | values /= world_size 269 | reduced_dict = {k: v for k, v in zip(names, values)} 270 | return reduced_dict 271 | -------------------------------------------------------------------------------- /src/modules/heads.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Optional, Tuple 6 | 7 | from transformers.models.bert.modeling_bert import ( 8 | BertPredictionHeadTransform, 9 | #BertLayer 10 | ) 11 | from transformers.modeling_utils import ( 12 | PreTrainedModel, 13 | apply_chunking_to_forward, 14 | find_pruneable_heads_and_indices, 15 | prune_linear_layer, 16 | ) 17 | from transformers.activations import ACT2FN 18 | from transformers.pytorch_utils import Conv1D 19 | 20 | ################################################################################## 21 | class BertSelfAttention(nn.Module): 22 | def __init__(self, config, position_embedding_type=None): 23 | super().__init__() 24 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 25 | raise ValueError( 26 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 27 | f"heads ({config.num_attention_heads})" 28 | ) 29 | 30 | self.num_attention_heads = config.num_attention_heads 31 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 32 | self.all_head_size = self.num_attention_heads * self.attention_head_size 33 | 34 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 35 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 36 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 37 | 38 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 39 | self.position_embedding_type = position_embedding_type or getattr( 40 | config, "position_embedding_type", "absolute" 41 | ) 42 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 43 | self.max_position_embeddings = config.max_position_embeddings 44 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 45 | 46 | self.is_decoder = config.is_decoder 47 | 48 | def transpose_for_scores(self, x): 49 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 50 | x = x.view(new_x_shape) 51 | return x.permute(0, 2, 1, 3) 52 | 53 | def forward( 54 | self, 55 | hidden_states: torch.Tensor, 56 | attention_mask: Optional[torch.FloatTensor] = None, 57 | head_mask: Optional[torch.FloatTensor] = None, 58 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 59 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 60 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 61 | output_attentions: Optional[bool] = False, 62 | ) -> Tuple[torch.Tensor]: 63 | mixed_query_layer = self.query(hidden_states) 64 | 65 | # If this is instantiated as a cross-attention module, the keys 66 | # and values come from an encoder; the attention mask needs to be 67 | # such that the encoder's padding tokens are not attended to. 68 | is_cross_attention = encoder_hidden_states is not None 69 | 70 | if is_cross_attention and past_key_value is not None: 71 | # reuse k,v, cross_attentions 72 | key_layer = past_key_value[0] 73 | value_layer = past_key_value[1] 74 | attention_mask = encoder_attention_mask 75 | elif is_cross_attention: 76 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 77 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 78 | attention_mask = encoder_attention_mask 79 | elif past_key_value is not None: 80 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 81 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 82 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 83 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 84 | else: 85 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 86 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 87 | 88 | query_layer = self.transpose_for_scores(mixed_query_layer) 89 | 90 | if self.is_decoder: 91 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 92 | # Further calls to cross_attention layer can then reuse all cross-attention 93 | # key/value_states (first "if" case) 94 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 95 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 96 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 97 | # if encoder bi-directional self-attention `past_key_value` is always `None` 98 | past_key_value = (key_layer, value_layer) 99 | 100 | # Take the dot product between "query" and "key" to get the raw attention scores. 101 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 102 | 103 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 104 | seq_length = hidden_states.size()[1] 105 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 106 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 107 | distance = position_ids_l - position_ids_r 108 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 109 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 110 | 111 | if self.position_embedding_type == "relative_key": 112 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 113 | attention_scores = attention_scores + relative_position_scores 114 | elif self.position_embedding_type == "relative_key_query": 115 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 116 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 117 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 118 | 119 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 120 | if attention_mask is not None: 121 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 122 | attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask, attention_scores.dtype) 123 | 124 | # Normalize the attention scores to probabilities. 125 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 126 | 127 | # torch.save(attention_probs, "/root/luoziyang/ckpts/16_t2i.pt") 128 | # attention_probs[0].mean(dim=0)[:, 0].mean() 129 | # print(attention_probs[0].mean(dim=0).shape) 130 | entropy = None #[attention_probs[0].mean(dim=0)[:15, 0].std().data] 131 | # for i in range(12): 132 | # entropy.append((-torch.sum(attention_probs[0][i][:, 0] * torch.log(attention_probs[0][i][:, 0]))).data) 133 | 134 | # This is actually dropping out entire tokens to attend to, which might 135 | # seem a bit unusual, but is taken from the original Transformer paper. 136 | attention_probs = self.dropout(attention_probs) 137 | 138 | # Mask heads if we want to 139 | if head_mask is not None: 140 | attention_probs = attention_probs * head_mask 141 | 142 | context_layer = torch.matmul(attention_probs, value_layer) 143 | 144 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 145 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 146 | context_layer = context_layer.view(new_context_layer_shape) 147 | 148 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 149 | 150 | if self.is_decoder: 151 | outputs = outputs + (past_key_value,) 152 | return outputs, entropy 153 | 154 | def get_extended_attention_mask(self, attention_mask, selfdtype): 155 | extended_attention_mask = attention_mask[:, None, None, :] 156 | extended_attention_mask = extended_attention_mask.to(dtype=selfdtype) # fp16 compatibility 157 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 158 | return extended_attention_mask 159 | 160 | 161 | class BertSelfOutput(nn.Module): 162 | def __init__(self, config): 163 | super().__init__() 164 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 165 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 166 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 167 | 168 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: 169 | hidden_states = self.dense(hidden_states) 170 | hidden_states = self.dropout(hidden_states) 171 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 172 | return hidden_states 173 | 174 | 175 | class BertAttention(nn.Module): 176 | def __init__(self, config, position_embedding_type=None): 177 | super().__init__() 178 | self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) 179 | self.output = BertSelfOutput(config) 180 | self.pruned_heads = set() 181 | 182 | def prune_heads(self, heads): 183 | if len(heads) == 0: 184 | return 185 | heads, index = find_pruneable_heads_and_indices( 186 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 187 | ) 188 | 189 | # Prune linear layers 190 | self.self.query = prune_linear_layer(self.self.query, index) 191 | self.self.key = prune_linear_layer(self.self.key, index) 192 | self.self.value = prune_linear_layer(self.self.value, index) 193 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 194 | 195 | # Update hyper params and store pruned heads 196 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 197 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 198 | self.pruned_heads = self.pruned_heads.union(heads) 199 | 200 | def forward( 201 | self, 202 | hidden_states: torch.Tensor, 203 | attention_mask: Optional[torch.FloatTensor] = None, 204 | head_mask: Optional[torch.FloatTensor] = None, 205 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 206 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 207 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 208 | output_attentions: Optional[bool] = False, 209 | ) -> Tuple[torch.Tensor]: 210 | self_outputs, entropy = self.self( 211 | hidden_states, 212 | attention_mask, 213 | head_mask, 214 | encoder_hidden_states, 215 | encoder_attention_mask, 216 | past_key_value, 217 | output_attentions, 218 | ) 219 | attention_output = self.output(self_outputs[0], hidden_states) 220 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 221 | return outputs, entropy 222 | 223 | 224 | class BertIntermediate(nn.Module): 225 | def __init__(self, config): 226 | super().__init__() 227 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 228 | if isinstance(config.hidden_act, str): 229 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 230 | else: 231 | self.intermediate_act_fn = config.hidden_act 232 | 233 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 234 | hidden_states = self.dense(hidden_states) 235 | hidden_states = self.intermediate_act_fn(hidden_states) 236 | return hidden_states 237 | 238 | 239 | class BertOutput(nn.Module): 240 | def __init__(self, config): 241 | super().__init__() 242 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 243 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 244 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 245 | 246 | def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: 247 | hidden_states = self.dense(hidden_states) 248 | hidden_states = self.dropout(hidden_states) 249 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 250 | return hidden_states 251 | 252 | 253 | class BertLayer(nn.Module): 254 | def __init__(self, config): 255 | super().__init__() 256 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 257 | self.seq_len_dim = 1 258 | self.attention = BertAttention(config) 259 | self.is_decoder = config.is_decoder 260 | self.add_cross_attention = config.add_cross_attention 261 | if self.add_cross_attention: 262 | if not self.is_decoder: 263 | raise ValueError(f"{self} should be used as a decoder model if cross attention is added") 264 | self.crossattention = BertAttention(config, position_embedding_type="absolute") 265 | self.intermediate = BertIntermediate(config) 266 | self.output = BertOutput(config) 267 | 268 | def forward( 269 | self, 270 | hidden_states: torch.Tensor, 271 | attention_mask: Optional[torch.FloatTensor] = None, 272 | head_mask: Optional[torch.FloatTensor] = None, 273 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 274 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 275 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 276 | output_attentions: Optional[bool] = False, 277 | ) -> Tuple[torch.Tensor]: 278 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 279 | entropy = None 280 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 281 | self_attention_outputs, _ = self.attention( 282 | hidden_states, 283 | attention_mask, 284 | head_mask, 285 | output_attentions=output_attentions, 286 | past_key_value=self_attn_past_key_value, 287 | ) 288 | attention_output = self_attention_outputs[0] 289 | 290 | # if decoder, the last output is tuple of self-attn cache 291 | if self.is_decoder: 292 | outputs = self_attention_outputs[1:-1] 293 | present_key_value = self_attention_outputs[-1] 294 | else: 295 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 296 | 297 | cross_attn_present_key_value = None 298 | if self.is_decoder and encoder_hidden_states is not None: 299 | if not hasattr(self, "crossattention"): 300 | raise ValueError( 301 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 302 | ) 303 | 304 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 305 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 306 | cross_attention_outputs, entropy = self.crossattention( 307 | attention_output, 308 | attention_mask, 309 | head_mask, 310 | encoder_hidden_states, 311 | encoder_attention_mask, 312 | cross_attn_past_key_value, 313 | output_attentions, 314 | ) 315 | attention_output = cross_attention_outputs[0] 316 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 317 | 318 | # add cross-attn cache to positions 3,4 of present_key_value tuple 319 | cross_attn_present_key_value = cross_attention_outputs[-1] 320 | present_key_value = present_key_value + cross_attn_present_key_value 321 | 322 | layer_output = apply_chunking_to_forward( 323 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 324 | ) 325 | outputs = (layer_output,) + outputs 326 | 327 | # if decoder, return the attn key/values as the last output 328 | if self.is_decoder: 329 | outputs = outputs + (present_key_value,) 330 | 331 | return outputs, entropy 332 | 333 | def feed_forward_chunk(self, attention_output): 334 | intermediate_output = self.intermediate(attention_output) 335 | layer_output = self.output(intermediate_output, attention_output) 336 | return layer_output 337 | ################################################################################# 338 | ################################################################################# 339 | ################################################################################# 340 | 341 | 342 | class SelfAttentionHead(nn.Module): 343 | def __init__(self, config): 344 | super().__init__() 345 | self.bertlayer = BertLayer(config) 346 | 347 | def forward(self, hidden_states, attention_mask=None): 348 | outputs, _ = self.bertlayer(hidden_states=hidden_states, attention_mask=attention_mask) 349 | return outputs 350 | 351 | 352 | class CrossAttentionHead(nn.Module): 353 | def __init__(self, config): 354 | super().__init__() 355 | config.is_decoder = True 356 | config.add_cross_attention = True 357 | self.bertcrosslayer = BertLayer(config) 358 | 359 | def forward( 360 | self, 361 | hidden_states, 362 | attention_mask, 363 | encoder_hidden_states, 364 | encoder_attention_mask, 365 | ): 366 | outputs, entropy = self.bertcrosslayer( 367 | hidden_states=hidden_states, 368 | attention_mask=attention_mask, 369 | encoder_hidden_states=encoder_hidden_states, 370 | encoder_attention_mask=encoder_attention_mask, 371 | ) 372 | return outputs, entropy 373 | 374 | 375 | class ContrastivePooler(nn.Module): 376 | def __init__(self, hidden_size): 377 | super().__init__() 378 | self.dense = nn.Linear(hidden_size, hidden_size) 379 | 380 | def forward(self, hidden_states): 381 | first_token_tensor = hidden_states[:, 0] 382 | pooled_output = self.dense(first_token_tensor) 383 | return pooled_output 384 | 385 | 386 | class MIMHead(nn.Module): 387 | def __init__(self, config): 388 | super().__init__() 389 | self.decoder = nn.Sequential( 390 | nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1), 391 | nn.PixelShuffle(config.encoder_stride), 392 | ) 393 | 394 | def forward(self, x): 395 | # cls need to be excluded 396 | batch_size, sequence_length, num_channels = x.shape 397 | height = width = int(sequence_length**0.5) 398 | x = x.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) 399 | reconstructed_pixel_values = self.decoder(x) 400 | return reconstructed_pixel_values 401 | 402 | 403 | class MLMHead(nn.Module): 404 | def __init__(self, config, weight=None): 405 | super().__init__() 406 | self.transform = BertPredictionHeadTransform(config) 407 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 408 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 409 | if weight is not None: 410 | self.decoder.weight = weight 411 | 412 | def forward(self, x): 413 | x = self.transform(x) 414 | x = self.decoder(x) + self.bias 415 | return x 416 | 417 | 418 | class CAHead(nn.Module): 419 | def __init__(self): 420 | super().__init__() 421 | self.wv = nn.Linear(768, 768) 422 | self.wk = nn.Linear(768, 768) 423 | self.wq = nn.Linear(768, 768) 424 | self.wo = nn.Linear(768, 768) 425 | self.dropout = nn.Dropout(p=0.1) 426 | 427 | def forward(self, query, key, value, attention_masks): 428 | # query [bsz, 1, 768], key [bsz, seq_len, 768] 429 | # attention_masks [bsz, seq_len] 430 | # query, key, value = self.wq(query), self.wk(key), self.wv(value) 431 | # scores = torch.matmul(query, key.transpose(-1, -2)) # [bsz, 1, seq_len] 432 | # if attention_masks is not None: 433 | # attention_masks = attention_masks.to(dtype=query.dtype).unsqueeze(1) # [bsz, 1, seq_len] fp16 compatibility 434 | # attention_masks = (1.0 - attention_masks) * -10000.0 435 | # scores = scores + attention_masks 436 | # scores = torch.softmax(scores, dim=-1) 437 | # scores = self.dropout(scores) 438 | # outputs = torch.matmul(scores, value) # [bsz, 1, 768] 439 | # return self.wo(outputs), scores 440 | 441 | h = 12 442 | bsz, seq_len, dm = key.size() 443 | d_h = dm // h 444 | query, key, value = self.wq(query), self.wk(key), self.wv(value) 445 | query = query.view(bsz, -1, h, d_h).transpose(1, 2) # [bsz, h, 1, d_h] 446 | key = key.view(bsz, -1, h, d_h).transpose(1, 2) # [bsz, h, seq_len, d_h] 447 | value = value.view(bsz, -1, h, d_h).transpose(1, 2) # [bsz, h, seq_len, d_h] 448 | 449 | scores = torch.matmul(query, key.transpose(-1, -2)) # [bsz, h, 1, seq_len] 450 | if attention_masks is not None: 451 | attention_masks = attention_masks.to(dtype=query.dtype).unsqueeze(1).unsqueeze(1) # [bsz, 1, 1, seq_len] fp16 compatibility 452 | attention_masks = (1.0 - attention_masks) * -10000.0 453 | scores = scores + attention_masks 454 | scores = torch.softmax(scores, dim=-1) 455 | scores = self.dropout(scores) 456 | outputs = torch.matmul(scores, value) # [bsz, h, 1, d_h] 457 | outputs = outputs.view(bsz, 1, dm) 458 | return self.wo(outputs), scores 459 | 460 | 461 | #################################################################### 462 | #################################################################### 463 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP 464 | 465 | class GPT2Block(nn.Module): 466 | def __init__(self, config, layer_idx=None): 467 | super().__init__() 468 | hidden_size = 768 469 | inner_dim = 4 * hidden_size 470 | 471 | self.ln_1 = nn.LayerNorm(hidden_size) 472 | self.attn = GPT2Attention(config, layer_idx=layer_idx) 473 | self.ln_2 = nn.LayerNorm(hidden_size) 474 | self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) 475 | self.ln_cross_attn = nn.LayerNorm(hidden_size) 476 | self.mlp = GPT2MLP(inner_dim, config) 477 | 478 | def forward( 479 | self, 480 | hidden_states: Optional[Tuple[torch.FloatTensor]], 481 | layer_past: Optional[Tuple[torch.Tensor]] = None, 482 | attention_mask: Optional[torch.FloatTensor] = None, 483 | head_mask: Optional[torch.FloatTensor] = None, 484 | encoder_hidden_states: Optional[torch.Tensor] = None, 485 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 486 | use_cache: Optional[bool] = False, 487 | output_attentions: Optional[bool] = False, 488 | ): 489 | residual = hidden_states 490 | hidden_states = self.ln_1(hidden_states) 491 | attn_outputs = self.attn( 492 | hidden_states, 493 | layer_past=layer_past, 494 | attention_mask=attention_mask, 495 | head_mask=head_mask, 496 | use_cache=use_cache, 497 | output_attentions=output_attentions, 498 | ) 499 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 500 | outputs = attn_outputs[1:] 501 | # residual connection 502 | hidden_states = attn_output + residual 503 | 504 | if encoder_hidden_states is not None: 505 | # add one self-attention block for cross-attention 506 | if not hasattr(self, "crossattention"): 507 | raise ValueError( 508 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 509 | "cross-attention layers by setting `config.add_cross_attention=True`" 510 | ) 511 | residual = hidden_states 512 | hidden_states = self.ln_cross_attn(hidden_states) 513 | cross_attn_outputs = self.crossattention( 514 | hidden_states, 515 | attention_mask=attention_mask, 516 | head_mask=head_mask, 517 | encoder_hidden_states=encoder_hidden_states, 518 | encoder_attention_mask=encoder_attention_mask, 519 | output_attentions=output_attentions, 520 | ) 521 | attn_output = cross_attn_outputs[0] 522 | # residual connection 523 | hidden_states = residual + attn_output 524 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 525 | 526 | residual = hidden_states 527 | hidden_states = self.ln_2(hidden_states) 528 | feed_forward_hidden_states = self.mlp(hidden_states) 529 | # residual connection 530 | hidden_states = residual + feed_forward_hidden_states 531 | 532 | if use_cache: 533 | outputs = (hidden_states,) + outputs 534 | else: 535 | outputs = (hidden_states,) + outputs[1:] 536 | 537 | return outputs # hidden_states, present, (attentions, cross_attentions) -------------------------------------------------------------------------------- /src/modules/mm_module.py: -------------------------------------------------------------------------------- 1 | from email.errors import NonPrintableDefect 2 | import torch 3 | import torch.nn as nn 4 | import pytorch_lightning as pl 5 | import numpy as np 6 | import random 7 | import json 8 | import jsonlines 9 | 10 | from torch import distributed as dist 11 | from transformers import CLIPVisionModel, T5Tokenizer 12 | 13 | from . import mm_utils 14 | from . import heads, objectives 15 | from . import dist_utils 16 | from .t5_model import T5ForMultimodalGeneration 17 | 18 | torch.backends.cudnn.enabled = False 19 | 20 | class MMTransformerSS(pl.LightningModule): 21 | def __init__(self, config): 22 | super().__init__() 23 | self.save_hyperparameters() 24 | self.mode = self.hparams.config["mode"] 25 | self.out_path = self.hparams.config["out_path"] 26 | 27 | if torch.distributed.is_initialized(): 28 | if torch.distributed.get_rank() == 0: 29 | CLIPVisionModel.from_pretrained(config["vit"]) 30 | T5ForMultimodalGeneration.from_pretrained(config['tokenizer']) 31 | torch.distributed.barrier() 32 | 33 | ##################################################################################### 34 | self.image_transformer = CLIPVisionModel.from_pretrained(config["vit"]) 35 | self.text_transformer = T5ForMultimodalGeneration.from_pretrained( 36 | config['tokenizer'], 37 | config["input_image_embed_size"], 38 | ) 39 | self.tokenizer = T5Tokenizer.from_pretrained(config['tokenizer']) 40 | ##################################################################################### 41 | for param in self.image_transformer.parameters(): 42 | param.requires_grad = False 43 | 44 | mm_utils.set_metrics(self) 45 | self.current_tasks = list() 46 | 47 | # ===================== load model ====================== 48 | if self.hparams.config["load_path"] != "": 49 | ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu") 50 | state_dict = ckpt["state_dict"] 51 | self.load_state_dict(state_dict, strict=False) 52 | 53 | self.pred_result = {} 54 | self.gold_result = {} 55 | 56 | def encode_image( 57 | self, 58 | image_features, 59 | ): 60 | last_hidden_state = self.image_transformer( 61 | pixel_values=image_features, 62 | ).last_hidden_state 63 | return last_hidden_state 64 | 65 | def infer( 66 | self, 67 | batch, 68 | ): 69 | text_ids = batch[f"text_ids"] 70 | label_ids = batch[f"label_ids"] if self.mode != "rationale" or "rationale_ids" not in batch else batch[f"rationale_ids"] 71 | label_ids[label_ids==0] = -100 72 | text_masks = batch[f"text_masks"] 73 | image_features = batch[f"image_features"] 74 | 75 | image_features = self.encode_image(image_features) 76 | text_outputs = self.text_transformer( 77 | input_ids=text_ids, 78 | attention_mask=text_masks, 79 | image_ids=image_features, 80 | labels=label_ids, 81 | ) 82 | 83 | ret = { 84 | "text_outputs": text_outputs, 85 | } 86 | 87 | return ret 88 | 89 | def forward(self, batch): 90 | ret = dict() 91 | 92 | ret.update(self.infer(batch)) 93 | 94 | ret.update(objectives.compute_clm(self, ret)) 95 | 96 | return ret 97 | 98 | def training_step(self, batch, batch_idx): 99 | mm_utils.set_task(self) 100 | output = self(batch) 101 | total_loss = sum([v for k, v in output.items() if "loss" in k]) 102 | 103 | return total_loss 104 | 105 | def training_epoch_end(self, outs): 106 | mm_utils.epoch_wrapup(self) 107 | 108 | def validation_step(self, batch, batch_idx): 109 | mm_utils.set_task(self) 110 | if self.mode != "rationale": 111 | text_ids = batch[f"text_ids"] 112 | image_features = batch[f"image_features"] 113 | image_features = self.encode_image(image_features) 114 | self.text_transformer.encoder.update_image_ids(image_features) 115 | self.text_transformer.update_image_ids(image_features) 116 | outputs = self.text_transformer.generate(text_ids, max_length=256) 117 | pred = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 118 | for iid in range(len(pred)): 119 | self.pred_result[batch["img_index"][iid]] = pred[iid] 120 | self.gold_result[batch["img_index"][iid]] = batch["label"][iid].split("The answer is: ")[-1].strip() 121 | ret = dict() 122 | else: 123 | ret = self(batch) 124 | 125 | return ret 126 | 127 | def validation_epoch_end(self, outs): 128 | if self.mode != "rationale": 129 | correct = 0 130 | for iid in self.gold_result: 131 | if iid not in self.pred_result: 132 | correct = 0 133 | break 134 | label = self.gold_result[iid] 135 | pred = self.pred_result[iid].split("The answer is: ")[-1].strip() 136 | if pred == label: 137 | correct += 1 138 | self.acc = correct / len(self.gold_result) 139 | self.pred_result = {} 140 | mm_utils.epoch_wrapup(self) 141 | 142 | def test_step(self, batch, batch_idx): 143 | mm_utils.set_task(self) 144 | 145 | text_ids = batch[f"text_ids"] 146 | image_features = batch[f"image_features"] 147 | image_features = self.encode_image(image_features) 148 | self.text_transformer.encoder.update_image_ids(image_features) 149 | self.text_transformer.update_image_ids(image_features) 150 | outputs = self.text_transformer.generate(text_ids, max_length=256) 151 | pred = self.tokenizer.decode(outputs[0], skip_special_tokens=True) 152 | ret = dict() 153 | self.pred_result[batch["img_index"][0]] = pred 154 | 155 | return ret 156 | 157 | def test_epoch_end(self, outs): 158 | with open(self.out_path, "w") as fout: 159 | json.dump(self.pred_result, fout) 160 | mm_utils.epoch_wrapup(self) 161 | 162 | def configure_optimizers(self): 163 | return mm_utils.set_schedule(self) -------------------------------------------------------------------------------- /src/modules/mm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | from transformers.optimization import AdamW 5 | from transformers import ( 6 | get_polynomial_decay_schedule_with_warmup, 7 | get_cosine_schedule_with_warmup, 8 | ) 9 | from .dist_utils import all_gather 10 | from .objectives import compute_irtr_recall 11 | from ..gadgets.my_metrics import Accuracy, VQAScore, Scalar 12 | 13 | 14 | def set_metrics(pl_module): 15 | for split in ["train", "val"]: 16 | for k, v in pl_module.hparams.config["loss_names"].items(): 17 | if v < 1: 18 | continue 19 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 20 | 21 | def epoch_wrapup(pl_module): 22 | phase = "train" if pl_module.training else "val" 23 | the_metric = 0 24 | 25 | for loss_name, v in pl_module.hparams.config["loss_names"].items(): 26 | if v < 1: 27 | continue 28 | 29 | if phase == "val" and pl_module.mode != "rationale": 30 | value = pl_module.acc 31 | loss = getattr(pl_module, f"{phase}_{loss_name}_loss").compute() 32 | pl_module.log(f"{loss_name}/{phase}/loss_epoch", loss) 33 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 34 | else: 35 | value = getattr(pl_module, f"{phase}_{loss_name}_loss").compute() 36 | pl_module.log(f"{loss_name}/{phase}/loss_epoch", value) 37 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 38 | the_metric = the_metric + value 39 | 40 | pl_module.log(f"{phase}/the_metric", the_metric) 41 | 42 | 43 | def check_non_acc_grad(pl_module): 44 | if pl_module.token_type_embeddings.weight.grad is None: 45 | return True 46 | else: 47 | grad = pl_module.token_type_embeddings.weight.grad 48 | return (grad.sum() == 0).item() 49 | 50 | 51 | def set_task(pl_module): 52 | pl_module.current_tasks = [ 53 | k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1 54 | ] 55 | return 56 | 57 | def set_schedule(pl_module): 58 | lr = pl_module.hparams.config["learning_rate"] 59 | wd = pl_module.hparams.config["weight_decay"] 60 | 61 | no_decay = [ 62 | "bias", 63 | "LayerNorm.bias", 64 | "LayerNorm.weight", 65 | "norm.bias", 66 | "norm.weight", 67 | "norm1.bias", 68 | "norm1.weight", 69 | "norm2.bias", 70 | "norm2.weight", 71 | ] 72 | end_lr = pl_module.hparams.config["end_lr"] 73 | decay_power = pl_module.hparams.config["decay_power"] 74 | optim_type = pl_module.hparams.config["optim_type"] 75 | optimizer_grouped_parameters = [ 76 | { 77 | "params": [ 78 | p 79 | for n, p in pl_module.named_parameters() 80 | if not any(nd in n for nd in no_decay) 81 | ], 82 | "weight_decay": wd, 83 | "lr": lr, 84 | }, 85 | { 86 | "params": [ 87 | p 88 | for n, p in pl_module.named_parameters() 89 | if any(nd in n for nd in no_decay) 90 | ], 91 | "weight_decay": 0.0, 92 | "lr": lr, 93 | }, 94 | ] 95 | 96 | if optim_type == "adamw": 97 | optimizer = AdamW( 98 | optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.999) 99 | ) 100 | elif optim_type == "adam": 101 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr) 102 | elif optim_type == "sgd": 103 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9) 104 | 105 | if pl_module.trainer.max_steps is None: 106 | max_steps = ( 107 | len(pl_module.trainer.datamodule.train_dataloader()) 108 | * pl_module.trainer.max_epochs 109 | // pl_module.trainer.accumulate_grad_batches 110 | ) 111 | else: 112 | max_steps = pl_module.trainer.max_steps 113 | 114 | warmup_steps = pl_module.hparams.config["warmup_steps"] 115 | if isinstance(pl_module.hparams.config["warmup_steps"], float): 116 | warmup_steps = int(max_steps * warmup_steps) 117 | 118 | if decay_power == "cosine": 119 | scheduler = get_cosine_schedule_with_warmup( 120 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps, 121 | ) 122 | else: 123 | scheduler = get_polynomial_decay_schedule_with_warmup( 124 | optimizer, 125 | num_warmup_steps=warmup_steps, 126 | num_training_steps=max_steps, 127 | lr_end=end_lr, 128 | power=decay_power, 129 | ) 130 | 131 | sched = {"scheduler": scheduler, "interval": "step"} 132 | 133 | return ( 134 | [optimizer], 135 | [sched], 136 | ) 137 | -------------------------------------------------------------------------------- /src/modules/objectives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import glob 6 | import json 7 | import tqdm 8 | import functools 9 | import numpy as np 10 | 11 | from torch.utils.data.distributed import DistributedSampler 12 | from einops import rearrange 13 | from .dist_utils import all_gather 14 | 15 | SMALL_NUM = np.log(1e-45) 16 | 17 | def compute_clm(pl_module, ret): 18 | clm_loss = ret["text_outputs"].loss 19 | 20 | new_ret = { 21 | f"clm_loss": clm_loss 22 | } 23 | 24 | phase = "train" if pl_module.training else "val" 25 | loss_clm = getattr(pl_module, f"{phase}_clm_loss")(clm_loss) 26 | pl_module.log(f"clm/{phase}/clm_loss", loss_clm) 27 | return new_ret 28 | 29 | def compute_mim(pl_module, ret, mode): 30 | reconstructed_pixel_values = ret[f"{mode}_logits"] 31 | image_features = ret["image_features"] 32 | if "self" in mode: 33 | image_masks = ret["encoder_image_masks"] 34 | else: 35 | image_masks = ret["decoder_image_masks"] 36 | 37 | size = pl_module.hparams.config["image_size"] // pl_module.hparams.config["patch_size"] 38 | bool_masked_pos = image_masks.reshape(-1, size, size) 39 | mask = ( 40 | bool_masked_pos.repeat_interleave(pl_module.hparams.config["patch_size"], 1) 41 | .repeat_interleave(pl_module.hparams.config["patch_size"], 2) 42 | .unsqueeze(1) 43 | .contiguous() 44 | ) 45 | reconstruction_loss = nn.functional.l1_loss( 46 | image_features, reconstructed_pixel_values, reduction="none" 47 | ) 48 | mim_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / 3 49 | 50 | new_ret = { 51 | f"{mode}_mim_loss": mim_loss 52 | } 53 | 54 | phase = "train" if pl_module.training else "val" 55 | loss_mim = getattr(pl_module, f"{phase}_{mode}_loss")(mim_loss) 56 | pl_module.log(f"{mode}/{phase}/{mode}_loss", loss_mim) 57 | return new_ret 58 | 59 | 60 | def compute_contrastive(pl_module, ret): 61 | # Query 62 | text_reps = F.normalize(ret["text_bottleneck_repre"]) 63 | image_reps = F.normalize(ret["image_bottleneck_repre"]) 64 | 65 | all_text_reps = pl_module.gather(text_reps) 66 | all_image_reps = pl_module.gather(image_reps) 67 | 68 | # in-batch contrastive 69 | # Cross Entropy 70 | logits_per_text = torch.einsum("nc,ck->nk", [all_text_reps, all_image_reps.transpose(-2, -1)]) / pl_module.T 71 | contrastive_loss = clip_loss(logits_per_text) 72 | 73 | new_ret = { 74 | "contrastive_loss": contrastive_loss, 75 | } 76 | 77 | phase = "train" if pl_module.training else "val" 78 | loss = getattr(pl_module, f"{phase}_contrastive_loss")(new_ret["contrastive_loss"]) 79 | pl_module.log(f"contrastive/{phase}/loss", loss) 80 | 81 | return new_ret 82 | 83 | @torch.no_grad() 84 | def compute_irtr_recall(pl_module): 85 | ### 86 | # device = "cuda" if torch.cuda.is_available() else "cpu" 87 | # model, preprocess = clip.load("ViT-B/16", device=device) 88 | ### 89 | 90 | text_dset = pl_module.trainer.datamodule.dms[0].make_val_dset() 91 | text_dset.tokenizer = pl_module.trainer.datamodule.dms[0].tokenizer 92 | text_loader = torch.utils.data.DataLoader( 93 | text_dset, 94 | batch_size=64, 95 | num_workers=pl_module.hparams.config["num_workers"], 96 | pin_memory=True, 97 | collate_fn=functools.partial( 98 | text_dset.collate, 99 | mlm_collator=pl_module.trainer.datamodule.dms[0].mlm_collator, 100 | ), 101 | ) 102 | 103 | image_dset = pl_module.trainer.datamodule.dms[0].make_val_dset( 104 | image_only=True 105 | ) 106 | image_dset.tokenizer = pl_module.trainer.datamodule.dms[0].tokenizer 107 | dist_sampler = DistributedSampler(image_dset, shuffle=False) 108 | image_loader = torch.utils.data.DataLoader( 109 | image_dset, 110 | batch_size=1, 111 | num_workers=pl_module.hparams.config["num_workers"], 112 | sampler=dist_sampler, 113 | pin_memory=True, 114 | collate_fn=functools.partial( 115 | image_dset.collate, 116 | mlm_collator=pl_module.trainer.datamodule.dms[0].mlm_collator, 117 | ), 118 | ) 119 | 120 | text_preload = list() 121 | for _b in tqdm.tqdm(text_loader, desc="text prefetch loop"): 122 | # # print(_b) 123 | # texts = clip.tokenize(_b["text"], truncate=True).to(device) 124 | # text_features = model.encode_text(texts) 125 | # # assert 1 == 2 126 | text_ids = _b["text_ids"].to(pl_module.device) 127 | text_masks = _b["text_masks"].to(pl_module.device) 128 | text_preload.append( 129 | { 130 | "img_index": _b["img_index"], 131 | "text_reps": pl_module.encode_text( 132 | text_ids, text_masks)[0] 133 | } 134 | ) 135 | 136 | tiids = list() 137 | for pre in text_preload: 138 | tiids += pre["img_index"] 139 | tiids = torch.tensor(tiids) 140 | 141 | image_preload = dict() 142 | image_preload_reps = list() 143 | for _b in tqdm.tqdm(image_loader, desc="image prefetch loop"): 144 | img_index = _b["img_index"][0] 145 | if img_index not in image_preload: 146 | # ### 147 | # img_features = [] 148 | # for img_dir in _b['img_dirs']: 149 | # img_feature = preprocess(Image.open(img_dir)).unsqueeze(0).to(device) 150 | # img_features.append(img_feature) 151 | # img_features = torch.cat(img_features, dim=0) 152 | # ### 153 | # img_reps = model.encode_image(img_features) 154 | 155 | image_features = _b["image_features"].to(pl_module.device) 156 | img_reps = pl_module.encode_image(image_features)[0] # [bsz, 768] 157 | image_preload[img_index] = 1 158 | image_preload_reps.append((img_reps, _b["img_index"])) 159 | 160 | rank_scores = list() 161 | rank_iids = list() 162 | 163 | for img_batch in tqdm.tqdm(image_preload_reps, desc="rank loop"): 164 | _img_reps, _iid = img_batch # [bsz, 768] 165 | _img_reps = _img_reps / torch.norm(_img_reps, dim=-1, keepdim=True) 166 | 167 | img_batch_score = list() 168 | for txt_batch in text_preload: 169 | _text_reps = txt_batch["text_reps"] # [bsz, 768] 170 | _text_reps = _text_reps / torch.norm(_text_reps, dim=-1, keepdim=True) 171 | with torch.cuda.amp.autocast(): 172 | score = torch.einsum('nc,cm->nm', [_img_reps, _text_reps.transpose(-1, -2)]) 173 | img_batch_score.append(score) 174 | img_batch_score = torch.cat(img_batch_score, dim=-1) # [bsz, num_texts] 175 | rank_scores.append(img_batch_score.cpu().tolist()) 176 | rank_iids += _iid 177 | 178 | ### 179 | torch.distributed.barrier() 180 | gather_rank_scores = all_gather(rank_scores) 181 | gather_rank_iids = all_gather(rank_iids) 182 | 183 | iids = torch.tensor(gather_rank_iids) 184 | iids = iids.view(-1) 185 | scores = torch.tensor(gather_rank_scores) 186 | scores = scores.view(len(iids), -1) 187 | 188 | # scores = torch.cat(rank_scores, dim=0) # [5000, 25010] 189 | # iids = torch.tensor(rank_iids).view(-1) # all image ids, [5000] 190 | ### 191 | 192 | topk5 = scores.topk(5, dim=0) 193 | topk5_iids = iids[topk5.indices] # [5, 25010] 194 | # print(topk5.values[:, 20:25]) 195 | # print(topk5_iids[:, 20:25]) 196 | # assert 1 == 2 197 | 198 | topk10 = scores.topk(10, dim=1) 199 | topk5 = scores.topk(5, dim=1) 200 | topk1 = scores.topk(1, dim=1) 201 | topk10_iids = tiids[topk10.indices] # [5000, 10] 202 | topk5_iids = tiids[topk5.indices] # [5000, 5] 203 | topk1_iids = tiids[topk1.indices] # [5000, 1] 204 | 205 | 206 | tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean() 207 | tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean() 208 | tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean() 209 | 210 | topk10 = scores.topk(10, dim=0) 211 | topk5 = scores.topk(5, dim=0) 212 | topk1 = scores.topk(1, dim=0) 213 | topk10_iids = iids[topk10.indices] # [10, 25010] 214 | topk5_iids = iids[topk5.indices] # [5, 25010] 215 | topk1_iids = iids[topk1.indices] # [1, 25010] 216 | # tiids [25010] 217 | 218 | ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean() 219 | ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean() 220 | ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean() 221 | # print((ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10)) 222 | 223 | return (ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10) 224 | 225 | 226 | def init_weights(module): 227 | if isinstance(module, (nn.Linear, nn.Embedding)): 228 | module.weight.data.normal_(mean=0.0, std=0.02) 229 | elif isinstance(module, nn.LayerNorm): 230 | module.bias.data.zero_() 231 | module.weight.data.fill_(1.0) 232 | 233 | if isinstance(module, nn.Linear) and module.bias is not None: 234 | module.bias.data.zero_() -------------------------------------------------------------------------------- /src/modules/t5_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/huggingface/transformers 3 | ''' 4 | 5 | from transformers import T5Config, T5ForConditionalGeneration 6 | from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG, T5EncoderModel, T5PreTrainedModel, T5Block, T5LayerNorm 7 | import copy 8 | import math 9 | import os 10 | import warnings 11 | from typing import Optional, Tuple, Union 12 | import torch 13 | from torch import nn 14 | from torch.nn import CrossEntropyLoss 15 | from transformers.modeling_outputs import ( 16 | BaseModelOutput, 17 | Seq2SeqLMOutput, 18 | BaseModelOutputWithPastAndCrossAttentions 19 | ) 20 | from transformers.utils import ( 21 | logging, 22 | ) 23 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 24 | from torch.utils.checkpoint import checkpoint 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | class SimpleCrossAttention(nn.Module): 30 | def __init__(self, img_hsz, txt_hsz): 31 | super().__init__() 32 | self.proj_q = nn.Linear(txt_hsz, txt_hsz // 2) 33 | self.proj_k = nn.Linear(img_hsz, txt_hsz // 2) 34 | self.proj_v = copy.deepcopy(self.proj_k) 35 | self.proj_o = nn.Linear(txt_hsz // 2, txt_hsz) 36 | self.d_h = txt_hsz // 2 37 | 38 | def forward(self, img, txt): 39 | q, k, v = self.proj_q(txt), self.proj_k(img), self.proj_v(img) 40 | score = torch.softmax(torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_h), dim=-1) 41 | o = torch.matmul(score, v) 42 | return self.proj_o(o) 43 | 44 | 45 | class T5Stack(T5PreTrainedModel): 46 | def __init__(self, config, embed_tokens=None, img_hsz=512): 47 | super().__init__(config) 48 | 49 | self.embed_tokens = embed_tokens 50 | self.is_decoder = config.is_decoder 51 | 52 | self.block = nn.ModuleList( 53 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] 54 | ) 55 | self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 56 | self.dropout = nn.Dropout(config.dropout_rate) 57 | 58 | if not self.is_decoder: 59 | self.ca_block = nn.ModuleList( 60 | [SimpleCrossAttention(img_hsz, config.hidden_size) for _ in range(config.num_layers)] 61 | ) 62 | 63 | # Initialize weights and apply final processing 64 | self.post_init() 65 | # Model parallel 66 | self.model_parallel = False 67 | self.device_map = None 68 | self.gradient_checkpointing = False 69 | self.image_ids = None 70 | 71 | def parallelize(self, device_map=None): 72 | # Check validity of device_map 73 | self.device_map = ( 74 | get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map 75 | ) 76 | assert_device_map(self.device_map, len(self.block)) 77 | self.model_parallel = True 78 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 79 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 80 | # Load onto devices 81 | for k, v in self.device_map.items(): 82 | for layer in v: 83 | cuda_device = "cuda:" + str(k) 84 | self.block[layer] = self.block[layer].to(cuda_device) 85 | 86 | # Set embed_tokens to first layer 87 | self.embed_tokens = self.embed_tokens.to(self.first_device) 88 | # Set final layer norm to last device 89 | self.final_layer_norm = self.final_layer_norm.to(self.last_device) 90 | 91 | def deparallelize(self): 92 | self.model_parallel = False 93 | self.device_map = None 94 | self.first_device = "cpu" 95 | self.last_device = "cpu" 96 | for i in range(len(self.block)): 97 | self.block[i] = self.block[i].to("cpu") 98 | self.embed_tokens = self.embed_tokens.to("cpu") 99 | self.final_layer_norm = self.final_layer_norm.to("cpu") 100 | torch.cuda.empty_cache() 101 | 102 | def get_input_embeddings(self): 103 | return self.embed_tokens 104 | 105 | def set_input_embeddings(self, new_embeddings): 106 | self.embed_tokens = new_embeddings 107 | 108 | def forward( 109 | self, 110 | input_ids=None, 111 | attention_mask=None, 112 | encoder_hidden_states=None, 113 | encoder_attention_mask=None, 114 | inputs_embeds=None, 115 | head_mask=None, 116 | cross_attn_head_mask=None, 117 | past_key_values=None, 118 | use_cache=None, 119 | output_attentions=None, 120 | output_hidden_states=None, 121 | return_dict=None, 122 | image_ids=None, 123 | ): 124 | if image_ids is None: 125 | image_ids = self.image_ids 126 | # Model parallel 127 | if self.model_parallel: 128 | torch.cuda.set_device(self.first_device) 129 | self.embed_tokens = self.embed_tokens.to(self.first_device) 130 | use_cache = use_cache if use_cache is not None else self.config.use_cache 131 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 132 | output_hidden_states = ( 133 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 134 | ) 135 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 136 | 137 | if input_ids is not None and inputs_embeds is not None: 138 | err_msg_prefix = "decoder_" if self.is_decoder else "" 139 | raise ValueError( 140 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" 141 | ) 142 | elif input_ids is not None: 143 | input_shape = input_ids.size() 144 | input_ids = input_ids.view(-1, input_shape[-1]) 145 | elif inputs_embeds is not None: 146 | input_shape = inputs_embeds.size()[:-1] 147 | else: 148 | err_msg_prefix = "decoder_" if self.is_decoder else "" 149 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") 150 | 151 | if inputs_embeds is None: 152 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 153 | inputs_embeds = self.embed_tokens(input_ids) 154 | 155 | batch_size, seq_length = input_shape 156 | 157 | # required mask seq length can be calculated via length of past 158 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length 159 | 160 | if use_cache is True: 161 | assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" 162 | 163 | if attention_mask is None: 164 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) 165 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 166 | encoder_seq_length = encoder_hidden_states.shape[1] 167 | encoder_attention_mask = torch.ones( 168 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 169 | ) 170 | 171 | # initialize past_key_values with `None` if past does not exist 172 | if past_key_values is None: 173 | past_key_values = [None] * len(self.block) 174 | 175 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 176 | # ourselves in which case we just need to make it broadcastable to all heads. 177 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) 178 | 179 | # If a 2D or 3D attention mask is provided for the cross-attention 180 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 181 | if self.is_decoder and encoder_hidden_states is not None: 182 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 183 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 184 | if encoder_attention_mask is None: 185 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) 186 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 187 | else: 188 | encoder_extended_attention_mask = None 189 | 190 | # Prepare head mask if needed 191 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 192 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) 193 | present_key_value_states = () if use_cache else None 194 | all_hidden_states = () if output_hidden_states else None 195 | all_attentions = () if output_attentions else None 196 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None 197 | position_bias = None 198 | encoder_decoder_position_bias = None 199 | 200 | hidden_states = self.dropout(inputs_embeds) 201 | 202 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): 203 | layer_head_mask = head_mask[i] 204 | cross_attn_layer_head_mask = cross_attn_head_mask[i] 205 | # Model parallel 206 | if self.model_parallel: 207 | torch.cuda.set_device(hidden_states.device) 208 | # Ensure that attention_mask is always on the same device as hidden_states 209 | if attention_mask is not None: 210 | attention_mask = attention_mask.to(hidden_states.device) 211 | if position_bias is not None: 212 | position_bias = position_bias.to(hidden_states.device) 213 | if encoder_hidden_states is not None: 214 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) 215 | if encoder_extended_attention_mask is not None: 216 | encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) 217 | if encoder_decoder_position_bias is not None: 218 | encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) 219 | if layer_head_mask is not None: 220 | layer_head_mask = layer_head_mask.to(hidden_states.device) 221 | if cross_attn_layer_head_mask is not None: 222 | cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) 223 | if output_hidden_states: 224 | all_hidden_states = all_hidden_states + (hidden_states,) 225 | 226 | if self.gradient_checkpointing and self.training: 227 | if use_cache: 228 | logger.warning( 229 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 230 | ) 231 | use_cache = False 232 | 233 | def create_custom_forward(module): 234 | def custom_forward(*inputs): 235 | return tuple(module(*inputs, use_cache, output_attentions)) 236 | 237 | return custom_forward 238 | 239 | layer_outputs = checkpoint( 240 | create_custom_forward(layer_module), 241 | hidden_states, 242 | extended_attention_mask, 243 | position_bias, 244 | encoder_hidden_states, 245 | encoder_extended_attention_mask, 246 | encoder_decoder_position_bias, 247 | layer_head_mask, 248 | cross_attn_layer_head_mask, 249 | None, # past_key_value is always None with gradient checkpointing 250 | ) 251 | else: 252 | layer_outputs = layer_module( 253 | hidden_states, 254 | attention_mask=extended_attention_mask, 255 | position_bias=position_bias, 256 | encoder_hidden_states=encoder_hidden_states, 257 | encoder_attention_mask=encoder_extended_attention_mask, 258 | encoder_decoder_position_bias=encoder_decoder_position_bias, 259 | layer_head_mask=layer_head_mask, 260 | cross_attn_layer_head_mask=cross_attn_layer_head_mask, 261 | past_key_value=past_key_value, 262 | use_cache=use_cache, 263 | output_attentions=output_attentions, 264 | ) 265 | 266 | # layer_outputs is a tuple with: 267 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 268 | if use_cache is False: 269 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] 270 | 271 | hidden_states, present_key_value_state = layer_outputs[:2] 272 | 273 | ###################### 274 | if not self.is_decoder: 275 | hidden_states = hidden_states + self.ca_block[i](image_ids, hidden_states) 276 | ##################### 277 | 278 | # We share the position biases between the layers - the first layer store them 279 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), 280 | # (cross-attention position bias), (cross-attention weights) 281 | position_bias = layer_outputs[2] 282 | if self.is_decoder and encoder_hidden_states is not None: 283 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] 284 | # append next layer key value states 285 | if use_cache: 286 | present_key_value_states = present_key_value_states + (present_key_value_state,) 287 | 288 | if output_attentions: 289 | all_attentions = all_attentions + (layer_outputs[3],) 290 | if self.is_decoder: 291 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],) 292 | 293 | # Model Parallel: If it's the last layer for that device, put things on the next device 294 | if self.model_parallel: 295 | for k, v in self.device_map.items(): 296 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 297 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 298 | 299 | hidden_states = self.final_layer_norm(hidden_states) 300 | hidden_states = self.dropout(hidden_states) 301 | 302 | # Add last layer 303 | if output_hidden_states: 304 | all_hidden_states = all_hidden_states + (hidden_states,) 305 | 306 | if not return_dict: 307 | return tuple( 308 | v 309 | for v in [ 310 | hidden_states, 311 | present_key_value_states, 312 | all_hidden_states, 313 | all_attentions, 314 | all_cross_attentions, 315 | ] 316 | if v is not None 317 | ) 318 | return BaseModelOutputWithPastAndCrossAttentions( 319 | last_hidden_state=hidden_states, 320 | past_key_values=present_key_value_states, 321 | hidden_states=all_hidden_states, 322 | attentions=all_attentions, 323 | cross_attentions=all_cross_attentions, 324 | ) 325 | 326 | def update_image_ids(self, img): 327 | self.image_ids = img 328 | 329 | 330 | class T5ForMultimodalGeneration(T5ForConditionalGeneration): 331 | _keys_to_ignore_on_load_missing = [ 332 | r"encoder.embed_tokens.weight", 333 | r"decoder.embed_tokens.weight", 334 | r"lm_head.weight", 335 | ] 336 | _keys_to_ignore_on_load_unexpected = [ 337 | r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", 338 | ] 339 | 340 | def __init__(self, config: T5Config, img_hsz=512): 341 | super().__init__(config) 342 | self.model_dim = config.d_model 343 | 344 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 345 | 346 | ############ 347 | self.ca = SimpleCrossAttention(img_hsz, config.hidden_size) 348 | ############ 349 | 350 | encoder_config = copy.deepcopy(config) 351 | encoder_config.is_decoder = False 352 | encoder_config.use_cache = False 353 | encoder_config.is_encoder_decoder = False 354 | self.encoder = T5Stack(encoder_config, self.shared, img_hsz=img_hsz) 355 | 356 | decoder_config = copy.deepcopy(config) 357 | decoder_config.is_decoder = True 358 | decoder_config.is_encoder_decoder = False 359 | decoder_config.num_layers = config.num_decoder_layers 360 | self.decoder = T5Stack(decoder_config, self.shared) 361 | 362 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 363 | 364 | # Initialize weights and apply final processing 365 | self.post_init() 366 | 367 | # Model parallel 368 | self.model_parallel = False 369 | self.device_map = None 370 | 371 | self.image_ids = None 372 | self.rationale_ids = None 373 | 374 | def update_image_ids(self, img): 375 | self.image_ids = img 376 | 377 | def forward( 378 | self, 379 | input_ids: Optional[torch.LongTensor] = None, 380 | image_ids=None, 381 | attention_mask: Optional[torch.FloatTensor] = None, 382 | decoder_input_ids: Optional[torch.LongTensor] = None, 383 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 384 | head_mask: Optional[torch.FloatTensor] = None, 385 | decoder_head_mask: Optional[torch.FloatTensor] = None, 386 | cross_attn_head_mask: Optional[torch.Tensor] = None, 387 | encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, 388 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 389 | inputs_embeds: Optional[torch.FloatTensor] = None, 390 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 391 | labels: Optional[torch.LongTensor] = None, 392 | use_cache: Optional[bool] = None, 393 | output_attentions: Optional[bool] = None, 394 | output_hidden_states: Optional[bool] = None, 395 | return_dict: Optional[bool] = None, 396 | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: 397 | use_cache = use_cache if use_cache is not None else self.config.use_cache 398 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 399 | 400 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 401 | if head_mask is not None and decoder_head_mask is None: 402 | if self.config.num_layers == self.config.num_decoder_layers: 403 | warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) 404 | decoder_head_mask = head_mask 405 | 406 | ############### 407 | if image_ids is None: 408 | image_ids = self.image_ids 409 | ############## 410 | 411 | # Encode if needed (training, first prediction pass) 412 | if encoder_outputs is None: 413 | # Convert encoder inputs in embeddings if needed 414 | encoder_outputs = self.encoder( 415 | input_ids=input_ids, 416 | attention_mask=attention_mask, 417 | inputs_embeds=inputs_embeds, 418 | head_mask=head_mask, 419 | output_attentions=output_attentions, 420 | output_hidden_states=output_hidden_states, 421 | return_dict=return_dict, 422 | image_ids=image_ids, 423 | ) 424 | 425 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 426 | encoder_outputs = BaseModelOutput( 427 | last_hidden_state=encoder_outputs[0], 428 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 429 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 430 | ) 431 | 432 | 433 | hidden_states = encoder_outputs[0] 434 | 435 | ########### 436 | if image_ids is None: 437 | image_ids = self.image_ids 438 | hidden_states = hidden_states + self.ca(image_ids, hidden_states) 439 | ########### 440 | 441 | if self.model_parallel: 442 | torch.cuda.set_device(self.decoder.first_device) 443 | 444 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: 445 | # get decoder inputs from shifting lm labels to the right 446 | decoder_input_ids = self._shift_right(labels) 447 | 448 | # Set device for model parallelism 449 | if self.model_parallel: 450 | torch.cuda.set_device(self.decoder.first_device) 451 | hidden_states = hidden_states.to(self.decoder.first_device) 452 | if decoder_input_ids is not None: 453 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) 454 | if attention_mask is not None: 455 | attention_mask = attention_mask.to(self.decoder.first_device) 456 | if decoder_attention_mask is not None: 457 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) 458 | 459 | # Decode 460 | decoder_outputs = self.decoder( 461 | input_ids=decoder_input_ids, 462 | attention_mask=decoder_attention_mask, 463 | inputs_embeds=decoder_inputs_embeds, 464 | past_key_values=past_key_values, 465 | encoder_hidden_states=hidden_states, 466 | encoder_attention_mask=attention_mask, 467 | head_mask=decoder_head_mask, 468 | cross_attn_head_mask=cross_attn_head_mask, 469 | use_cache=use_cache, 470 | output_attentions=output_attentions, 471 | output_hidden_states=output_hidden_states, 472 | return_dict=return_dict, 473 | ) 474 | 475 | sequence_output = decoder_outputs[0] 476 | 477 | # Set device for model parallelism 478 | if self.model_parallel: 479 | torch.cuda.set_device(self.encoder.first_device) 480 | self.lm_head = self.lm_head.to(self.encoder.first_device) 481 | sequence_output = sequence_output.to(self.lm_head.weight.device) 482 | 483 | if self.config.tie_word_embeddings: 484 | # Rescale output before projecting on vocab 485 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 486 | sequence_output = sequence_output * (self.model_dim**-0.5) 487 | 488 | lm_logits = self.lm_head(sequence_output) 489 | 490 | loss = None 491 | if labels is not None: 492 | loss_fct = CrossEntropyLoss(ignore_index=-100) 493 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 494 | 495 | if not return_dict: 496 | output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs 497 | return ((loss,) + output) if loss is not None else output 498 | 499 | return Seq2SeqLMOutput( 500 | loss=loss, 501 | logits=lm_logits, 502 | past_key_values=decoder_outputs.past_key_values, 503 | decoder_hidden_states=decoder_outputs.hidden_states, 504 | decoder_attentions=decoder_outputs.attentions, 505 | cross_attentions=decoder_outputs.cross_attentions, 506 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 507 | encoder_hidden_states=encoder_outputs.hidden_states, 508 | encoder_attentions=encoder_outputs.attentions, 509 | ) -------------------------------------------------------------------------------- /src/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | pixelbert_transform, 3 | pixelbert_transform_randaug, 4 | vit_transform, 5 | vit_transform_randaug, 6 | imagenet_transform, 7 | imagenet_transform_randaug, 8 | clip_transform, 9 | clip_transform_randaug, 10 | ) 11 | 12 | _transforms = { 13 | "pixelbert": pixelbert_transform, 14 | "pixelbert_randaug": pixelbert_transform_randaug, 15 | "vit": vit_transform, 16 | "vit_randaug": vit_transform_randaug, 17 | "imagenet": imagenet_transform, 18 | "imagenet_randaug": imagenet_transform_randaug, 19 | "clip": clip_transform, 20 | "clip_randaug": clip_transform_randaug, 21 | } 22 | 23 | def keys_to_transforms(keys: list, size=224): 24 | return [_transforms[key](size=size) for key in keys] 25 | -------------------------------------------------------------------------------- /src/transforms/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.0: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.0)) 136 | y0 = int(max(0, y0 - v / 2.0)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | # (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0.0, 0.3), 197 | (ShearY, 0.0, 0.3), 198 | # (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0.0, 100), 200 | (TranslateYabs, 0.0, 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = ( 220 | self.eigvec.type_as(img) 221 | .clone() 222 | .mul(alpha.view(1, 3).expand(3, 3)) 223 | .mul(self.eigval.view(1, 3).expand(3, 3)) 224 | .sum(1) 225 | .squeeze() 226 | ) 227 | 228 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 229 | 230 | 231 | class CutoutDefault(object): 232 | """ 233 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 234 | """ 235 | 236 | def __init__(self, length): 237 | self.length = length 238 | 239 | def __call__(self, img): 240 | h, w = img.size(1), img.size(2) 241 | mask = np.ones((h, w), np.float32) 242 | y = np.random.randint(h) 243 | x = np.random.randint(w) 244 | 245 | y1 = np.clip(y - self.length // 2, 0, h) 246 | y2 = np.clip(y + self.length // 2, 0, h) 247 | x1 = np.clip(x - self.length // 2, 0, w) 248 | x2 = np.clip(x + self.length // 2, 0, w) 249 | 250 | mask[y1:y2, x1:x2] = 0.0 251 | mask = torch.from_numpy(mask) 252 | mask = mask.expand_as(img) 253 | img *= mask 254 | return img 255 | 256 | 257 | class RandAugment: 258 | def __init__(self, n, m): 259 | self.n = n 260 | self.m = m # [0, 30] 261 | self.augment_list = augment_list() 262 | 263 | def __call__(self, img): 264 | ops = random.choices(self.augment_list, k=self.n) 265 | for op, minval, maxval in ops: 266 | val = (float(self.m) / 30) * float(maxval - minval) + minval 267 | img = op(img, val) 268 | 269 | return img 270 | -------------------------------------------------------------------------------- /src/transforms/transform.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | inception_normalize, 3 | imagenet_normalize, 4 | MinMaxResize, 5 | ) 6 | from PIL import Image 7 | from torchvision import transforms 8 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 9 | from .randaug import RandAugment 10 | 11 | 12 | def pixelbert_transform(size=800): 13 | longer = int((1333 / 800) * size) 14 | return transforms.Compose( 15 | [ 16 | MinMaxResize(shorter=size, longer=longer), 17 | transforms.ToTensor(), 18 | inception_normalize, 19 | ] 20 | ) 21 | 22 | def pixelbert_transform_randaug(size=800): 23 | longer = int((1333 / 800) * size) 24 | trs = transforms.Compose( 25 | [ 26 | MinMaxResize(shorter=size, longer=longer), 27 | transforms.ToTensor(), 28 | inception_normalize, 29 | ] 30 | ) 31 | trs.transforms.insert(0, RandAugment(2, 9)) 32 | return trs 33 | 34 | def imagenet_transform(size=800): 35 | return transforms.Compose( 36 | [ 37 | Resize(size, interpolation=Image.BICUBIC), 38 | CenterCrop(size), 39 | transforms.ToTensor(), 40 | imagenet_normalize, 41 | ] 42 | ) 43 | 44 | def imagenet_transform_randaug(size=800): 45 | trs = transforms.Compose( 46 | [ 47 | Resize(size, interpolation=Image.BICUBIC), 48 | CenterCrop(size), 49 | transforms.ToTensor(), 50 | imagenet_normalize, 51 | ] 52 | ) 53 | trs.transforms.insert(0, RandAugment(2, 9)) 54 | return trs 55 | 56 | def vit_transform(size=800): 57 | return transforms.Compose( 58 | [ 59 | Resize(size, interpolation=Image.BICUBIC), 60 | CenterCrop(size), 61 | lambda image: image.convert("RGB"), 62 | transforms.ToTensor(), 63 | inception_normalize, 64 | ] 65 | ) 66 | 67 | def vit_transform_randaug(size=800): 68 | trs = transforms.Compose( 69 | [ 70 | Resize(size, interpolation=Image.BICUBIC), 71 | CenterCrop(size), 72 | lambda image: image.convert("RGB"), 73 | transforms.ToTensor(), 74 | inception_normalize, 75 | ] 76 | ) 77 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 78 | trs.transforms.insert(0, RandAugment(2, 9)) 79 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 80 | return trs 81 | 82 | def clip_transform(size): 83 | return Compose([ 84 | Resize(size, interpolation=Image.BICUBIC), 85 | CenterCrop(size), 86 | lambda image: image.convert("RGB"), 87 | ToTensor(), 88 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 89 | ]) 90 | 91 | def clip_transform_randaug(size): 92 | trs = Compose([ 93 | Resize(size, interpolation=Image.BICUBIC), 94 | CenterCrop(size), 95 | lambda image: image.convert("RGB"), 96 | ToTensor(), 97 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 98 | ]) 99 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 100 | trs.transforms.insert(0, RandAugment(2, 9)) 101 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 102 | return trs 103 | 104 | -------------------------------------------------------------------------------- /src/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | 5 | class MinMaxResize: 6 | def __init__(self, shorter=800, longer=1333): 7 | self.min = shorter 8 | self.max = longer 9 | 10 | def __call__(self, x): 11 | w, h = x.size 12 | scale = self.min / min(w, h) 13 | if h < w: 14 | newh, neww = self.min, scale * w 15 | else: 16 | newh, neww = scale * h, self.min 17 | 18 | if max(newh, neww) > self.max: 19 | scale = self.max / max(newh, neww) 20 | newh = newh * scale 21 | neww = neww * scale 22 | 23 | newh, neww = int(newh + 0.5), int(neww + 0.5) 24 | newh, neww = newh // 32 * 32, neww // 32 * 32 25 | 26 | return x.resize((neww, newh), resample=Image.BICUBIC) 27 | 28 | 29 | class UnNormalize(object): 30 | def __init__(self, mean, std): 31 | self.mean = mean 32 | self.std = std 33 | 34 | def __call__(self, tensor): 35 | """ 36 | Args: 37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 38 | Returns: 39 | Tensor: Normalized image. 40 | """ 41 | for t, m, s in zip(tensor, self.mean, self.std): 42 | t.mul_(s).add_(m) 43 | # The normalize code -> t.sub_(m).div_(s) 44 | return tensor 45 | 46 | 47 | # This is simple maximum entropy normalization performed in Inception paper 48 | inception_normalize = transforms.Compose( 49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 50 | ) 51 | 52 | # ViT uses simple non-biased inception normalization 53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 54 | inception_unnormalize = transforms.Compose( 55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 56 | ) 57 | 58 | # ImageNet normalize 59 | imagenet_normalize = transforms.Compose( 60 | [transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 61 | ) 62 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKBUNLP/Mr.Harm-EMNLP2023/84a116fb7483621ca6c1b051b7055b443344817c/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/glossary.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", 5 | "arent": "aren't", 6 | "cant": "can't", 7 | "couldve": "could've", 8 | "couldnt": "couldn't", 9 | "couldn'tve": "couldn't've", 10 | "couldnt've": "couldn't've", 11 | "didnt": "didn't", 12 | "doesnt": "doesn't", 13 | "dont": "don't", 14 | "hadnt": "hadn't", 15 | "hadnt've": "hadn't've", 16 | "hadn'tve": "hadn't've", 17 | "hasnt": "hasn't", 18 | "havent": "haven't", 19 | "hed": "he'd", 20 | "hed've": "he'd've", 21 | "he'dve": "he'd've", 22 | "hes": "he's", 23 | "howd": "how'd", 24 | "howll": "how'll", 25 | "hows": "how's", 26 | "Id've": "I'd've", 27 | "I'dve": "I'd've", 28 | "Im": "I'm", 29 | "Ive": "I've", 30 | "isnt": "isn't", 31 | "itd": "it'd", 32 | "itd've": "it'd've", 33 | "it'dve": "it'd've", 34 | "itll": "it'll", 35 | "let's": "let's", 36 | "maam": "ma'am", 37 | "mightnt": "mightn't", 38 | "mightnt've": "mightn't've", 39 | "mightn'tve": "mightn't've", 40 | "mightve": "might've", 41 | "mustnt": "mustn't", 42 | "mustve": "must've", 43 | "neednt": "needn't", 44 | "notve": "not've", 45 | "oclock": "o'clock", 46 | "oughtnt": "oughtn't", 47 | "ow's'at": "'ow's'at", 48 | "'ows'at": "'ow's'at", 49 | "'ow'sat": "'ow's'at", 50 | "shant": "shan't", 51 | "shed've": "she'd've", 52 | "she'dve": "she'd've", 53 | "she's": "she's", 54 | "shouldve": "should've", 55 | "shouldnt": "shouldn't", 56 | "shouldnt've": "shouldn't've", 57 | "shouldn'tve": "shouldn't've", 58 | "somebody'd": "somebodyd", 59 | "somebodyd've": "somebody'd've", 60 | "somebody'dve": "somebody'd've", 61 | "somebodyll": "somebody'll", 62 | "somebodys": "somebody's", 63 | "someoned": "someone'd", 64 | "someoned've": "someone'd've", 65 | "someone'dve": "someone'd've", 66 | "someonell": "someone'll", 67 | "someones": "someone's", 68 | "somethingd": "something'd", 69 | "somethingd've": "something'd've", 70 | "something'dve": "something'd've", 71 | "somethingll": "something'll", 72 | "thats": "that's", 73 | "thered": "there'd", 74 | "thered've": "there'd've", 75 | "there'dve": "there'd've", 76 | "therere": "there're", 77 | "theres": "there's", 78 | "theyd": "they'd", 79 | "theyd've": "they'd've", 80 | "they'dve": "they'd've", 81 | "theyll": "they'll", 82 | "theyre": "they're", 83 | "theyve": "they've", 84 | "twas": "'twas", 85 | "wasnt": "wasn't", 86 | "wed've": "we'd've", 87 | "we'dve": "we'd've", 88 | "weve": "we've", 89 | "werent": "weren't", 90 | "whatll": "what'll", 91 | "whatre": "what're", 92 | "whats": "what's", 93 | "whatve": "what've", 94 | "whens": "when's", 95 | "whered": "where'd", 96 | "wheres": "where's", 97 | "whereve": "where've", 98 | "whod": "who'd", 99 | "whod've": "who'd've", 100 | "who'dve": "who'd've", 101 | "wholl": "who'll", 102 | "whos": "who's", 103 | "whove": "who've", 104 | "whyll": "why'll", 105 | "whyre": "why're", 106 | "whys": "why's", 107 | "wont": "won't", 108 | "wouldve": "would've", 109 | "wouldnt": "wouldn't", 110 | "wouldnt've": "wouldn't've", 111 | "wouldn'tve": "wouldn't've", 112 | "yall": "y'all", 113 | "yall'll": "y'all'll", 114 | "y'allll": "y'all'll", 115 | "yall'd've": "y'all'd've", 116 | "y'alld've": "y'all'd've", 117 | "y'all'dve": "y'all'd've", 118 | "youd": "you'd", 119 | "youd've": "you'd've", 120 | "you'dve": "you'd've", 121 | "youll": "you'll", 122 | "youre": "you're", 123 | "youve": "you've", 124 | } 125 | 126 | manual_map = { 127 | "none": "0", 128 | "zero": "0", 129 | "one": "1", 130 | "two": "2", 131 | "three": "3", 132 | "four": "4", 133 | "five": "5", 134 | "six": "6", 135 | "seven": "7", 136 | "eight": "8", 137 | "nine": "9", 138 | "ten": "10", 139 | } 140 | articles = ["a", "an", "the"] 141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 142 | comma_strip = re.compile("(\d)(\,)(\d)") 143 | punct = [ 144 | ";", 145 | r"/", 146 | "[", 147 | "]", 148 | '"', 149 | "{", 150 | "}", 151 | "(", 152 | ")", 153 | "=", 154 | "+", 155 | "\\", 156 | "_", 157 | "-", 158 | ">", 159 | "<", 160 | "@", 161 | "`", 162 | ",", 163 | "?", 164 | "!", 165 | ] 166 | 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | --------------------------------------------------------------------------------