├── .gitignore ├── README.md ├── app.py ├── data ├── disaster.csv └── english_to_latex.csv ├── images ├── Venus_and_Adonis_by_Peter_Paul_Rubens.jpg └── oreilly.png ├── notebooks ├── BERT_vs_GPT_for_CLF.ipynb ├── constructing_a_vqa_system.ipynb ├── deployment_and_optimization.ipynb ├── latex_gpt2.ipynb ├── mnist.ipynb ├── rnn_and_cnn.ipynb ├── using_our_vqa.ipynb └── vgg_and_bert.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pt 3 | notebooks/clf/ 4 | notebooks/wandb/ 5 | .ipynb_checkpoints 6 | *.onnx 7 | data/MNIST/ 8 | data/art-styles/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![oreilly-logo](images/oreilly.png) 2 | 3 | # Deep Learning for Modern AI 4 | 5 | 6 | This repository contains code for the [O'Reilly Live Online Training for Deep Learning for Modern AI](https://learning.oreilly.com/live-events/deep-learning-for-modern-ai/0642572005084) 7 | 8 | This training provides the theory and practical concepts for a comprehensive introduction to machine learning and deep learning with PyTorch —foundational knowledge needed to successfully build and train GenAI and multimodal models. By making our way through several real-world case studies including object recognition and text classification this session is an excellent crash course in deep learning with PyTorch. 9 | 10 | We use tools including large pre-trained models and model training dashboards to set up reproducible deep learning experiments and build machine learning models optimized for performance. There are several code examples throughout the training to help solidify the theoretical concepts that will be introduced. Models like Stable Diffusion, Llama 3, GPT, and BERT are highlighted as we uncover the training and optimization strategies to get the most of our models' performance, speed, and memory usage. 11 | ### Notebooks 12 | 13 | 14 | #### 1. Introduction to Deep Learning 15 | 16 | All data can be downloaded for the art classification example [here](https://drive.google.com/file/d/1jofGOHQ4PwZ50kpGuDqBeVXwDNcjPE6B/view?usp=sharing). Note it is about 6GB so it may take a bit. 17 | 18 | - [**First steps with Deep Learning with MNIST**](notebooks/mnist.ipynb) 19 | - [**RNNs and CNNs**](notebooks/rnn_and_cnn.ipynb) 20 | - [**Working with pre-trained VGG-11 and BERT models**](notebooks/vgg_and_bert.ipynb) 21 | - [**Fine-tuning BERT vs ChatGPT**](notebooks/BERT_vs_GPT_for_CLF.ipynb) 22 | - [Fine-tuning OpenAI](https://github.com/sinanuozdemir/quick-start-guide-to-llms/blob/main/notebooks/05_openai_app_review_fine_tuning.ipynb): the code to compare against BERT 23 | - [**Fine-tuning GPT-2 to convert English to LaTEX**](notebooks/latex_gpt2.ipynb) 24 | - [**Fine-tuning Llama 3 to be a chatbot**](https://colab.research.google.com/drive/1gN7jsUFQTPAj5uFrq06HcSLQSZzT7hZz?usp=sharing) 25 | 26 | #### 2. Optimizing models 27 | 28 | - [**Production Optimization**](notebooks/deployment_and_optimization.ipynb) 29 | - [**Quantizing Llama 3**](https://colab.research.google.com/drive/12RTnrcaXCeAqyGQNbWsrvcqKyOdr0NSm?usp=sharing) 30 | - [**Testing different fine-tuning configurations**](https://colab.research.google.com/drive/1fdx2XlqfAjBoyiTktkRwa8SFaRF3Ch82?usp=sharing) 31 | - [**Distilling BERT models**](https://github.com/sinanuozdemir/quick-start-guide-to-llms/blob/main/notebooks/11_distillation_example_2.ipynb) 32 | 33 | #### 3. Going Further 34 | 35 | - **[Intro to Multimodality](https://colab.research.google.com/drive/1zYSzDuYFa_cbRlti3scUjfmvradK8Sf4?usp=sharing)**: An introduction to multimodality with CLIP and SHAP-E + Diffusion 36 | 37 | - **[Whisper](https://colab.research.google.com/drive/1KxLWEEBtgix4zgP52pnxlIoJrZ8sHEYC?usp=sharing)**: An introduction to using Whisper for audio transcription 38 | 39 | - **[Llava](https://colab.research.google.com/drive/1IwNAz1Ee4YUSRNCU-SOsa7FS8Q2vmpoL?usp=sharing)**: Using an open source mult-turn multimodal engine 40 | 41 | - **[CLIP-based Stock Image Search](https://colab.research.google.com/drive/1aUz0FKQDSAyXyhRyvkkRsSy7S30mpRJc?usp=sharing)**: Using CLIP to search through a library of images 42 | 43 | - **[Dreambooth](https://colab.research.google.com/drive/1tQt1pE6l0MI79W8ZX0MMu0YVmF2I0GB3?usp=sharing)**: Fine-tuning a stable difusion model to make images of yours truly! Ever wonder what I look like blonde? Me neither but AI gave me some ideas of what it would look like. 44 | 45 | 46 | - **Visual Q/A** - This case study requires you to [download the data from my Dropbox here](https://www.dropbox.com/scl/fo/w6iyfox8gnflvm7g10n47/AB47L7tNEl2Q8eyemZa2GMA?rlkey=v9s8bv6cmjukykpilzimswar0&st=fbulzw4e&dl=0). The code snippets should download them in code if that is easier! Our goal is to emulate the process done by [Llama 3.2-Vision-Instruct](https://colab.research.google.com/drive/1r6Nab2L7rYUBV5e8K8u8EFw98adJu5uh?usp=sharing): one of Meta's latest Llama models that can take in images. 47 | 48 | - Method 1: BERT + ViT -> GPT-2 (Fusion) 49 | 50 | - Constructing and Training our model: [Local](notebooks/constructing_a_vqa_system.ipynb) and notebook in [Colab](https://colab.research.google.com/drive/1zvbruS1DvFrVgXjNouSrrF9-PphKLWWl?usp=sharing) 51 | - Using our VQA system: [Local](notebooks/using_our_vqa.ipynb) notebook and [Colab](https://colab.research.google.com/drive/16GOBndQuIBO-UfXdpPte-PXaZS2nsW1H?usp=sharing) 52 | 53 | - Method 2: BERT + ViT -> GPT-2 (Fusion) 54 | - [Train the VQA Model](https://colab.research.google.com/drive/1DSh8_yfubuu5xPVM2BQ-I_eH5rrxLKZU?usp=sharing) and [use it here](https://colab.research.google.com/drive/1AWAk7NTvgTbjktUNB6bmS6T37bgTzRgt?usp=sharing) 55 | 56 | - **[Training a Reasoning Model with Unsloth](https://colab.research.google.com/drive/1Cws1IL_T_0_cP0-cHxFA0FEsXYdiAN_8?usp=sharing)** - Advanced - See how companies like DeepSeek and Anthropic train their reasoning models. [Unsloth AI](https://unsloth.ai/) is a package aiming to make fine-tuning more streamlined, faster, and more memory efficient by handwriting things like backprop in a faster way. We salute them for their work! 57 | 58 | #### How to Use the Image Recognition Flask App 59 | 60 | `app.py` is a Flask app that uses a VGG16 model to classify the art style of an uploaded image. The app currently supports 10 different art styles: 61 | 62 | - Abstract Expressionism 63 | - Art Nouveau (Modern) 64 | - Baroque 65 | - Expressionism 66 | - Impressionism 67 | - Northern Renaissance 68 | - Post-Impressionism 69 | - Realism 70 | - Romanticism 71 | - Symbolism 72 | 73 | Start the Flask app: 74 | `python app.py` 75 | 76 | This should start the Flask app and make it available at `http://localhost:5000`. 77 | 78 | #### How to Use the App 79 | 80 | To classify an image, you can use a cURL request in the following format: 81 | 82 | 83 | ```curl -X POST -F 'image=@/path/to/your/image.jpg' http://localhost:5000/predict``` 84 | 85 | Replace `/path/to/your/image.jpg` with the path to your own image. The response will be in JSON format and will contain the predicted art style and associated confidence scores, as shown below: 86 | 87 | ``` 88 | e.g. 89 | curl -X POST -F \ 90 | 'image=@images/Venus_and_Adonis_by_Peter_Paul_Rubens.jpg' \ 91 | http://localhost:5000/predict 92 | 93 | [ 94 | ["Northern_Renaissance",0.13392961025238037], 95 | ["Realism",0.12794768810272217], 96 | ["Romanticism",0.12592236697673798], 97 | ["Post_Impressionism",0.11863630264997482], 98 | ["Baroque",0.11325731128454208], 99 | ["Symbolism",0.1120268702507019], 100 | ["Expressionism",0.08971412479877472], 101 | ["Impressionism",0.086906298995018], 102 | ["Art_Nouveau_Modern",0.05910796299576759], 103 | ["Abstract_Expressionism",0.03255145251750946]] 104 | ``` 105 | 106 | If there is an error with the request, such as no image being provided, the response will contain an error message instead: 107 | 108 | ``` 109 | { 110 | "error": "No image provided" 111 | } 112 | ``` 113 | 114 | 115 | ## Instructor 116 | 117 | **Sinan Ozdemir** is the Founder and CTO of LoopGenius where he uses State of the art AI to help people create and run their businesses. Sinan is a former lecturer of Data Science at Johns Hopkins University and the author of multiple textbooks on data science and machine learning. Additionally, he is the founder of the recently acquired Kylie.ai, an enterprise-grade conversational AI platform with RPA capabilities. He holds a master’s degree in Pure Mathematics from Johns Hopkins University and is based in San Francisco, CA. 118 | 119 | # For More 120 | 121 | - CHeck out [Deep Learning Illustrated](https://www.amazon.com/dp/0135116694?ref_=cm_sw_r_ffobk_cp_ud_dp_T500T43FCOX9F12OYRFO&peakEvent=5&dealEvent=0&bestFormat=true): A best seller by Jon Krohn, it's a very visual introduction to deep learning 122 | - [Deep Learning course: lecture slides and lab notebooks](https://m2dsupsdlclass.github.io/lectures-labs/): The course covers the basics of Deep Learning, with a focus on applications. 123 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | from flask import Flask, request, jsonify 6 | from torch import nn 7 | 8 | # Initialize the Flask app 9 | app = Flask(__name__) 10 | 11 | # Define the image transformations to be applied 12 | transform = transforms.Compose([ 13 | transforms.Resize((224, 224)), # Resize the input image to 224 x 224 pixels 14 | transforms.ToTensor(), # Convert the image to a PyTorch tensor 15 | ]) 16 | 17 | # Load the pre-trained VGG16 model 18 | trained_vgg_model = models.vgg11() 19 | 20 | # Modify the model's final layer to have the correct number of output classes 21 | num_classes = 10 22 | trained_vgg_model.classifier[-1] = nn.Linear(trained_vgg_model.classifier[-1].in_features, num_classes) 23 | 24 | # Set the model to evaluation mode and load the trained weights 25 | trained_vgg_model.load_state_dict(torch.load("data/trained_vgg_model_quantized.pt")) 26 | trained_vgg_model.eval() 27 | 28 | # Define the class labels 29 | CLASSES = [ 30 | 'Abstract_Expressionism', 31 | 'Art_Nouveau_Modern', 32 | 'Baroque', 33 | 'Expressionism', 34 | 'Impressionism', 35 | 'Northern_Renaissance', 36 | 'Post_Impressionism', 37 | 'Realism', 38 | 'Romanticism', 39 | 'Symbolism' 40 | ] 41 | 42 | 43 | # Define the predict endpoint 44 | @app.route('/predict', methods=['POST']) 45 | def predict(): 46 | if request.files.get('image'): # Check if an image file was uploaded 47 | # Preprocess the image and pass it through the model 48 | image = Image.open(request.files['image'].stream).convert('RGB') 49 | input_tensor = transform(image).unsqueeze(0) 50 | with torch.no_grad(): 51 | output = trained_vgg_model(input_tensor) 52 | 53 | # Get the probabilities for all classes 54 | probabilities = torch.softmax(output, dim=1)[0].tolist() 55 | class_probabilities = {CLASSES[class_id]: prob for class_id, prob in enumerate(probabilities)} 56 | sorted_probabilities = sorted(class_probabilities.items(), key=lambda x: x[1], reverse=True) 57 | 58 | return jsonify(sorted_probabilities) 59 | 60 | return jsonify({'error': 'No image provided'}) # Return an error message if no image file was uploaded 61 | 62 | 63 | if __name__ == '__main__': 64 | app.run() # Run the Flask app if this module is executed directly (not imported) 65 | -------------------------------------------------------------------------------- /data/disaster.csv: -------------------------------------------------------------------------------- 1 | index,id,keyword,location,text,target,label 2 | 7138,10224,volcano,,@MrMikeEaton @Muazimus_Prime hill hill mountain volcano of hell mountain hill hil.,1,1 3 | 2151,3086,deaths,Blackpool,"Cancers equate for around 25% of all deaths in #Blackpool. 4 | 5 | Kowing the signs could save your life: http://t.co/5lNIdvoBff 6 | #BeClearOnCancer",1,1 7 | 4395,6247,hijacking,World,The Murderous Story Of America‰Ûªs First Hijacking: Earnest Pletch‰Ûªs cold-blooded killing of‰Û_ http://t.co/B9JAxx0vCf,1,1 8 | 2508,3602,desolation,"Birmingham, UK",The date for the release of EP03 DESOLATION is set. Stay tuned for more info while we finalise the schedule. #alt #electro #rock #comingsoon,1,1 9 | 1378,1987,bush%20fires,London/Bristol/Guildford,On holiday to relax sunbathe and drink ... Putting out bush fires? Not so much ?? #spain https://t.co/dRno7OKM21,0,0 10 | 6825,9775,trapped,????s ?? ????Ìø????Ì¡a,(?EudryLantiqua?) Hollywood Movie About Trapped Miners Released in Chile: 'The 33' Holly... http://t.co/us1DMdXZVb (?EudryLantiqua?),1,1 11 | 3877,5514,flattened,Some other mansion,"Flattened all cartoony-like. 12 | 'Whoa there Papa!' https://t.co/4zmcqRMOIs",0,0 13 | 3465,4957,exploded,,#news #science London warship exploded in 1665 because sailors were recycling artillery cartridges... http://t.co/r4WGXrA59M #life #tech,1,1 14 | 6245,8921,snowstorm,"Brooklyn, NY",'Cooler than Freddie Jackson sippin' a milkshake in a snowstorm',0,0 15 | 5664,8083,rescue,Wanderlust,Mary coming to Troy rescue. ?????? https://t.co/rosVXQeLQj,0,0 16 | 3938,5599,flood,New York,Spot Flood Combo 53inch 300W Curved Cree LED Work Light Bar 4X4 Offroad Fog Lamp - Full re‰Û_ http://t.co/fDSaoOiskJ http://t.co/2uVmq4vAfQ,0,0 17 | 3059,4388,earthquake,in the Word of God,@DArchambau THX for your great encouragement and for RT of NEW VIDEO http://t.co/cybKsXHF7d The Coming Apocalyptic US Earthquake & Tsunami,1,1 18 | 5985,8546,screams,,When you on the phone and @Worstoverdose screams 'jaileens caked up on the phone' so everyone looks at you ????????????,0,0 19 | 7186,10296,weapon,,Cat Of Nine Irons XII: This nightmarishly brutal weapon is used in ritualistic country club de http://t.co/xpFmR368uF http://t.co/nmAUMYdKe1,1,1 20 | 6493,9283,sunk,"San Jose, CA",Its like I never left. I just sunk to the background,0,0 21 | 816,1185,blizzard,columbus ohio,@StevenOnTwatter @PussyxDestroyer just order a blizzard pay then put your nuts in it say they have you ball flavored. Boom free ice cream,0,0 22 | 4127,5868,hailstorm,facebook.com/tradcatknights,"Canada: Hailstorm flash flooding slam Calgary knocks out power to 20k customers 23 | http://t.co/SkY9EokgGB http://t.co/5IyZsDA6xB",1,1 24 | 4623,6571,injury,,incident with injury:I-495 inner loop Exit 31 - MD 97/Georgia Ave Silver Spring,1,1 25 | 895,1296,bloody,Leicester,Bloody insomnia again! Grrrr!! #Insomnia,1,1 26 | 7450,10662,wounds,The American Wasteland (MV),@FEVWarrior -in the Vault that could take a look at those wounds of yours if you'd like to go to one of these places first.' Zarry has had-,0,0 27 | 3043,4366,earthquake,a box,"@AGeekyFangirl14 's things she looks in a significant other: 28 | 1. Beautiful eyes. 29 | 2. Humor. 30 | 3. Farts that creates an earthquake. 31 | 32 | ????????",0,0 33 | 422,612,arsonist,America,If you don't have anything nice to say you can come sit with me.,0,0 34 | 194,273,ambulance,Loveland Colorado,@Kiwi_Karyn Check out what's in my parking lot!! He said that until last year it was an ambulance in St Johns. http://t.co/hPvOdUD7iP,0,0 35 | 542,788,avalanche,Freeport il ,The possible new jerseys for the Avalanche next year. ???? http://t.co/nruzhR5XQu,0,0 36 | 6950,9972,tsunami,,"Crptotech tsunami and banks. 37 | http://t.co/KHzTeVeDja #Banking #tech #bitcoing #blockchain",1,1 38 | 4860,6919,mass%20murderer,,Julian Knight - @SCVSupremeCourt dismisses mass murderer's attempt to increase prisoner pay. Challenged quantum of 5% increase 2013.,1,1 39 | 3297,4725,evacuate,,Disregard my snap story there is an angry white girl riot happening as we speak. #evacuate,1,1 40 | 4977,7097,military," Quantico Marine Base, VA.",@FurTrix then find cougars who look like her even better if they're in military uniform!,0,0 41 | 411,594,arson,,Arson suspect linked to 30 fires caught in Northern California http://t.co/u1fuWrGK5U,1,1 42 | 3817,5425,first%20responders,,Carmike Cinemas on Antioch Shooting: 'We Are Grateful' for Staff and First Responders Safety Is 'Highest Priority' http://t.co/BehfHspPud,1,1 43 | 1121,1617,bombed,"Screwston, TX",'Redskins WR Roberts Belly-Bombed ' via @TeamStream http://t.co/GbcvVEvDTY,1,1 44 | 4233,6013,hazardous,"Oregon, USA",Is it possible to sneak into a hospital so I can stab myself with a hazardous needle and inject some crazy disease into my veins until I die,0,0 45 | 6397,9143,suicide%20bomber,,Suicide Bomber Kills 13 At Saudi Mosque http://t.co/oZ1DS3Xu0D #Saudi #Tripoli #Libya,1,1 46 | 4622,6570,injury,,@Sport_EN Just being linked to Arsenal causes injury.,1,1 47 | 4149,5898,harm,Kansas City,@dinallyhot Love what you picked! We're playing WORTH IT by FIFTH HARM/KID INK because of you! Listen & Vote: http://t.co/0wrATkA2jL,0,0 48 | 2927,4206,drowned,"United Kingdom,Fraserburgh","@Stephen_Georg Hey Stephen Remember that time you drowned all the yellows 49 | 50 | Read: http://t.co/0sa6Xx1oQ7",0,0 51 | 4895,6970,massacre,,"Headed to the massacre 52 | Bodies arriving everyday 53 | What were those shells you heard 54 | Picking the bones up along the way",1,1 55 | 2267,3249,deluged,balvanera,'Afterwards I had to be alone for an hour to savour and prolong the almost physical intensity of the feelings that deluged me'. Y sigue:,0,0 56 | 6027,8613,seismic,,SEISMIC AUDIO SA-15T SA15T Padded Black Speaker COVERS (2) - Qty of 1 = 1 Pair! http://t.co/2jbIbeib9G http://t.co/p5KtaqW5QG,0,0 57 | 7354,10527,wildfire,"Eddyville, Oregon 97343",Oregon's biggest wildfire slows growth http://t.co/P0GoS5URXG via @katunews,1,1 58 | 1935,2783,curfew,Uganda,@DavisKawalya I know @Mauryn143 will be saying her final goodbyes to grandpa as seen on news RiP Me? always open to ideas but may ve curfew,0,0 59 | 932,1349,blown%20up,Cobblestone,"'If a truckload of soldiers will be blown up nobody panics but when one little lion dies everyone loses their mind' 60 | http://t.co/wjNTaOkdHf",1,1 61 | 3963,5632,flooding,"Honolulu, Hawaii",If Flooding occurs in your home - How to stop mold growth via @ProudGreenHome http://t.co/KAVAovJz2V,1,1 62 | 2805,4033,disaster,,@cncpts @SOLELINKS what a disaster - can't say I'm surprised,0,0 63 | 1683,2428,collide,The Forever Girl,The Witches of the Glass Castle. Supernatural YA where sibling rivalry magic and love collide #wogc #kindle http://t.co/IzakNpJeQW,0,0 64 | 5754,8214,riot,,"@abran_caballero Discovered by @NickCannon 65 | Listen/Buy @realmandyrain #RIOT on @iTunesMusic @iTunes https://t.co/dehMym5lpk ‰Û_ #BlowMandyUp",0,0 66 | 3907,5557,flattened,,100 1' MIX NEW FLAT DOUBLE SIDED LINERLESS BOTTLE CAPS YOU CHOOSE MIX FLATTENED - Full re‰Û_ http://t.co/w00kjPrfdR http://t.co/mIXl1pFRJe,0,0 67 | 4123,5861,hailstorm,,Summer heat drives bobcats to Calgary backyards ~ 3 http://t.co/kEstuUYc4t http://t.co/PzFBb1P1mj,1,1 68 | 4842,6896,mass%20murder,"Huntsville, AL",Okay not sure the word 'mass murder' applies during this war but it was horrendous none the less. https://t.co/Sb3rjQqzIX,1,1 69 | 6892,9880,traumatised,,@wrongdejavu I'm traumatised,0,0 70 | 4122,5858,hailstorm,,IG: http://t.co/2WBiVKzJIP 'It's hailing again! #abstorm #yyc #hail #hailstorm #haildamage #yycweather #calgary #captureyyc #alberta #sto‰Û_,1,1 71 | 5093,7265,nuclear%20disaster,,3 former executives to be prosecuted in Fukushima nuclear disaster http://t.co/Gvj7slbELP,1,1 72 | 5327,7605,pandemonium,"Mumbai, India",@minhazmerchant Govt should pass the bills in the Pandemonium. UPA used to do it why cant NDA?,0,0 73 | 2348,3379,demolition,,Israel continues its demolition of Palestinian homes #gop #potus #irandeal #isis https://t.co/NMgp7iMEIi,1,1 74 | 1123,1619,bombed,My old New England home,I liked a @YouTube video http://t.co/FX7uZZXtE4 Benedict Cumberbatch Gets Video Bombed,0,0 75 | 2645,3797,destruction,denver colorado,it sure made an impact on me http://t.co/GS50DdG1JY,0,0 76 | 3220,4620,emergency%20services,"Vancouver, British Columbia",Removing tsunami debris from the West Coast: Karen Robinson Enviromental and Emergency services manager of the‰Û_ http://t.co/1MeEo3WJcO,1,1 77 | 4117,5850,hailstorm,Washington State,We The Free Hailstorm Maxi http://t.co/ERWs6IELdG,1,1 78 | 1913,2753,curfew,,@TheComedyQuote @50ShadezOfGrey the thirst has no curfew ???????????? @P45Perez,0,0 79 | 3372,4830,evacuation,,Run out evacuation hospital indexing remedial of angioplasty dissertation at power elite hospitals dismayed:‰Û_ http://t.co/VGvJGr8zoO,1,1 80 | 6382,9120,suicide%20bomb,lagos. Unilag,16yr old PKK suicide bomber who detonated bomb in Turkey Army trench released http://t.co/mGZslZz1wF,1,1 81 | 1966,2830,cyclone,"Sioux Falls, SD",Excited for Cyclone football https://t.co/Xqv6gzZMmN,0,0 82 | 150,215,airplane%20accident,"New Mexico, USA",@mickinyman @TheAtlantic That or they might be killed in an airplane accident in the night a car wreck! Politics at it's best.,0,0 83 | 1938,2786,curfew,"Elkhart, IN",Had to cancel my cats doctor appointment because she decided to go out and play and not come home by curfew ...,0,0 84 | 4735,6734,lava,Venezuela,I LAVA YOU.,0,0 85 | 4111,5842,hailstorm,,Hail:The user summons a hailstorm lasting five turns. It damages all Pokemon except the Ice type.,0,0 86 | 498,721,attacked,Peshawar,IK Only Troll His Pol Rivals Never Literally Abused Them Or Attacked Their Families. While All Of Them Literally Abuse IK. Loosers,0,0 87 | 4246,6033,hazardous,United States,DLH issues Hazardous Weather Outlook (HWO) http://t.co/a0Ad8z5Vsr #WX,1,1 88 | 4578,6511,injuries,"Sutton, London UK",@fkhanage look what Shad Forsythe has done in 1 year we won't have as many injuries as before we will inevitably have injuries like others,0,0 89 | 3532,5049,eyewitness,Pennsylvania,A true #TBT Eyewitness News WBRE WYOU http://t.co/JHVigsX5Jg,0,0 90 | 3070,4405,electrocute,,when you got an extension cord that extends from your bed to your bath tub ?? lets pray I don't electrocute myself,0,0 91 | 1829,2629,crashed,,Pakistan says army helicopter has crashed in country's restive northwest killing at least 8 http://t.co/QV1RMZI3J1,1,1 92 | 6818,9765,trapped,10 Steps Ahead. Cloud 9,Bomb head? Explosive decisions dat produced more dead children than dead bodies trapped tween buildings on that day in September there,1,1 93 | 4355,6187,hijacker,San Francisco,Governor allows parole for California school bus hijacker who kidnapped 26 children in 1976. http://t.co/hdAhLgrprl http://t.co/Z1s3T77P3L,1,1 94 | 2961,4253,drowning,Numa casa de old yellow bricks,LONDON IS DROWNING AND IIII LIVE BY THE RIVEEEEEER,1,1 95 | 4285,6088,hellfire,,Hellfire is surrounded by desires so be careful and don‰Ûªt let your desires control you! #Afterlife,0,0 96 | 6572,9405,survivors,"Dallas, TX ",Haunting memories drawn by survivors http://t.co/pRAro2OWia,1,1 97 | 5047,7195,natural%20disaster,home ,she's a natural disaster she's the last of the American girls ??,0,0 98 | 249,354,annihilation,Spain,:StarMade: :Stardate 3: :Planetary Annihilation:: http://t.co/I2hHvIUmTm via @YouTube,1,1 99 | 5382,7680,panic,,@panic awesome thanks.,0,0 100 | 2898,4163,drown,,@cameronhigdon34 I can't drown my demons they know how to swim.,0,0 101 | 2771,3983,devastation,,70 Years After Atomic Bombs Japan Still Struggles With War Past: The anniversary of the devastation wrought b... http://t.co/Targ56iGBZ,1,1 102 | 6043,8638,seismic,,A subcontractor working for French seismic survey group CGG has been kidnapped in Cairo and is held by Islamic State the company said on W‰Û_,1,1 103 | 3868,5499,flames,,If you have an opinion and you don't put it on thh internet you will furst into flames.,0,0 104 | 7452,10665,wounds,"Cleveland, OH",If time heals all wounds how come the belly button stays the same?,0,0 105 | 5434,7753,police,,Police unions retard justice & drain gov $ but cops vote 4 Rs so Rs go after teachers unions instead! @DCCC @VJ44 @Lawrence @JBouie @mmfa,0,0 106 | 6757,9681,tornado,Toronto,Environment Canada confirms 2nd tornado touched down last weekend åÈ http://t.co/x8zqbwNfO1,1,1 107 | 5874,8391,ruin,"MÌ©rida, YucatÌÁn",babe I'm gonna ruin you if you let me stay,0,0 108 | 6036,8627,seismic,"Hatteras, North Carolina",Agency seeks comments on seismic permits http://t.co/9Vd6x4WDOY,0,0 109 | 7589,10843,,,Omg earthquake,1,1 110 | 5386,7686,panic,Milwaukee WI,Someone asked me about a monkey fist about 2 feet long with a panic snap like the one pictured to be used as a... http://t.co/Yi9BBbx3FE,0,0 111 | 4604,6547,injury,"Los Angeles, California",California Law‰ÛÓNegligence and Fireworks Explosion Incidents http://t.co/d5w2zynP7b,1,1 112 | 5743,8201,riot,"Cardiff, UK",@PipRhys I predict a riot.,1,1 113 | 1015,1474,body%20bagging,401 livin',Aubrey really out here body-bagging Meek.,1,1 114 | 3425,4898,explode,they/them,"tagged by @attackonstiles 115 | 116 | millions 117 | a-punk 118 | hang em high 119 | alpha dog 120 | yeah boy and doll face 121 | little white lies 122 | explode http://t.co/lAtsSUo4wS",0,0 123 | 102,146,aftershock,Instagram - @heyimginog ,@afterShock_DeLo scuf ps live and the game... cya,0,0 124 | 5373,7669,panic,"Philadelphia, PA",Despite the crippling anxiety and overwhelming panic attacks I'd say I'm fairly well-adjusted.,0,0 125 | 3850,5479,flames,hell,@gilderoy i wish i was good enough to add flames to my nails im on fire,1,1 126 | 5945,8491,screamed,,i dont even remember slsp happening i just remember being like wtf and then the lights turned off and everyone screamed for the encore,0,0 127 | 5920,8452,screamed,,I heard the steven universe theme song from upstairs and screamed his name at the part of the song and scared my cousin,0,0 128 | 5310,7583,outbreak,Finland,Families to sue over Legionnaires: More than 40 families affected by the fatal outbreak of Legionnaires' disea... http://t.co/ZA4AXFJSVB,1,1 129 | 5955,8506,screaming,rio de janeiro | brazil,SCREAMING @MariahCarey @ArianaGrande http://t.co/xxZD1nmb1i,0,0 130 | 5454,7780,police,"Mesa, AZ",@ArizonaDOT Price Rd North bound closed from University to Rio Salado.. Lots of police.. What's crackin?,1,1 131 | 365,523,army,,One Direction Is my pick for http://t.co/q2eBlOKeVE Fan Army #Directioners http://t.co/eNCmhz6y34 x1386,0,0 132 | 5598,7988,razed,,The Latest: More Homes Razed by Northern California Wildfire - ABC News http://t.co/2872J5d4HB,1,1 133 | 5152,7348,obliterate,"Cavite, Philippines",The People's Republic Of China ( PROC ): Abandon the West Philippine Sea and all the ... https://t.co/pD14GsrfSC via @ChangePilipinas,0,0 134 | 128,184,aftershock,304,'Remembering that you are going to die is the best way I know to avoid the trap of thinking you have something to lose.' ‰ÛÒ Steve Jobs,0,0 135 | 4635,6588,inundated,"Zeerust, South Africa",Most of us ddnt get this English RT @ReIgN_CoCo: The World Is Inundated With Ostentatious People Stay Woke!,0,0 136 | 2519,3619,desolation,Arizona,Did Josephus get it wrong about Antiochus Epiphanes and the Abomination of Desolation? Read more: http://t.co/FWj9CcYw6k,0,0 137 | 2684,3850,detonation,,Ignition Knock (Detonation) Sensor-Senso Standard KS161 http://t.co/WadPP69LwJ http://t.co/yjTh2nABv5,0,0 138 | 4250,6038,heat%20wave,USA,Heat Advisory is In Effect From 1 PM Through 7 PM Thursday. A Building Heat Wave and Increasing Humidity... #lawx http://t.co/u0SYkowVWV,1,1 139 | 6610,9466,terrorism,,DHS Refuses to Call Chattanooga ‰Û÷Islamic Terrorism‰Ûª out of respect for MUSLIMS ... http://t.co/u8RGB51d22 via @po_st http://t.co/2tnu95VGFE,1,1 140 | 7386,10570,windstorm,Houston,New roof and hardy up..Windstorm inspection tomorrow http://t.co/kKeH8qCgc3,0,0 141 | 3921,5577,flood,New York,Spot Flood Combo 53inch 300W Curved Cree LED Work Light Bar 4X4 Offroad Fog Lamp - Full re‰Û_ http://t.co/mTmoIa0Oo0 http://t.co/Nn4ZtCmSRU,0,0 142 | 629,906,bioterrorism,,70 won 70...& some think possibility of my full transformation is impossible. I don't quite like medical mysteries. BIOTERRORISM sucks.,0,0 143 | 6986,10018,twister,,Twister was fun https://t.co/qCT6fb8wOn,0,0 144 | 7173,10278,war%20zone,,@RobertONeill31 Getting hit by a foul ball while sitting there is hardly a freak accident. It's a war zone.,0,0 145 | 2493,3582,desolate,Macclesfield,@binellithresa TY for the follow Go To http://t.co/UAN05TNkSW BRUTALLY ABUSED+DESOLATE&LOST + HER LOVELY MUM DIES..Is it Murder?,1,1 146 | 7467,10684,wounds,Charlotte,ME says many of these wounds could be fatal some rather quickly others slower and a couple not lethal at all. #KerrickTrial,0,0 147 | 1657,2394,collapsed,#ForeverWithBAP 8 ,and he almost collapsed bc he said his wish came true moderately FUCK,0,0 148 | 4737,6736,lava,Indonesia,@YoungHeroesID 4. Lava Blast Power Red #PantherAttack,0,0 149 | 5690,8120,rescued,"Winston-Salem, NC",'You can only be rescued from where you actually are and not from where you pretend you are.' Giorgio Hiatt,0,0 150 | 807,1173,blight,Kama | 18 | France ,@anellatulip there is a theory that makes way too much sense that says that the dwarves may be the actual origin of the blight,0,0 151 | 3149,4523,emergency,We are global!,SF Asian Women's Shelter crisis line (415) 751-0880. Emergency shelter/support services 4 non-English speaking Asian women & children.,1,1 152 | 3194,4585,emergency%20plan,Los Angeles,Senators calling for emergency housing: Boxer Feinstein back plan to move #homeless vets to VA campus http://t.co/Gm80X3vutf,1,1 153 | 3583,5119,fatal,Nakhon Si Thammarat,Investigators shift focus to cause of fatal Waimate fire http://t.co/c9dVDsSoFn,1,1 154 | 2380,3421,derail, Road to the Billionaires Club,@TheJenMorillo GM! I pray any attack of the enemy 2 derail ur destiny is blocked by the Lord & that He floods ur life w/heavenly Blessings,0,0 155 | 4456,6339,hostages,"Cumming, GA",C-130 specially modified to land in a stadium and rescue hostages in Iran in 1980 http://t.co/jkD7CTi2iW http://t.co/LAjN2n5e2d,1,1 156 | 2669,3831,detonate,,@WoundedPigeon http://t.co/s9soAeVcVo Detonate by @ApolloBrown ft. M.O.P.,0,0 157 | 3393,4857,evacuation,The Empire/First Order,@ariabrisard @leiaorganasolo Good. Play along with her. You may begin your operation with the death star. The evacuation is nearly complete.,0,0 158 | 3134,4503,emergency,,Emergency Flow http://t.co/lH9mrYpDrJ mp3 http://t.co/PqhuthSS3i rar http://t.co/0iW6dRf5X9,1,1 159 | 1799,2585,crash,Galatians 2:20 ,Please keep Josh the Salyers/Blair/Hall families & Jenna's friends in your prayers. She was taken far too soon. RIP http://t.co/bDN2FDPdAz,0,0 160 | 6339,9061,structural%20failure,,Investigators say a fatal Virgin Galactic spaceship crash last year was caused by structural failure after the co‰Û_,1,1 161 | 5162,7361,obliterate,Cymru araul,@McCaineNL Think how spectacular it will look when the Stonewall riots obliterate the white house.,1,1 162 | 4788,6811,loud%20bang,,Who the fuck plays music extremely loud at 9am on a Tuesday morning?? Bruv do they want me to come bang them??,0,0 163 | 2871,4127,drought,Meereen ,Pizza drought is over I just couldn't anymore...,0,0 164 | 3020,4335,dust%20storm,"Atlanta, GA",DUST IN THE WIND: @82ndABNDIV paratroopers move to a loading zone during a dust storm in support of Operation Fury: http://t.co/uGesKLCn8M,1,1 165 | 880,1275,blood,,Shed innocent blood of their sons and daughters and the land was polluted Psalms 106:38 Help stop the sin of abortion.,0,0 166 | 5138,7328,nuclear%20reactor,"Den Helder, Rijkswerf",US Navy Sidelines 3 Newest #Subs http://t.co/9WQixGMHfh,0,0 167 | 4216,5989,hazardous,,DLH issues Hazardous Weather Outlook (HWO) http://t.co/WOzuBXRi2p,1,1 168 | 1310,1892,burning,"Caracas, Venezuela.",We will be burning up like neon lights??????,1,1 169 | 3939,5600,flood,United States,Flood Advisory issued August 05 at 4:28PM EDT by NWS http://t.co/D02sbM0ojs #WxKY,1,1 170 | 6017,8595,seismic,garowe puntland somalia,Oil and Gas Exploration http://t.co/PckF0nl2yN,1,1 171 | 6105,8717,sinking,"Michigan, USA",‰Û¢‰Û¢If your lost & alone or your sinking like a stone carry onå¡å¡,0,0 172 | 937,1354,blown%20up,Nowhere Islands/Smash Manor,@TheBoyOfMasks 'Thanks again for letting me stay here since the manor was blown up..... Anyways how are you doing buddy?',0,0 173 | 314,457,armageddon,Canada,@ENews Ben Affleck......I know there's a wife/kids and other girls but I can't help it. I've loved him since Armageddon #eonlinechat,0,0 174 | 6275,8966,storm,,What tropical storm? #guillermo by hawaiianpaddlesports http://t.co/LgPgAjgomY http://t.co/FKd1mBTB68,1,1 175 | 1889,2715,crushed,#HAMont,Edwin wow. Crushed.,0,0 176 | 5550,7916,rainstorm,,#DnB #NewRelease EDGE Jimmy - Summer Rainstorm (Lapaka Sounds) http://t.co/4L8h2FKlNO via http://t.co/ZITQKDFXJY,0,0 177 | 6415,9173,suicide%20bomber,,http://t.co/9k1tqsAarM Suicide bomber kills 15 in Saudi security site mosque - Reuters http://t.co/Ev3nX9scx3,1,1 178 | 5722,8165,rescuers,,Last Second Ebay Bid RT? http://t.co/oEKUcq4ZL0 Shaolin Rescuers (dvd 2010) Shen Chan Nan Chiang Five Venoms Kung Fu ?Please Favori,0,0 179 | 5071,7229,natural%20disaster,"Suburban Detroit, Michigan",We have overrun a Natural Disaster Survival server!,1,1 180 | 4274,6072,heat%20wave,"Oklahoma City, OK",Longest Streak of Triple-Digit Heat Since 2013 Forecast in Dallas: An unrelenting and dangerous heat wave will... http://t.co/s4Srgrmqcz,1,1 181 | 1035,1501,body%20bags,Swaning Around,US ‰Û÷Institute Of Peace‰Ûª Chairman Wants Russian Body Bags http://t.co/owbUjez3q4,0,0 182 | 4316,6128,hellfire,"Denver, Colorado",@MechaMacGyver Wow bet you got blamed for that too huh?,0,0 183 | 2077,2983,dead,"Sochi, KDA, RU",@hlportal Hello! I'm looking for mod Cold Ice. I saw it on your site but link to download dead. Maybe you have it and share with me? Thanks.,0,0 184 | 3899,5545,flattened,Le Memenet,@KainYusanagi @Grummz @PixelCanuck and flattened raynor. Raynor was a balding imperfect biker marine not a emo generic western hero.,0,0 185 | 3974,5650,flooding,Global-NoLocation,#flood #disaster Burst Water Pipe Floods Apartments at NYCHA Senior Center - NY1: NY1Burst Water Pipe Floods A... http://t.co/w7SIIdujOH,1,1 186 | 7050,10101,typhoon,The Peach State,I think a Typhoon just passed through here lol,1,1 187 | 3383,4845,evacuation,"sydney, australia",my school just put the evacuation alarms on accidently with 2 different trial exams happening are you kidding me,0,0 188 | 7562,10811,wrecked,,you wrecked my whole world,0,0 189 | 2088,3001,dead,Milton Keynes ,Can't believe Ross is dead???????? @emmerdale @MikeParrActor #Emmerdale #summerfate,0,0 190 | 6011,8584,screams,5-Feb,When you go to a concert and someone screams in your ear... Does it look like I wanna loose my hearing anytime soon???,1,1 191 | 2295,3292,demolish,,I have completed the quest 'Demolish 5 Murlo...' in the #Android game The Tribez. http://t.co/pBclFsXRld #androidgames #gameinsight,0,0 192 | 4745,6749,lightning,,Thunder and lightning possible in the Pinpoint Foothill Forecast. http://t.co/CtIjdPXABk,1,1 193 | 16,24,,,I love fruits,0,0 194 | 6596,9448,terrorism,,@RobPulseNews @huyovoeTripolye Phillips should be charged for assisting terrorism. LDNR-terrorists' organizations. http://t.co/XwnJYsV9V9,1,1 195 | 514,740,attacked,"SÌ£o Paulo SP, Brasil",Christian Attacked by Muslims at the Temple Mount after Waving Israeli Flag via Pamela Geller - ... http://t.co/e4wK8Uri8A,1,1 196 | 2551,3659,destroy,,"ng2x5 mhtw4fnet 197 | 198 | Watch Michael Jordan absolutely destroy this meme-baiting camper - FOXSportscom",0,0 199 | 6184,8829,sirens,,@iK4LEN Sirens was cancelled.,0,0 200 | 7359,10536,windstorm,,Windstorm lastingness perquisite - acquiesce in a twister retreat: ZiUW http://t.co/iRt4kkgsJx,1,1 201 | 927,1343,blown%20up,Georgia,Man why hasn't @machinegunkelly blown up? He's still underground.,0,0 202 | 7238,10366,weapons,??????,#Battlefield 1942 forgotten hope secret weapons,1,1 203 | 88,130,accident,"Manchester, NH",Accident left lane blocked in #Manchester on Rt 293 NB before Eddy Rd stop and go traffic back to NH-3A delay of 4 mins #traffic,1,1 204 | 4374,6215,hijacker,,Governor allows parole for California school bus hijacker: Local... http://t.co/tAM6aoskoJ http://t.co/eL24mnFcHw,1,1 205 | 5922,8454,screamed,ljp/4,THE GIRLS NEXT TO ME SCREAMED WHAT THE FUCK IS A CHONCE I'm CRYIBG,0,0 206 | 6834,9788,trapped,å_å_Los Mina City‰ã¢,Hollywood Movie About Trapped Miners Released in Chile: 'The 33' Hollywood movie about trapped miners starring... http://t.co/x8moYeVjsJ,0,0 207 | 5842,8348,ruin,London / Birmingham,Im so anxious though because so many ppl will me watching me meet them and that makes me uncomfortable BUT I CANT LET THAT RUIN THE MOMENT,0,0 208 | 2039,2927,danger,Spinning through time.,@riverroaming 'And not too much danger please.',0,0 209 | 2942,4230,drowned,,Thank you Uplifting spirit. When Im drowned you've been an anchor,0,0 210 | 523,755,avalanche,Ireland,A little piece I wrote for the Avalanche Designs blog! I'd appreciate it greatly if you checked it out :-) https://t.co/rfvjh58eF2,0,0 211 | 6432,9203,suicide%20bombing,,< 25 Dead In Kuwait Mosque Suicide Bombing Claimed By ISIS Offshoot on http://t.co/eTITgPSrUN,1,1 212 | 925,1340,blown%20up,USA,@troylercraft YEAH ITS NOT WORTH IT BC HE ALREADY HAS SO MANY SPAMMERS & HIS TWITTER IS PROBABLY BLOWN UP EVERY SECOND,0,0 213 | 587,848,bioterror,iTunes,#world FedEx no longer to transport bioterror germs in wake of anthrax lab mishaps http://t.co/wvExJjRG6E,1,1 214 | 3661,5211,fatality,"Lowell, MA",@EBROINTHEAM jay....big L....pun....biggie...wrap over...zero question....fatality...flawless victory http://t.co/Y33QcKq7qD,0,0 215 | 6503,9297,survive,United States,AHH forgot my headphones how am I supposed to survive a day without music AYHHHHHDJJFJRJJRDJJEKS,0,0 216 | 5120,7301,nuclear%20reactor,"USA, North Dakota",Salem 2 nuclear reactor shut down over electrical circuit failure on pump: The Salem 2 nuclear reactor had bee... http://t.co/5hkGXzJLmX,1,1 217 | 7541,10782,wreckage,"New Delhi,India",Wreckage 'Conclusively Confirmed' as From MH370: Malaysia PM: Investigators and the families of those who were... http://t.co/1YIxFG1Hdy,1,1 218 | 4779,6800,loud%20bang,Kenya,tkyonly1fmk: Breaking news! Unconfirmed! I just heard a loud bang nearby. in what appears to be a blast of wind from my neighbour's ass.,0,0 219 | 2342,3369,demolition,,#download & #watch Demolition Frog (2002) http://t.co/81nEizeknm #movie,1,1 220 | 7526,10764,wreckage,"Dublin City, Ireland",Wreckage 'Conclusively Confirmed' as From MH370: Malaysia PM: Investigators and the families of those who were... http://t.co/VAZpG0ftmU,1,1 221 | 2614,3752,destruction,,Crackdown 3 Destruction Restricted to Multiplayer: Crackdown 3 impressed earlier this week with a demonstratio... http://t.co/LMWKjsYCgj,0,0 222 | 2638,3786,destruction,Jersey,That's the ultimate road to destruction,0,0 223 | 963,1393,body%20bag,,new summer long thin body bag hip A word skirt Blue http://t.co/lvKoEMsq8m http://t.co/CjiRhHh4vj,0,0 224 | 383,552,arson,"Charlotte, NC",Add Familia to the arson squad.,0,0 225 | 4522,6425,hurricane,,Stream HYPE HURRICANE,0,0 226 | 2049,2941,danger,Instagram: trillrebel_,"Guns are for protection.. 227 | That shit really shouldn't be used unless your life in danger",0,0 228 | 7532,10770,wreckage,iTunes,#science Now that a piece of wreckage from flight MH370 has been confirmed on RÌ©union Island is it possible t... http://t.co/Z2vDGIyOwf,1,1 229 | 5345,7630,pandemonium,illinois. united state ,Pandemonium In Aba As Woman Delivers Baby Without Face (Photos) @... http://t.co/JbxBi93CLu,0,0 230 | 1342,1940,burning%20buildings,,Hero's fight wars and save ppl from burning buildings etc I'm sorry but u gotta do more than pay 4 a sex change be4 I call u a hero,0,0 231 | -------------------------------------------------------------------------------- /data/english_to_latex.csv: -------------------------------------------------------------------------------- 1 | English,LaTeX 2 | integral from a to b of x squared,"\int_{a}^{b} x^2 \,dx" 3 | integral from negative 1 to 1 of x squared,"\int_{-1}^{1} x^2 \,dx" 4 | integral from negative 1 to infinity of x cubed,"\int_{-1}^{\inf} x^3 \,dx" 5 | integral from 0 to infinity of x squared,"\int_{0}^{\inf} x^2 \,dx" 6 | integral from 0 to infinity of y squared,"\int_{0}^{\inf} y^2 \,dy" 7 | integral from 1 to 2 of x over 2,"\int_{1}^{2} \frac{x}{2} \,dx" 8 | f of x equals x squared,f(x) = x^2 9 | h of x equals x squared,h(x) = x^2 10 | g of x equals x squared,g(x) = x^2 11 | g of x equals x to the eighth power,g(x) = x^8 12 | f of x equals x cubed,f(x) = x^3 13 | f of x equals x,f(x) = x 14 | h of x equals x to the fifth power,h(x) = x^5 15 | g of x equals integral from 0 to 10 of x cubed,"g(x) = \int_{0}^{10} x^3 \,dx" 16 | f of x equals x over n,f(x) = \frac{x}{n} 17 | f of x equals integral from 1 to 2 of x,"f(x) = \int_{1}^{2} x \,dx" 18 | f of x equals integral from 0 to 2 of x,"f(x) = \int_{0}^{2} x \,dx" 19 | f of x equals integral from 1 to 2 of x over 2,"f(x) = \int_{1}^{2} \frac{x}{2} \,dx" 20 | f of x equals sum from 1 to 5 of x squared,f(x) = \sum_{1}^{5} x^2 21 | x squared,x^2 22 | x cubed,x^3 23 | pi squared,\pi^2 24 | z squared,z^2 25 | z over x squared,\frac{z}{x^2} 26 | f of x equals x squared,f(x) = x^2 27 | 1 over 6,\frac{1}{6} 28 | 2 pi,2 * \pi 29 | s cubed,s^3 30 | s to the sixth power,s^6 31 | 2 pi r,2 * \pi * r 32 | pi over n,\frac{\pi}{n} 33 | f of n equals pi over n,f(n) = \frac{\pi}{n} 34 | pi times x,\pi*x 35 | pi to the fourth power,\pi^4 36 | pi to the fifth power,\pi^5 37 | f of x equals x times pi to the fifth power,f(x) = x * \pi^5 38 | g of x equals x times pi cubed,g(x) = x * \pi^3 39 | g of x equals pi cubed,g(x) = \pi^3 40 | 1 over n,\frac{1}{n} 41 | x squared over n,\frac{x^2}{n} 42 | y squared over x^2,\frac{y^2}{x^2} 43 | 1 over 7 to the seventh power,(\frac{1}{7})^7 44 | 1 over 9 to the seventh power,(\frac{1}{9})^7 45 | f of x equals x over 9 to the seventh power,(f(x) = \frac{x}{9})^7 46 | sum from i to n of X i,\sum_{i}^{n} X_i 47 | sum from 0 to n of 77 n,\sum_{0}^{n} 77 * n 48 | sum from 0 to 5 of x,\sum_{0}^{5} x 49 | sum from 1 to x of x,\sum_{1}^{x} x 50 | sum from 1 to x of x squared,\sum_{1}^{x} x^2 51 | sum from 1 to 10 of pi squared,\sum_{1}^{10} \pi^2 52 | -------------------------------------------------------------------------------- /images/Venus_and_Adonis_by_Peter_Paul_Rubens.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sinanuozdemir/oreilly-pytorch-dl/3e9ccbcf5182f2e89e6db4655242f1208fa34058/images/Venus_and_Adonis_by_Peter_Paul_Rubens.jpg -------------------------------------------------------------------------------- /images/oreilly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sinanuozdemir/oreilly-pytorch-dl/3e9ccbcf5182f2e89e6db4655242f1208fa34058/images/oreilly.png -------------------------------------------------------------------------------- /notebooks/mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "876f9496-2aeb-4b3c-9e92-7d8e3c1bf051", 6 | "metadata": {}, 7 | "source": [ 8 | "# Training a DL model on the MNIST dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "4cfbde3a-e540-49e8-99f0-63993987a6d5", 14 | "metadata": { 15 | "jp-MarkdownHeadingCollapsed": true 16 | }, 17 | "source": [ 18 | "## Imports and Data Setup" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "bdfba6c5", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# Import required libraries\n", 29 | "import torch\n", 30 | "import torch.nn as nn\n", 31 | "import matplotlib.pyplot as plt\n", 32 | "import torch.nn.functional as F\n", 33 | "\n", 34 | "import torchvision\n", 35 | "import numpy as np\n", 36 | "import torchvision.transforms as transforms\n", 37 | "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "id": "deb4a334", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "text/plain": [ 49 | "" 50 | ] 51 | }, 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "# standardize random numbers for reproducibility\n", 59 | "torch.manual_seed(0)\n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "id": "cf93c989", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "img = torchvision.datasets.MNIST(root='../data', train=True, download=True)[0]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "id": "41a542d4-c4c1-4640-9145-3c0836e0e5c3", 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "(, 5)" 82 | ] 83 | }, 84 | "execution_count": 5, 85 | "metadata": {}, 86 | "output_type": "execute_result" 87 | } 88 | ], 89 | "source": [ 90 | "img # PIL == Pillow" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 6, 96 | "id": "0de01bf4", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# Load MNIST dataset\n", 101 | "batch_size = 4\n", 102 | "transform = transforms.Compose(\n", 103 | " [\n", 104 | " transforms.ToTensor(), # convert to pytorch tensor\n", 105 | " transforms.Normalize((0.1307,), (0.3081,)), # standardize the values with mean + std. Found before running code\n", 106 | " # transforms.RandomVerticalFlip()\n", 107 | " ] \n", 108 | ")\n", 109 | "\n", 110 | "train_dataset = torchvision.datasets.MNIST(root='../data', train=True, transform=transform, download=True)\n", 111 | "test_dataset = torchvision.datasets.MNIST(root='../data', train=False, transform=transform, download=True)\n", 112 | "\n", 113 | "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", 114 | "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 7, 120 | "id": "e926a88c", 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "image/png": "", 126 | "text/plain": [ 127 | "
" 128 | ] 129 | }, 130 | "metadata": {}, 131 | "output_type": "display_data" 132 | } 133 | ], 134 | "source": [ 135 | "# transpose the array to (28, 28) format expected by Matplotlib\n", 136 | "array = np.squeeze(np.transpose(train_dataset[0][0], (1, 2, 0)))\n", 137 | "\n", 138 | "# plot the image using Matplotlib\n", 139 | "plt.imshow(array)\n", 140 | "plt.show()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 8, 146 | "id": "42f64815", 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "data": { 151 | "text/plain": [ 152 | "torch.Size([4, 1, 28, 28])" 153 | ] 154 | }, 155 | "execution_count": 8, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "next(iter(test_loader))[0].shape" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "id": "30d74d36-e002-48f2-884f-7f695fb75050", 167 | "metadata": { 168 | "jp-MarkdownHeadingCollapsed": true 169 | }, 170 | "source": [ 171 | "## Creating our Model" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 9, 177 | "id": "9ef55b65", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "# Define the neural network model\n", 182 | "class NeuralNet(nn.Module):\n", 183 | " def __init__(self, input_size, hidden_size, num_classes): # e.g. input_size=10, hidden_size=4, num_classes=2\n", 184 | " super(NeuralNet, self).__init__()\n", 185 | " self.fc1 = nn.Linear(input_size, hidden_size) # 1x10 input vector X 10x4 = 1x4\n", 186 | " self.relu = nn.ReLU()\n", 187 | " self.fc2 = nn.Linear(hidden_size, num_classes) # 1x4 X 4x2 = 1x2\n", 188 | " \n", 189 | " def forward(self, x):\n", 190 | " out = self.fc1(x)\n", 191 | " out = self.relu(out)\n", 192 | " out = self.fc2(out)\n", 193 | " return out" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 10, 199 | "id": "aaae9c90-4106-41e0-8d44-4360ce1d36c5", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "NeuralNet(\n", 206 | " (fc1): Linear(in_features=10, out_features=4, bias=True)\n", 207 | " (relu): ReLU()\n", 208 | " (fc2): Linear(in_features=4, out_features=2, bias=True)\n", 209 | ")" 210 | ] 211 | }, 212 | "execution_count": 10, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | } 216 | ], 217 | "source": [ 218 | "model = NeuralNet(10, 4, 2)\n", 219 | "model" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 11, 225 | "id": "96c8b840-f4ba-4991-bdee-df14528e877a", 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "data": { 230 | "text/plain": [ 231 | "torch.Size([4, 10])" 232 | ] 233 | }, 234 | "execution_count": 11, 235 | "metadata": {}, 236 | "output_type": "execute_result" 237 | } 238 | ], 239 | "source": [ 240 | "model.fc1.weight.shape" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "id": "ab9aac7b-0a3c-48fd-b03e-6b552fb365ae", 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 12, 254 | "id": "27f2be33-9fdd-4bd3-8f56-25f374cf3137", 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "image/png": "", 260 | "text/plain": [ 261 | "
" 262 | ] 263 | }, 264 | "metadata": {}, 265 | "output_type": "display_data" 266 | } 267 | ], 268 | "source": [ 269 | "import matplotlib.pyplot as plt\n", 270 | "import matplotlib.patches as patches\n", 271 | "\n", 272 | "def visualize_neural_network(input_size, hidden_size, num_classes):\n", 273 | " fig, ax = plt.subplots()\n", 274 | "\n", 275 | " # Define the layers\n", 276 | " layers = [input_size, hidden_size, num_classes]\n", 277 | " layer_positions = [(0, i) for i in range(len(layers))]\n", 278 | "\n", 279 | " # Draw the layers\n", 280 | " for i, (x, layer) in enumerate(zip(layer_positions, layers)):\n", 281 | " for j in range(layer):\n", 282 | " circle = patches.Circle((i * 2, j - layer / 2), radius=0.1, fill=True)\n", 283 | " ax.add_patch(circle)\n", 284 | "\n", 285 | " # Draw connections to the next layer\n", 286 | " if i < len(layers) - 1:\n", 287 | " for k in range(layers[i + 1]):\n", 288 | " line = plt.Line2D([i * 2 + 0.1, (i + 1) * 2 - 0.1], [j - layer / 2, k - layers[i + 1] / 2], color=\"black\")\n", 289 | " ax.add_line(line)\n", 290 | "\n", 291 | " ax.set_xlim(-0.5, len(layers) * 2)\n", 292 | " ax.set_ylim(-max(layers) / 2 - 0.5, max(layers) / 2 + 0.5)\n", 293 | " ax.set_aspect('equal')\n", 294 | " plt.axis('off')\n", 295 | " plt.title('Neural Network Architecture')\n", 296 | " plt.show()\n", 297 | "\n", 298 | "# Example parameters\n", 299 | "input_size = 10\n", 300 | "hidden_size = 4\n", 301 | "num_classes = 2\n", 302 | "\n", 303 | "# Visualize the neural network\n", 304 | "visualize_neural_network(input_size, hidden_size, num_classes)\n" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 13, 310 | "id": "fca379b3", 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "# Some loss function options\n", 315 | "cross_entropy_loss = nn.CrossEntropyLoss()\n", 316 | "nll_loss = nn.NLLLoss()\n", 317 | "mse_loss = nn.MSELoss()\n", 318 | "\n", 319 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "020fbccd-eb05-4ea5-a0ca-c92198d6eefa", 325 | "metadata": { 326 | "jp-MarkdownHeadingCollapsed": true 327 | }, 328 | "source": [ 329 | "## Training our model" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 14, 335 | "id": "c8df8e47", 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "# input layer, a single hidden layer, and an output layer\n", 340 | "# Define hyperparameters\n", 341 | "input_size = 784 # 28x28 pixels\n", 342 | "hidden_size = 128\n", 343 | "num_classes = 10\n", 344 | "\n", 345 | "batch_size = 100" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 31, 351 | "id": "502faaf8", 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "# Function to train the model\n", 356 | "import time\n", 357 | "def train_model(model, loss_function, train_loader, optimizer, num_epochs=10):\n", 358 | " # Put the model in training mode\n", 359 | " model.train()\n", 360 | " start_time = time.time()\n", 361 | " \n", 362 | " # Iterate over the number of epochs\n", 363 | " for epoch in range(num_epochs):\n", 364 | " # Initialize the running loss for this epoch to zero\n", 365 | " running_loss = 0.0\n", 366 | " \n", 367 | " # Iterate over each batch in the training loader\n", 368 | " for i, (images, labels) in enumerate(train_loader):\n", 369 | " \n", 370 | " # Reshape the images tensor to have size (batch_size, input_size)\n", 371 | " images = images.reshape(-1, input_size)\n", 372 | " # Forward pass: compute the outputs of the model given the input images\n", 373 | " outputs = model(images)\n", 374 | " # Compute the loss between the outputs and the true labels\n", 375 | " loss = loss_function(outputs, labels)\n", 376 | " # Backward pass: compute the gradients of the loss with respect to the model parameters\n", 377 | " optimizer.zero_grad()\n", 378 | " loss.backward()\n", 379 | " # Update the model parameters using the optimizer\n", 380 | " optimizer.step()\n", 381 | " \n", 382 | " # Add the current batch loss to the running loss for this epoch\n", 383 | " running_loss += loss.item()\n", 384 | "\n", 385 | " # Compute the average loss over all batches for this epoch and print it\n", 386 | " print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')\n", 387 | " training_time = (time.time() - start_time)\n", 388 | " print(f'Training took: {round(training_time, 1)} seconds.')\n" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 32, 394 | "id": "410c88bf", 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "# Function to evaluate the model\n", 399 | "def evaluate_model(model, test_loader):\n", 400 | " # Put the model in evaluation mode\n", 401 | " model.eval()\n", 402 | " \n", 403 | " # Initialize empty lists to store true and predicted labels\n", 404 | " y_true, y_pred = [], []\n", 405 | " \n", 406 | " # Disable gradient computation since we're only evaluating the model\n", 407 | " with torch.no_grad():\n", 408 | " # Iterate over each batch in the test loader\n", 409 | " for images, labels in test_loader:\n", 410 | " # Reshape the images tensor to have size (batch_size, input_size)\n", 411 | " images = images.reshape(-1, input_size)\n", 412 | " \n", 413 | " # Forward pass: compute the outputs of the model given the input images\n", 414 | " outputs = model(images)\n", 415 | " \n", 416 | " # Find the predicted class for each image in the batch\n", 417 | " _, predicted = torch.max(outputs.data, 1)\n", 418 | " \n", 419 | " # Append the true and predicted labels for this batch to the lists\n", 420 | " y_true.extend(labels.numpy())\n", 421 | " y_pred.extend(predicted.numpy())\n", 422 | " \n", 423 | " # Calculate evaluation metrics using the true and predicted labels\n", 424 | " accuracy, f1, precision, recall = evaluate_model_metrics(np.array(y_true), np.array(y_pred))\n", 425 | " \n", 426 | " # Return the evaluation metrics\n", 427 | " return accuracy, f1, precision, recall\n", 428 | "\n", 429 | "# Function to calculate evaluation metrics\n", 430 | "def evaluate_model_metrics(y_true, y_pred):\n", 431 | " # Compute the accuracy, F1 score, precision, and recall\n", 432 | " accuracy = accuracy_score(y_true, y_pred)\n", 433 | " f1 = f1_score(y_true, y_pred, average='macro')\n", 434 | " precision = precision_score(y_true, y_pred, average='macro')\n", 435 | " recall = recall_score(y_true, y_pred, average='macro')\n", 436 | "\n", 437 | " # Return the evaluation metrics\n", 438 | " return accuracy, f1, precision, recall\n", 439 | "\n", 440 | "# Define a dictionary of loss functions\n", 441 | "loss_functions = {\n", 442 | " 'CrossEntropyLoss': nn.CrossEntropyLoss(),\n", 443 | " 'NLLLoss': nn.NLLLoss(), # requires logsoftmax output (according to docs, so loss will be wrong without it!)\n", 444 | " 'MultiMarginLoss': nn.MultiMarginLoss(),\n", 445 | " 'KLDivLoss': nn.KLDivLoss() # requires logsoftmax output (according to docs, so loss will be wrong without it!)\n", 446 | "}" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 33, 452 | "id": "559edd99-720b-47f1-b677-b8264c4e9981", 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "data": { 457 | "text/plain": [ 458 | "784" 459 | ] 460 | }, 461 | "execution_count": 33, 462 | "metadata": {}, 463 | "output_type": "execute_result" 464 | } 465 | ], 466 | "source": [ 467 | "input_size" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 34, 473 | "id": "156ac9db", 474 | "metadata": { 475 | "scrolled": true 476 | }, 477 | "outputs": [ 478 | { 479 | "name": "stdout", 480 | "output_type": "stream", 481 | "text": [ 482 | "Training with CrossEntropyLoss:\n", 483 | "Epoch [1/5], Loss: 0.2124\n", 484 | "Epoch [2/5], Loss: 0.1234\n", 485 | "Epoch [3/5], Loss: 0.1008\n", 486 | "Epoch [4/5], Loss: 0.0905\n", 487 | "Epoch [5/5], Loss: 0.0793\n", 488 | "Training took: 74.2 seconds.\n", 489 | "Performance metrics for CrossEntropyLoss:\n", 490 | "Accuracy: 0.9716, F1-score: 0.9715, Precision: 0.9718, Recall: 0.9716\n", 491 | "\n", 492 | "Training with NLLLoss:\n", 493 | "Epoch [1/5], Loss: -2130538.3142\n", 494 | "Epoch [2/5], Loss: -12779579.6207\n", 495 | "Epoch [3/5], Loss: -32398089.4337\n", 496 | "Epoch [4/5], Loss: -60796260.3199\n", 497 | "Epoch [5/5], Loss: -97924332.6755\n", 498 | "Training took: 74.6 seconds.\n" 499 | ] 500 | }, 501 | { 502 | "name": "stderr", 503 | "output_type": "stream", 504 | "text": [ 505 | "/Users/sinanozdemir/Library/Python/3.9/lib/python/site-packages/sklearn/metrics/_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 506 | " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" 507 | ] 508 | }, 509 | { 510 | "name": "stdout", 511 | "output_type": "stream", 512 | "text": [ 513 | "Performance metrics for NLLLoss:\n", 514 | "Accuracy: 0.1135, F1-score: 0.0204, Precision: 0.0114, Recall: 0.1000\n", 515 | "\n", 516 | "Training with MultiMarginLoss:\n", 517 | "Epoch [1/5], Loss: 0.0368\n", 518 | "Epoch [2/5], Loss: 0.0201\n", 519 | "Epoch [3/5], Loss: 0.0166\n", 520 | "Epoch [4/5], Loss: 0.0140\n", 521 | "Epoch [5/5], Loss: 0.0134\n", 522 | "Training took: 70.7 seconds.\n", 523 | "Performance metrics for MultiMarginLoss:\n", 524 | "Accuracy: 0.9661, F1-score: 0.9658, Precision: 0.9663, Recall: 0.9658\n", 525 | "\n", 526 | "Training with KLDivLoss:\n" 527 | ] 528 | }, 529 | { 530 | "name": "stderr", 531 | "output_type": "stream", 532 | "text": [ 533 | "/Users/sinanozdemir/Library/Python/3.9/lib/python/site-packages/torch/nn/functional.py:2919: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.\n", 534 | " warnings.warn(\n" 535 | ] 536 | }, 537 | { 538 | "ename": "RuntimeError", 539 | "evalue": "kl_div: Integral inputs not supported.", 540 | "output_type": "error", 541 | "traceback": [ 542 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 543 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 544 | "Cell \u001b[0;32mIn[34], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(model\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.001\u001b[39m)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# Train the model\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloss_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# Evaluate the model\u001b[39;00m\n\u001b[1;32m 13\u001b[0m accuracy, f1, precision, recall \u001b[38;5;241m=\u001b[39m evaluate_model(model, test_loader)\n", 545 | "Cell \u001b[0;32mIn[31], line 21\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(model, loss_function, train_loader, optimizer, num_epochs)\u001b[0m\n\u001b[1;32m 19\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(images)\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# Compute the loss between the outputs and the true labels\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mloss_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# Backward pass: compute the gradients of the loss with respect to the model parameters\u001b[39;00m\n\u001b[1;32m 23\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n", 546 | "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", 547 | "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/loss.py:471\u001b[0m, in \u001b[0;36mKLDivLoss.forward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor, target: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 471\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkl_div\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreduction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlog_target\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_target\u001b[49m\u001b[43m)\u001b[49m\n", 548 | "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/nn/functional.py:2931\u001b[0m, in \u001b[0;36mkl_div\u001b[0;34m(input, target, size_average, reduce, reduction, log_target)\u001b[0m\n\u001b[1;32m 2928\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2929\u001b[0m reduction_enum \u001b[38;5;241m=\u001b[39m _Reduction\u001b[38;5;241m.\u001b[39mget_enum(reduction)\n\u001b[0;32m-> 2931\u001b[0m reduced \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkl_div\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduction_enum\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlog_target\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_target\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2933\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m reduction \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatchmean\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 2934\u001b[0m reduced \u001b[38;5;241m=\u001b[39m reduced \u001b[38;5;241m/\u001b[39m \u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39msize()[\u001b[38;5;241m0\u001b[39m]\n", 549 | "\u001b[0;31mRuntimeError\u001b[0m: kl_div: Integral inputs not supported." 550 | ] 551 | } 552 | ], 553 | "source": [ 554 | "# Train and evaluate the model using different loss functions\n", 555 | "for loss_name, loss_function in loss_functions.items():\n", 556 | " print(f'Training with {loss_name}:')\n", 557 | " \n", 558 | " # Initialize a new model and optimizer for each loss function\n", 559 | " model = NeuralNet(input_size, hidden_size, num_classes)\n", 560 | " optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", 561 | " \n", 562 | " # Train the model\n", 563 | " train_model(model, loss_function, train_loader, optimizer, num_epochs=5)\n", 564 | "\n", 565 | " # Evaluate the model\n", 566 | " accuracy, f1, precision, recall = evaluate_model(model, test_loader)\n", 567 | " print(f'Performance metrics for {loss_name}:')\n", 568 | " print(f'Accuracy: {accuracy:.4f}, F1-score: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}\\n')\n" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": null, 574 | "id": "0116ef96-efc8-4699-8dd2-2c3a36b96bfd", 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [ 578 | "# UH OH, KL didn't even run :(" 579 | ] 580 | }, 581 | { 582 | "cell_type": "markdown", 583 | "id": "c714633a-9df3-41d9-9c82-6163d5dfb174", 584 | "metadata": { 585 | "jp-MarkdownHeadingCollapsed": true 586 | }, 587 | "source": [ 588 | "## Making our code with with KL and NLL" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": 39, 594 | "id": "3efdc826-eb71-4974-8bb5-52bbad2784c7", 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "def one_hot_encode(labels, num_classes):\n", 599 | " # Create a tensor of zeros with shape [len(labels), num_classes]\n", 600 | " one_hot = torch.zeros(len(labels), num_classes)\n", 601 | " \n", 602 | " # Use scatter_ to assign 1s to the correct class indices\n", 603 | " one_hot.scatter_(1, labels.unsqueeze(1), 1)\n", 604 | " \n", 605 | " return one_hot\n" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": 40, 611 | "id": "47c8d6a0-1abb-4b5e-9e0e-7cfcd14da1b0", 612 | "metadata": {}, 613 | "outputs": [ 614 | { 615 | "data": { 616 | "text/plain": [ 617 | "tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 618 | " [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])" 619 | ] 620 | }, 621 | "execution_count": 40, 622 | "metadata": {}, 623 | "output_type": "execute_result" 624 | } 625 | ], 626 | "source": [ 627 | "one_hot_encode(torch.tensor([1, 7]), 10)" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 43, 633 | "id": "cbb6dd60-1a9d-4063-b38e-e82135a335b3", 634 | "metadata": {}, 635 | "outputs": [], 636 | "source": [ 637 | "# Function to train the model\n", 638 | "def train_model(model, loss_function, train_loader, optimizer, num_epochs=10):\n", 639 | " # Put the model in training mode\n", 640 | " model.train()\n", 641 | " start_time = time.time()\n", 642 | " \n", 643 | " # Iterate over the number of epochs\n", 644 | " for epoch in range(num_epochs):\n", 645 | " # Initialize the running loss for this epoch to zero\n", 646 | " running_loss = 0.0\n", 647 | " \n", 648 | " # Iterate over each batch in the training loader\n", 649 | " for i, (images, labels) in enumerate(train_loader):\n", 650 | " \n", 651 | " # Reshape the images tensor to have size (batch_size, input_size)\n", 652 | " images = images.reshape(-1, input_size)\n", 653 | " # Forward pass: compute the outputs of the model given the input images\n", 654 | " outputs = model(images)\n", 655 | "\n", 656 | " #### NEW ####\n", 657 | " if 'KL' in str(loss_function) or 'NLL' in str(loss_function):\n", 658 | " outputs = F.log_softmax(outputs, dim=1)\n", 659 | " if 'KL' in str(loss_function): # KL also need one hot encoded labels: [[0, 0, 1, 0, 0], [0, 0, 0, 0, 1]] vs [2, 4]\n", 660 | " labels = one_hot_encode(labels, 10)\n", 661 | " #### NEW ####\n", 662 | " \n", 663 | " # Compute the loss between the outputs and the true labels\n", 664 | " loss = loss_function(outputs, labels)\n", 665 | " # Backward pass: compute the gradients of the loss with respect to the model parameters\n", 666 | " optimizer.zero_grad()\n", 667 | " loss.backward()\n", 668 | " # Update the model parameters using the optimizer\n", 669 | " optimizer.step()\n", 670 | " \n", 671 | " # Add the current batch loss to the running loss for this epoch\n", 672 | " running_loss += loss.item()\n", 673 | "\n", 674 | " \n", 675 | "\n", 676 | " # Compute the average loss over all batches for this epoch and print it\n", 677 | " print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')\n", 678 | " training_time = (time.time() - start_time)\n", 679 | " print(f'Training took: {round(training_time, 1)} seconds.')\n" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 44, 685 | "id": "e5c84301-45e4-497d-849b-56124f51f571", 686 | "metadata": {}, 687 | "outputs": [ 688 | { 689 | "name": "stdout", 690 | "output_type": "stream", 691 | "text": [ 692 | "Training with NLLLoss:\n", 693 | "Epoch [1/5], Loss: 0.2193\n", 694 | "Epoch [2/5], Loss: 0.1263\n", 695 | "Epoch [3/5], Loss: 0.1039\n", 696 | "Epoch [4/5], Loss: 0.0916\n", 697 | "Epoch [5/5], Loss: 0.0841\n", 698 | "Training took: 83.9 seconds.\n", 699 | "Performance metrics for NLLLoss:\n", 700 | "Accuracy: 0.9680, F1-score: 0.9679, Precision: 0.9684, Recall: 0.9678\n", 701 | "\n", 702 | "Training with KLDivLoss:\n", 703 | "Epoch [1/5], Loss: 0.2135\n", 704 | "Epoch [2/5], Loss: 0.1239\n", 705 | "Epoch [3/5], Loss: 0.1030\n", 706 | "Epoch [4/5], Loss: 0.0904\n", 707 | "Epoch [5/5], Loss: 0.0815\n", 708 | "Training took: 79.7 seconds.\n", 709 | "Performance metrics for KLDivLoss:\n", 710 | "Accuracy: 0.9679, F1-score: 0.9677, Precision: 0.9681, Recall: 0.9674\n", 711 | "\n", 712 | "Training with CrossEntropyLoss:\n", 713 | "Epoch [1/5], Loss: 0.2133\n", 714 | "Epoch [2/5], Loss: 0.1242\n", 715 | "Epoch [3/5], Loss: 0.1019\n", 716 | "Epoch [4/5], Loss: 0.0918\n", 717 | "Epoch [5/5], Loss: 0.0833\n", 718 | "Training took: 74.7 seconds.\n", 719 | "Performance metrics for CrossEntropyLoss:\n", 720 | "Accuracy: 0.9674, F1-score: 0.9670, Precision: 0.9672, Recall: 0.9669\n", 721 | "\n" 722 | ] 723 | } 724 | ], 725 | "source": [ 726 | "# Define a dictionary of loss functions\n", 727 | "loss_functions = {\n", 728 | "\n", 729 | " 'NLLLoss': nn.NLLLoss(), \n", 730 | " 'KLDivLoss': nn.KLDivLoss(reduction='batchmean'), # to avoid deprecation warning, I'm setting reduction to batchmean\n", 731 | " 'CrossEntropyLoss': nn.CrossEntropyLoss()\n", 732 | "}\n", 733 | "# Train and evaluate the model using different loss functions\n", 734 | "for loss_name, loss_function in loss_functions.items():\n", 735 | " print(f'Training with {loss_name}:')\n", 736 | " \n", 737 | " # Initialize a new model and optimizer for each loss function\n", 738 | " model = NeuralNet(input_size, hidden_size, num_classes)\n", 739 | " optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", 740 | " \n", 741 | " # Train the model\n", 742 | " train_model(model, loss_function, train_loader, optimizer, num_epochs=5)\n", 743 | "\n", 744 | " # Evaluate the model\n", 745 | " accuracy, f1, precision, recall = evaluate_model(model, test_loader)\n", 746 | " print(f'Performance metrics for {loss_name}:')\n", 747 | " print(f'Accuracy: {accuracy:.4f}, F1-score: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}\\n')\n" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": 45, 753 | "id": "ff44b62b-0e6c-4a12-8e65-22ca97c1929a", 754 | "metadata": {}, 755 | "outputs": [ 756 | { 757 | "name": "stderr", 758 | "output_type": "stream", 759 | "text": [ 760 | "/Users/sinanozdemir/Library/Python/3.9/lib/python/site-packages/torchvision/datasets/mnist.py:65: UserWarning: train_labels has been renamed targets\n", 761 | " warnings.warn(\"train_labels has been renamed targets\")\n" 762 | ] 763 | }, 764 | { 765 | "data": { 766 | "text/plain": [ 767 | "" 768 | ] 769 | }, 770 | "execution_count": 45, 771 | "metadata": {}, 772 | "output_type": "execute_result" 773 | }, 774 | { 775 | "data": { 776 | "image/png": "", 777 | "text/plain": [ 778 | "
" 779 | ] 780 | }, 781 | "metadata": {}, 782 | "output_type": "display_data" 783 | } 784 | ], 785 | "source": [ 786 | "import pandas as pd\n", 787 | "\n", 788 | "# KL is better for imbalanced data, which this isn't so much\n", 789 | "pd.Series(test_dataset.train_labels).value_counts().plot(kind='bar')" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": 46, 795 | "id": "00bad6b3-bb53-4cf2-a3c4-a1431a7c46c3", 796 | "metadata": {}, 797 | "outputs": [ 798 | { 799 | "data": { 800 | "text/plain": [ 801 | "Dataset MNIST\n", 802 | " Number of datapoints: 10000\n", 803 | " Root location: ../data\n", 804 | " Split: Test\n", 805 | " StandardTransform\n", 806 | "Transform: Compose(\n", 807 | " ToTensor()\n", 808 | " Normalize(mean=(0.1307,), std=(0.3081,))\n", 809 | " )" 810 | ] 811 | }, 812 | "execution_count": 46, 813 | "metadata": {}, 814 | "output_type": "execute_result" 815 | } 816 | ], 817 | "source": [ 818 | "test_dataset" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": 47, 824 | "id": "39a4155c-8fda-44e1-9064-4dca92cd74b7", 825 | "metadata": {}, 826 | "outputs": [ 827 | { 828 | "data": { 829 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbKUlEQVR4nO3df3DU9b3v8dcCyQqYbAwh2UQCBvxBFUinFNJclMaSS4hnGFDOHVBvBxwvXGlwhNTqiaMgbeemxTno0UPxnxbqGQHLuQJHTi8djSaMbYKHKIfLtWZIJhYYklBzD9kQJATyuX9wXV1JwO+ym3eyPB8z3xmy+/3k+/br6pNvsvnG55xzAgBggA2zHgAAcH0iQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwMQI6wG+rre3VydPnlRKSop8Pp/1OAAAj5xz6uzsVE5OjoYN6/86Z9AF6OTJk8rNzbUeAwBwjY4fP65x48b1+/ygC1BKSook6W7dpxFKMp4GAODVBfXoff0+/P/z/sQtQJs2bdILL7yg1tZW5efn65VXXtHMmTOvuu6LL7uNUJJG+AgQAAw5//8Oo1f7Nkpc3oTwxhtvqLy8XOvWrdOHH36o/Px8lZSU6NSpU/E4HABgCIpLgDZu3Kjly5frkUce0Z133qlXX31Vo0aN0m9+85t4HA4AMATFPEDnz59XfX29iouLvzzIsGEqLi5WbW3tZft3d3crFApFbACAxBfzAH322We6ePGisrKyIh7PyspSa2vrZftXVlYqEAiEN94BBwDXB/MfRK2oqFBHR0d4O378uPVIAIABEPN3wWVkZGj48OFqa2uLeLytrU3BYPCy/f1+v/x+f6zHAAAMcjG/AkpOTtb06dNVVVUVfqy3t1dVVVUqLCyM9eEAAENUXH4OqLy8XEuXLtV3v/tdzZw5Uy+99JK6urr0yCOPxONwAIAhKC4BWrx4sf76179q7dq1am1t1be//W3t27fvsjcmAACuXz7nnLMe4qtCoZACgYCKtIA7IQDAEHTB9ahae9TR0aHU1NR+9zN/FxwA4PpEgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMxDxAzz//vHw+X8Q2efLkWB8GADDEjYjHJ73rrrv0zjvvfHmQEXE5DABgCItLGUaMGKFgMBiPTw0ASBBx+R7Q0aNHlZOTo4kTJ+rhhx/WsWPH+t23u7tboVAoYgMAJL6YB6igoEBbt27Vvn37tHnzZjU3N+uee+5RZ2dnn/tXVlYqEAiEt9zc3FiPBAAYhHzOORfPA5w+fVoTJkzQxo0b9eijj172fHd3t7q7u8Mfh0Ih5ebmqkgLNMKXFM/RAABxcMH1qFp71NHRodTU1H73i/u7A9LS0nT77bersbGxz+f9fr/8fn+8xwAADDJx/zmgM2fOqKmpSdnZ2fE+FABgCIl5gJ588knV1NTo008/1Z/+9Cfdf//9Gj58uB588MFYHwoAMITF/EtwJ06c0IMPPqj29naNHTtWd999t+rq6jR27NhYHwoAMITFPEA7duyI9acEACQg7gUHADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJiI+y+kw8BqX17oec34H/b9ywKv5pNTWZ7XnO/2/ltub97ufc2oE2c8r5Gk3kMfR7UOgHdcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEd8NOME/9ZJvnNYtG/0d0B5sU3TLPirwv+fTC2agO9Q9/vTeqdRg4H5ya4HnN6L8PRHWsEVX1Ua3DN8MVEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABggpuRJpiXn1niec3aadH9PeSmPzvPa/7jWz7Pa5Knnfa8ZsOUNz2vkaQXsw94XvOvZ2/0vOZvRp3xvGYgfe7Oe15zoHu05zVFN/R4XqMo/h3duvi/ez+OpNurolqGb4grIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABDcjTTCj/9n7jRpH/3McBulH6gAd55VgUVTrfj7rFs9rUmsaPa/ZUHSr5zUDacTnvZ7XjD7c4nnNmP3/0/OaqclJnteM+tT7GsQfV0AAABMECABgwnOA9u/fr/nz5ysnJ0c+n0+7d++OeN45p7Vr1yo7O1sjR45UcXGxjh49Gqt5AQAJwnOAurq6lJ+fr02bNvX5/IYNG/Tyyy/r1Vdf1YEDBzR69GiVlJTo3Llz1zwsACBxeH4TQmlpqUpLS/t8zjmnl156Sc8++6wWLFggSXrttdeUlZWl3bt3a8kS77+tEwCQmGL6PaDm5ma1traquLg4/FggEFBBQYFqa2v7XNPd3a1QKBSxAQASX0wD1NraKknKysqKeDwrKyv83NdVVlYqEAiEt9zc3FiOBAAYpMzfBVdRUaGOjo7wdvz4ceuRAAADIKYBCgaDkqS2traIx9va2sLPfZ3f71dqamrEBgBIfDENUF5enoLBoKqqqsKPhUIhHThwQIWFhbE8FABgiPP8LrgzZ86osfHLW480Nzfr0KFDSk9P1/jx47V69Wr9/Oc/12233aa8vDw999xzysnJ0cKFC2M5NwBgiPMcoIMHD+ree+8Nf1xeXi5JWrp0qbZu3aqnnnpKXV1dWrFihU6fPq27775b+/bt0w033BC7qQEAQ57POeesh/iqUCikQCCgIi3QCB83EASGivb/5v3L7LXr/9Hzmo3/d7LnNfvnTvK8RpIutPT97l1c2QXXo2rtUUdHxxW/r2/+LjgAwPWJAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJjz/OgYAiW/EhFzPa/7xGe93tk7yDfe8Zuc/FHteM6al1vMaxB9XQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GCuAyn6y52fOaGX6f5zX/5/znntekf3zW8xoMTlwBAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmuBkpkMC6/2ZGVOs+/NsXo1jl97xi5RNPeF4z8k8feF6DwYkrIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABDcjBRLYsdLo/o55o8/7jUUfbP7PnteM2vfvntc4zyswWHEFBAAwQYAAACY8B2j//v2aP3++cnJy5PP5tHv37ojnly1bJp/PF7HNmzcvVvMCABKE5wB1dXUpPz9fmzZt6nefefPmqaWlJbxt3779moYEACQez29CKC0tVWlp6RX38fv9CgaDUQ8FAEh8cfkeUHV1tTIzM3XHHXdo5cqVam9v73ff7u5uhUKhiA0AkPhiHqB58+bptddeU1VVlX75y1+qpqZGpaWlunjxYp/7V1ZWKhAIhLfc3NxYjwQAGIRi/nNAS5YsCf956tSpmjZtmiZNmqTq6mrNmTPnsv0rKipUXl4e/jgUChEhALgOxP1t2BMnTlRGRoYaGxv7fN7v9ys1NTViAwAkvrgH6MSJE2pvb1d2dna8DwUAGEI8fwnuzJkzEVczzc3NOnTokNLT05Wenq7169dr0aJFCgaDampq0lNPPaVbb71VJSUlMR0cADC0eQ7QwYMHde+994Y//uL7N0uXLtXmzZt1+PBh/fa3v9Xp06eVk5OjuXPn6mc/+5n8fu/3lgIAJC7PASoqKpJz/d8O8A9/+MM1DQSgb8NSUjyv+eE970d1rFDvOc9rTv2PiZ7X+Lv/zfMaJA7uBQcAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATMf+V3ADi4+jzd3leszfjV1Eda8HRRZ7X+H/Pna3hDVdAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJbkYKGOj4r9/zvObw4pc9r2m60ON5jSSd+eU4z2v8aonqWLh+cQUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJjgZqTANRpxc47nNaufe8PzGr/P+3+uS/79h57XSNLY//VvUa0DvOAKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwc1Iga/wjfD+n0T+3hOe1/yXG9s9r3m9M9Pzmqznovs7Zm9UqwBvuAICAJggQAAAE54CVFlZqRkzZiglJUWZmZlauHChGhoaIvY5d+6cysrKNGbMGN14441atGiR2traYjo0AGDo8xSgmpoalZWVqa6uTm+//bZ6eno0d+5cdXV1hfdZs2aN3nrrLe3cuVM1NTU6efKkHnjggZgPDgAY2jx9x3Xfvn0RH2/dulWZmZmqr6/X7Nmz1dHRoV//+tfatm2bfvCDH0iStmzZom9961uqq6vT9773vdhNDgAY0q7pe0AdHR2SpPT0dElSfX29enp6VFxcHN5n8uTJGj9+vGpra/v8HN3d3QqFQhEbACDxRR2g3t5erV69WrNmzdKUKVMkSa2trUpOTlZaWlrEvllZWWptbe3z81RWVioQCIS33NzcaEcCAAwhUQeorKxMR44c0Y4dO65pgIqKCnV0dIS348ePX9PnAwAMDVH9IOqqVau0d+9e7d+/X+PGjQs/HgwGdf78eZ0+fTriKqitrU3BYLDPz+X3++X3+6MZAwAwhHm6AnLOadWqVdq1a5feffdd5eXlRTw/ffp0JSUlqaqqKvxYQ0ODjh07psLCwthMDABICJ6ugMrKyrRt2zbt2bNHKSkp4e/rBAIBjRw5UoFAQI8++qjKy8uVnp6u1NRUPf744yosLOQdcACACJ4CtHnzZklSUVFRxONbtmzRsmXLJEkvvviihg0bpkWLFqm7u1slJSX61a9+FZNhAQCJw+ecc9ZDfFUoFFIgEFCRFmiEL8l6HFxnfNPv8rzmX//ln+IwyeX+U0WZ5zVpr/X94w9APF1wParWHnV0dCg1NbXf/bgXHADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExE9RtRgcFu+J23R7VuxY49MZ6kb3f+xvudrW/5p7o4TALY4QoIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADDBzUiRkD750U1RrZs/KhTjSfo2rvq890XOxX4QwBBXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GikHv3PyZntdUzf/7KI82Ksp1ALziCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHNSDHonZw13POa8SMG7qair3dmel6TFDrveY3zvAIY3LgCAgCYIEAAABOeAlRZWakZM2YoJSVFmZmZWrhwoRoaGiL2KSoqks/ni9gee+yxmA4NABj6PAWopqZGZWVlqqur09tvv62enh7NnTtXXV1dEfstX75cLS0t4W3Dhg0xHRoAMPR5ehPCvn37Ij7eunWrMjMzVV9fr9mzZ4cfHzVqlILBYGwmBAAkpGv6HlBHR4ckKT09PeLx119/XRkZGZoyZYoqKip09uzZfj9Hd3e3QqFQxAYASHxRvw27t7dXq1ev1qxZszRlypTw4w899JAmTJignJwcHT58WE8//bQaGhr05ptv9vl5KisrtX79+mjHAAAMUVEHqKysTEeOHNH7778f8fiKFSvCf546daqys7M1Z84cNTU1adKkSZd9noqKCpWXl4c/DoVCys3NjXYsAMAQEVWAVq1apb1792r//v0aN27cFfctKCiQJDU2NvYZIL/fL7/fH80YAIAhzFOAnHN6/PHHtWvXLlVXVysvL++qaw4dOiRJys7OjmpAAEBi8hSgsrIybdu2TXv27FFKSopaW1slSYFAQCNHjlRTU5O2bdum++67T2PGjNHhw4e1Zs0azZ49W9OmTYvLPwAAYGjyFKDNmzdLuvTDpl+1ZcsWLVu2TMnJyXrnnXf00ksvqaurS7m5uVq0aJGeffbZmA0MAEgMnr8EdyW5ubmqqam5poEAANcH7oYNfEVl+52e19SW3OJ5jWv5357XAImGm5ECAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GSkGvYl/V+t5zX1/9504TNKf1gE8FpA4uAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgYtDdC845J0m6oB7JGQ8DAPDsgnokffn/8/4MugB1dnZKkt7X740nAQBci87OTgUCgX6f97mrJWqA9fb26uTJk0pJSZHP54t4LhQKKTc3V8ePH1dqaqrRhPY4D5dwHi7hPFzCebhkMJwH55w6OzuVk5OjYcP6/07PoLsCGjZsmMaNG3fFfVJTU6/rF9gXOA+XcB4u4Txcwnm4xPo8XOnK5wu8CQEAYIIAAQBMDKkA+f1+rVu3Tn6/33oUU5yHSzgPl3AeLuE8XDKUzsOgexMCAOD6MKSugAAAiYMAAQBMECAAgAkCBAAwMWQCtGnTJt1yyy264YYbVFBQoA8++MB6pAH3/PPPy+fzRWyTJ0+2Hivu9u/fr/nz5ysnJ0c+n0+7d++OeN45p7Vr1yo7O1sjR45UcXGxjh49ajNsHF3tPCxbtuyy18e8efNsho2TyspKzZgxQykpKcrMzNTChQvV0NAQsc+5c+dUVlamMWPG6MYbb9SiRYvU1tZmNHF8fJPzUFRUdNnr4bHHHjOauG9DIkBvvPGGysvLtW7dOn344YfKz89XSUmJTp06ZT3agLvrrrvU0tIS3t5//33rkeKuq6tL+fn52rRpU5/Pb9iwQS+//LJeffVVHThwQKNHj1ZJSYnOnTs3wJPG19XOgyTNmzcv4vWxffv2AZww/mpqalRWVqa6ujq9/fbb6unp0dy5c9XV1RXeZ82aNXrrrbe0c+dO1dTU6OTJk3rggQcMp469b3IeJGn58uURr4cNGzYYTdwPNwTMnDnTlZWVhT++ePGiy8nJcZWVlYZTDbx169a5/Px86zFMSXK7du0Kf9zb2+uCwaB74YUXwo+dPn3a+f1+t337doMJB8bXz4Nzzi1dutQtWLDAZB4rp06dcpJcTU2Nc+7Sv/ukpCS3c+fO8D5//vOfnSRXW1trNWbcff08OOfc97//fffEE0/YDfUNDPoroPPnz6u+vl7FxcXhx4YNG6bi4mLV1tYaTmbj6NGjysnJ0cSJE/Xwww/r2LFj1iOZam5uVmtra8TrIxAIqKCg4Lp8fVRXVyszM1N33HGHVq5cqfb2duuR4qqjo0OSlJ6eLkmqr69XT09PxOth8uTJGj9+fEK/Hr5+Hr7w+uuvKyMjQ1OmTFFFRYXOnj1rMV6/Bt3NSL/us88+08WLF5WVlRXxeFZWlj755BOjqWwUFBRo69atuuOOO9TS0qL169frnnvu0ZEjR5SSkmI9nonW1lZJ6vP18cVz14t58+bpgQceUF5enpqamvTMM8+otLRUtbW1Gj58uPV4Mdfb26vVq1dr1qxZmjJliqRLr4fk5GSlpaVF7JvIr4e+zoMkPfTQQ5owYYJycnJ0+PBhPf3002poaNCbb75pOG2kQR8gfKm0tDT852nTpqmgoEATJkzQ7373Oz366KOGk2EwWLJkSfjPU6dO1bRp0zRp0iRVV1drzpw5hpPFR1lZmY4cOXJdfB/0Svo7DytWrAj/eerUqcrOztacOXPU1NSkSZMmDfSYfRr0X4LLyMjQ8OHDL3sXS1tbm4LBoNFUg0NaWppuv/12NTY2Wo9i5ovXAK+Py02cOFEZGRkJ+fpYtWqV9u7dq/feey/i17cEg0GdP39ep0+fjtg/UV8P/Z2HvhQUFEjSoHo9DPoAJScna/r06aqqqgo/1tvbq6qqKhUWFhpOZu/MmTNqampSdna29Shm8vLyFAwGI14foVBIBw4cuO5fHydOnFB7e3tCvT6cc1q1apV27dqld999V3l5eRHPT58+XUlJSRGvh4aGBh07diyhXg9XOw99OXTokCQNrteD9bsgvokdO3Y4v9/vtm7d6j7++GO3YsUKl5aW5lpbW61HG1A//vGPXXV1tWtubnZ//OMfXXFxscvIyHCnTp2yHi2uOjs73UcffeQ++ugjJ8lt3LjRffTRR+4vf/mLc865X/ziFy4tLc3t2bPHHT582C1YsMDl5eW5zz//3Hjy2LrSeejs7HRPPvmkq62tdc3Nze6dd95x3/nOd9xtt93mzp07Zz16zKxcudIFAgFXXV3tWlpawtvZs2fD+zz22GNu/Pjx7t1333UHDx50hYWFrrCw0HDq2LvaeWhsbHQ//elP3cGDB11zc7Pbs2ePmzhxops9e7bx5JGGRICcc+6VV15x48ePd8nJyW7mzJmurq7OeqQBt3jxYpedne2Sk5PdzTff7BYvXuwaGxutx4q79957z0m6bFu6dKlz7tJbsZ977jmXlZXl/H6/mzNnjmtoaLAdOg6udB7Onj3r5s6d68aOHeuSkpLchAkT3PLlyxPuL2l9/fNLclu2bAnv8/nnn7sf/ehH7qabbnKjRo1y999/v2tpabEbOg6udh6OHTvmZs+e7dLT053f73e33nqr+8lPfuI6OjpsB/8afh0DAMDEoP8eEAAgMREgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJv4fx1BnJzDsp98AAAAASUVORK5CYII=", 830 | "text/plain": [ 831 | "
" 832 | ] 833 | }, 834 | "metadata": {}, 835 | "output_type": "display_data" 836 | } 837 | ], 838 | "source": [ 839 | "# transpose the array to (28, 28) format expected by Matplotlib\n", 840 | "array = np.squeeze(np.transpose(test_dataset[0][0], (1, 2, 0)))\n", 841 | "\n", 842 | "# plot the image using Matplotlib\n", 843 | "plt.imshow(array)\n", 844 | "plt.show()" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": 48, 850 | "id": "c6b6d5fc-a407-4704-8a0d-bce5060e8eb1", 851 | "metadata": {}, 852 | "outputs": [ 853 | { 854 | "data": { 855 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAac0lEQVR4nO3dfWxU973n8c8A9gQSe6gx9niCoYYQSENwbym4viSUFC/G0UU83Yo8dBeiCAQx2QJNk3WVhKSt5IZIaZTUBe1uC41uIA9SgA1KqYiJjdLadHFgEdvUF1tuMYttGrSeMSYYg3/7B5tpJtjQY2b89Zj3SzoSnjk/n29ORnnneIZjn3POCQCAATbMegAAwM2JAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMjrAf4sp6eHp0+fVppaWny+XzW4wAAPHLOqaOjQ6FQSMOG9X2dM+gCdPr0aeXm5lqPAQC4Qc3NzRo3blyfzw+6AKWlpUmS7tUDGqEU42kAAF5dUrc+0vvR/573JWEBqqio0EsvvaTW1lbl5+frtdde06xZs6677vMfu41Qikb4CBAAJJ3/f4fR672NkpAPIbz11lvauHGjNm3apI8//lj5+fkqLi7WmTNnEnE4AEASSkiAXn75Za1atUqPPvqovva1r2nr1q0aNWqUfv3rXyficACAJBT3AF28eFF1dXUqKir6+0GGDVNRUZFqamqu2r+rq0uRSCRmAwAMfXEP0KeffqrLly8rOzs75vHs7Gy1trZetX95ebkCgUB04xNwAHBzMP+LqGVlZQqHw9GtubnZeiQAwACI+6fgMjMzNXz4cLW1tcU83tbWpmAweNX+fr9ffr8/3mMAAAa5uF8BpaamasaMGaqsrIw+1tPTo8rKShUWFsb7cACAJJWQvwe0ceNGrVixQt/85jc1a9YsvfLKK+rs7NSjjz6aiMMBAJJQQgK0fPly/e1vf9Nzzz2n1tZWff3rX9e+ffuu+mACAODm5XPOOeshvigSiSgQCGiuFnEnBABIQpdct6q0R+FwWOnp6X3uZ/4pOADAzYkAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwMcJ6AACJ0/6fCvu17tDPtnhe87WKxz2vGf/iHz2vcZcueV6DwYkrIACACQIEADAR9wA9//zz8vl8MdvUqVPjfRgAQJJLyHtAd999tz744IO/H2QEbzUBAGIlpAwjRoxQMBhMxLcGAAwRCXkP6MSJEwqFQpo4caIeeeQRnTx5ss99u7q6FIlEYjYAwNAX9wAVFBRo+/bt2rdvn7Zs2aKmpibdd9996ujo6HX/8vJyBQKB6JabmxvvkQAAg1DcA1RSUqLvfve7mj59uoqLi/X++++rvb1db7/9dq/7l5WVKRwOR7fm5uZ4jwQAGIQS/umA0aNH684771RDQ0Ovz/v9fvn9/kSPAQAYZBL+94DOnTunxsZG5eTkJPpQAIAkEvcAPfnkk6qurtZf/vIX/eEPf9CSJUs0fPhwPfTQQ/E+FAAgicX9R3CnTp3SQw89pLNnz2rs2LG69957VVtbq7Fjx8b7UACAJOZzzjnrIb4oEokoEAhorhZphC/Fehxg0Bhxe8jzmu9X7+/XseaP6u7XOq9KptzneU1PH5+oxeBxyXWrSnsUDoeVnp7e537cCw4AYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMJHwX0gHID7OFE/wvGagbioqSd84vNzzmrHn/j0BkyBZcAUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9wNGzAwbNQoz2uK//NHCZgkfvxvfsX7IufiPwiSBldAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJbkYKGOj657s8r/lp1q8SMEnvzvdc9LwmfUdtAibBUMYVEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABggpuRAgaalg63HuGa/vXE4n6sOh3vMTDEcQUEADBBgAAAJjwH6ODBg1q4cKFCoZB8Pp92794d87xzTs8995xycnI0cuRIFRUV6cSJE/GaFwAwRHgOUGdnp/Lz81VRUdHr85s3b9arr76qrVu36tChQ7r11ltVXFysCxcu3PCwAIChw/OHEEpKSlRSUtLrc845vfLKK3rmmWe0aNEiSdLrr7+u7Oxs7d69Ww8++OCNTQsAGDLi+h5QU1OTWltbVVRUFH0sEAiooKBANTU1va7p6upSJBKJ2QAAQ19cA9Ta2ipJys7Ojnk8Ozs7+tyXlZeXKxAIRLfc3Nx4jgQAGKTMPwVXVlamcDgc3Zqbm61HAgAMgLgGKBgMSpLa2tpiHm9ra4s+92V+v1/p6ekxGwBg6ItrgPLy8hQMBlVZWRl9LBKJ6NChQyosLIznoQAASc7zp+DOnTunhoaG6NdNTU06evSoMjIyNH78eK1fv14//elPNXnyZOXl5enZZ59VKBTS4sWL4zk3ACDJeQ7Q4cOHdf/990e/3rhxoyRpxYoV2r59u5566il1dnZq9erVam9v17333qt9+/bplltuid/UAICk53POOeshvigSiSgQCGiuFmmEL8V6HCAhJv9Pv+c1v7j9kOc14Z7PPK+RpH995HHPa4ZVH+nXsTD0XHLdqtIehcPha76vb/4pOADAzYkAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmPP86BgCxuh6Y6XnNL27/bwmY5GqnLvVvHXe2xkDgCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHNSIEb1DYzxXqEPi3cu75f6ybrUHwHAXrBFRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKbkQI3KPWf/u+AHOeTi+c9r5n66qf9Otblfq0CvOEKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwc1IgS+48C+zPK85PHNLP4403POK+u4sz2su/3uj5zXAQOEKCABgggABAEx4DtDBgwe1cOFChUIh+Xw+7d69O+b5lStXyufzxWwLFiyI17wAgCHCc4A6OzuVn5+vioqKPvdZsGCBWlpaotvOnTtvaEgAwNDj+UMIJSUlKikpueY+fr9fwWCw30MBAIa+hLwHVFVVpaysLE2ZMkVr167V2bNn+9y3q6tLkUgkZgMADH1xD9CCBQv0+uuvq7KyUi+++KKqq6tVUlKiy5d7/y3z5eXlCgQC0S03NzfeIwEABqG4/z2gBx98MPrne+65R9OnT9ekSZNUVVWlefPmXbV/WVmZNm7cGP06EokQIQC4CST8Y9gTJ05UZmamGhoaen3e7/crPT09ZgMADH0JD9CpU6d09uxZ5eTkJPpQAIAk4vlHcOfOnYu5mmlqatLRo0eVkZGhjIwMvfDCC1q2bJmCwaAaGxv11FNP6Y477lBxcXFcBwcAJDfPATp8+LDuv//+6Nefv3+zYsUKbdmyRceOHdNvfvMbtbe3KxQKaf78+frJT34iv98fv6kBAEnPc4Dmzp0r51yfz//ud7+7oYEAS59ler9JaIrP+5r+eKpuqec1eTqWgEmA+OBecAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADAR91/JDSSzrsXtA3KcTy6e97xm3H9PScAkgB2ugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMFEPS8Dsn9Wvd4Zn/1p+jeV7x23PTPK9J+aDO8xpgMOMKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwc1IMSS13Z/Vr3UpPu83Fu2PX3z4HzyvmaxDCZgEsMMVEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABggpuRYki6kOEbsGPVdV30vOauF095XnPJ8wpgcOMKCABgggABAEx4ClB5eblmzpyptLQ0ZWVlafHixaqvr4/Z58KFCyotLdWYMWN02223admyZWpra4vr0ACA5OcpQNXV1SotLVVtba3279+v7u5uzZ8/X52dndF9NmzYoPfee0/vvPOOqqurdfr0aS1dujTugwMAkpunDyHs27cv5uvt27crKytLdXV1mjNnjsLhsH71q19px44d+s53viNJ2rZtm+666y7V1tbqW9/6VvwmBwAktRt6DygcDkuSMjIyJEl1dXXq7u5WUVFRdJ+pU6dq/Pjxqqmp6fV7dHV1KRKJxGwAgKGv3wHq6enR+vXrNXv2bE2bNk2S1NraqtTUVI0ePTpm3+zsbLW2tvb6fcrLyxUIBKJbbm5uf0cCACSRfgeotLRUx48f15tvvnlDA5SVlSkcDke35ubmG/p+AIDk0K+/iLpu3Trt3btXBw8e1Lhx46KPB4NBXbx4Ue3t7TFXQW1tbQoGg71+L7/fL7/f358xAABJzNMVkHNO69at065du3TgwAHl5eXFPD9jxgylpKSosrIy+lh9fb1OnjypwsLC+EwMABgSPF0BlZaWaseOHdqzZ4/S0tKi7+sEAgGNHDlSgUBAjz32mDZu3KiMjAylp6friSeeUGFhIZ+AAwDE8BSgLVu2SJLmzp0b8/i2bdu0cuVKSdLPf/5zDRs2TMuWLVNXV5eKi4v1y1/+Mi7DAgCGDk8Bcs5dd59bbrlFFRUVqqio6PdQwI3K+s7/GbBj/Y/IP3lec/lvnyZgEiC5cC84AIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmOjXb0QFBpKvH78xd1HofyVgkt6dvXib5zWuqysBkwDJhSsgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAENyPF4Hf5sucl//WTe/t1qPX//BfPa6qa7/C85nb9b89rgKGGKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQ3I8Wg5y5d8rzmq/+ls1/Huqv8P3pe4zua1q9jATc7roAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABPcjBRD0uWGpn6tG//dOA8CoE9cAQEATBAgAIAJTwEqLy/XzJkzlZaWpqysLC1evFj19fUx+8ydO1c+ny9mW7NmTVyHBgAkP08Bqq6uVmlpqWpra7V//351d3dr/vz56uyM/eVfq1atUktLS3TbvHlzXIcGACQ/Tx9C2LdvX8zX27dvV1ZWlurq6jRnzpzo46NGjVIwGIzPhACAIemG3gMKh8OSpIyMjJjH33jjDWVmZmratGkqKyvT+fPn+/weXV1dikQiMRsAYOjr98ewe3p6tH79es2ePVvTpk2LPv7www9rwoQJCoVCOnbsmJ5++mnV19fr3Xff7fX7lJeX64UXXujvGACAJOVzzrn+LFy7dq1++9vf6qOPPtK4ceP63O/AgQOaN2+eGhoaNGnSpKue7+rqUldXV/TrSCSi3NxczdUijfCl9Gc0AIChS65bVdqjcDis9PT0Pvfr1xXQunXrtHfvXh08ePCa8ZGkgoICSeozQH6/X36/vz9jAACSmKcAOef0xBNPaNeuXaqqqlJeXt511xw9elSSlJOT068BAQBDk6cAlZaWaseOHdqzZ4/S0tLU2toqSQoEAho5cqQaGxu1Y8cOPfDAAxozZoyOHTumDRs2aM6cOZo+fXpC/gEAAMnJ03tAPp+v18e3bdumlStXqrm5Wd/73vd0/PhxdXZ2Kjc3V0uWLNEzzzxzzZ8DflEkElEgEOA9IABIUgl5D+h6rcrNzVV1dbWXbwkAuElxLzgAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgIkR1gN8mXNOknRJ3ZIzHgYA4NkldUv6+3/P+zLoAtTR0SFJ+kjvG08CALgRHR0dCgQCfT7vc9dL1ADr6enR6dOnlZaWJp/PF/NcJBJRbm6umpublZ6ebjShPc7DFZyHKzgPV3AerhgM58E5p46ODoVCIQ0b1vc7PYPuCmjYsGEaN27cNfdJT0+/qV9gn+M8XMF5uILzcAXn4Qrr83CtK5/P8SEEAIAJAgQAMJFUAfL7/dq0aZP8fr/1KKY4D1dwHq7gPFzBebgimc7DoPsQAgDg5pBUV0AAgKGDAAEATBAgAIAJAgQAMJE0AaqoqNBXv/pV3XLLLSooKNAf//hH65EG3PPPPy+fzxezTZ061XqshDt48KAWLlyoUCgkn8+n3bt3xzzvnNNzzz2nnJwcjRw5UkVFRTpx4oTNsAl0vfOwcuXKq14fCxYssBk2QcrLyzVz5kylpaUpKytLixcvVn19fcw+Fy5cUGlpqcaMGaPbbrtNy5YtU1tbm9HEifGPnIe5c+de9XpYs2aN0cS9S4oAvfXWW9q4caM2bdqkjz/+WPn5+SouLtaZM2esRxtwd999t1paWqLbRx99ZD1SwnV2dio/P18VFRW9Pr9582a9+uqr2rp1qw4dOqRbb71VxcXFunDhwgBPmljXOw+StGDBgpjXx86dOwdwwsSrrq5WaWmpamtrtX//fnV3d2v+/Pnq7OyM7rNhwwa99957euedd1RdXa3Tp09r6dKlhlPH3z9yHiRp1apVMa+HzZs3G03cB5cEZs2a5UpLS6NfX7582YVCIVdeXm441cDbtGmTy8/Ptx7DlCS3a9eu6Nc9PT0uGAy6l156KfpYe3u78/v9bufOnQYTDowvnwfnnFuxYoVbtGiRyTxWzpw54yS56upq59yVf/cpKSnunXfeie7zySefOEmupqbGasyE+/J5cM65b3/72+773/++3VD/gEF/BXTx4kXV1dWpqKgo+tiwYcNUVFSkmpoaw8lsnDhxQqFQSBMnTtQjjzyikydPWo9kqqmpSa2trTGvj0AgoIKCgpvy9VFVVaWsrCxNmTJFa9eu1dmzZ61HSqhwOCxJysjIkCTV1dWpu7s75vUwdepUjR8/fki/Hr58Hj73xhtvKDMzU9OmTVNZWZnOnz9vMV6fBt3NSL/s008/1eXLl5WdnR3zeHZ2tv785z8bTWWjoKBA27dv15QpU9TS0qIXXnhB9913n44fP660tDTr8Uy0trZKUq+vj8+fu1ksWLBAS5cuVV5enhobG/WjH/1IJSUlqqmp0fDhw63Hi7uenh6tX79es2fP1rRp0yRdeT2kpqZq9OjRMfsO5ddDb+dBkh5++GFNmDBBoVBIx44d09NPP636+nq9++67htPGGvQBwt+VlJRE/zx9+nQVFBRowoQJevvtt/XYY48ZTobB4MEHH4z++Z577tH06dM1adIkVVVVad68eYaTJUZpaamOHz9+U7wPei19nYfVq1dH/3zPPfcoJydH8+bNU2NjoyZNmjTQY/Zq0P8ILjMzU8OHD7/qUyxtbW0KBoNGUw0Oo0eP1p133qmGhgbrUcx8/hrg9XG1iRMnKjMzc0i+PtatW6e9e/fqww8/jPn1LcFgUBcvXlR7e3vM/kP19dDXeehNQUGBJA2q18OgD1BqaqpmzJihysrK6GM9PT2qrKxUYWGh4WT2zp07p8bGRuXk5FiPYiYvL0/BYDDm9RGJRHTo0KGb/vVx6tQpnT17dki9PpxzWrdunXbt2qUDBw4oLy8v5vkZM2YoJSUl5vVQX1+vkydPDqnXw/XOQ2+OHj0qSYPr9WD9KYh/xJtvvun8fr/bvn27+9Of/uRWr17tRo8e7VpbW61HG1A/+MEPXFVVlWtqanK///3vXVFRkcvMzHRnzpyxHi2hOjo63JEjR9yRI0ecJPfyyy+7I0eOuL/+9a/OOed+9rOfudGjR7s9e/a4Y8eOuUWLFrm8vDz32WefGU8eX9c6Dx0dHe7JJ590NTU1rqmpyX3wwQfuG9/4hps8ebK7cOGC9ehxs3btWhcIBFxVVZVraWmJbufPn4/us2bNGjd+/Hh34MABd/jwYVdYWOgKCwsNp46/652HhoYG9+Mf/9gdPnzYNTU1uT179riJEye6OXPmGE8eKykC5Jxzr732mhs/frxLTU11s2bNcrW1tdYjDbjly5e7nJwcl5qa6m6//Xa3fPly19DQYD1Wwn344YdO0lXbihUrnHNXPor97LPPuuzsbOf3+928efNcfX297dAJcK3zcP78eTd//nw3duxYl5KS4iZMmOBWrVo15P4nrbd/fklu27Zt0X0+++wz9/jjj7uvfOUrbtSoUW7JkiWupaXFbugEuN55OHnypJszZ47LyMhwfr/f3XHHHe6HP/yhC4fDtoN/Cb+OAQBgYtC/BwQAGJoIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABP/D7c1Y6KPV4dUAAAAAElFTkSuQmCC", 856 | "text/plain": [ 857 | "
" 858 | ] 859 | }, 860 | "metadata": {}, 861 | "output_type": "display_data" 862 | } 863 | ], 864 | "source": [ 865 | "# transpose the array to (28, 28) format expected by Matplotlib\n", 866 | "array = np.squeeze(np.transpose(test_dataset[5][0], (1, 2, 0)))\n", 867 | "\n", 868 | "# plot the image using Matplotlib\n", 869 | "plt.imshow(array)\n", 870 | "plt.show()" 871 | ] 872 | }, 873 | { 874 | "cell_type": "markdown", 875 | "id": "5b6ed347-5fd3-4749-861e-09cd9c532fdf", 876 | "metadata": { 877 | "jp-MarkdownHeadingCollapsed": true 878 | }, 879 | "source": [ 880 | "## Playing with Mean Squared Error Loss (MSE)" 881 | ] 882 | }, 883 | { 884 | "cell_type": "code", 885 | "execution_count": 49, 886 | "id": "9fe7de7b-a05b-463d-bbf1-b3b99d8e7ddc", 887 | "metadata": {}, 888 | "outputs": [ 889 | { 890 | "data": { 891 | "text/plain": [ 892 | "tensor(0.)" 893 | ] 894 | }, 895 | "execution_count": 49, 896 | "metadata": {}, 897 | "output_type": "execute_result" 898 | } 899 | ], 900 | "source": [ 901 | "# MSE loss of image to itself compares pixel values (perfect match for the same image)\n", 902 | "mse_loss(test_dataset[0][0], test_dataset[0][0])" 903 | ] 904 | }, 905 | { 906 | "cell_type": "code", 907 | "execution_count": 50, 908 | "id": "534493f7-b39e-44f5-9740-a0f6d99037d7", 909 | "metadata": {}, 910 | "outputs": [ 911 | { 912 | "data": { 913 | "text/plain": [ 914 | "tensor(1.0462)" 915 | ] 916 | }, 917 | "execution_count": 50, 918 | "metadata": {}, 919 | "output_type": "execute_result" 920 | } 921 | ], 922 | "source": [ 923 | "# MSE loss of image to itself compares pixel values (not a perfect match for different images)\n", 924 | "# keep this in mind for diffusion\n", 925 | "mse_loss(test_dataset[0][0], test_dataset[5][0])" 926 | ] 927 | }, 928 | { 929 | "cell_type": "code", 930 | "execution_count": null, 931 | "id": "bc77cc73-fef8-4cc0-bfd7-6d5e177d58ad", 932 | "metadata": {}, 933 | "outputs": [], 934 | "source": [] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "execution_count": null, 939 | "id": "5e1b2e96-e61b-407c-8dc3-c9f4898a55dc", 940 | "metadata": {}, 941 | "outputs": [], 942 | "source": [] 943 | }, 944 | { 945 | "cell_type": "code", 946 | "execution_count": null, 947 | "id": "fcdc0cfd-ef75-4432-b675-9f0718520695", 948 | "metadata": {}, 949 | "outputs": [], 950 | "source": [] 951 | } 952 | ], 953 | "metadata": { 954 | "kernelspec": { 955 | "display_name": "Python (/usr/bin/python3)", 956 | "language": "python", 957 | "name": "my_python3" 958 | }, 959 | "language_info": { 960 | "codemirror_mode": { 961 | "name": "ipython", 962 | "version": 3 963 | }, 964 | "file_extension": ".py", 965 | "mimetype": "text/x-python", 966 | "name": "python", 967 | "nbconvert_exporter": "python", 968 | "pygments_lexer": "ipython3", 969 | "version": "3.9.6" 970 | } 971 | }, 972 | "nbformat": 4, 973 | "nbformat_minor": 5 974 | } 975 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision==0.15.1 3 | Flask==2.1.0 4 | Pillow==10.4.0 5 | transformers[torch]==4.44.0 6 | datasets==2.21.0 7 | jupyterlab 8 | ipykernel 9 | matplotlib 10 | scikit-learn==1.5.1 11 | onnx==1.16.2 12 | tiktoken 13 | wandb 14 | evaluate==0.4.2 15 | numpy<2 --------------------------------------------------------------------------------