├── .gitignore ├── LearningLibTorch ├── LearningLibTorch.sln └── LearningLibTorch │ ├── DeepLearning.cpp │ ├── LearningLibTorch.vcxproj │ ├── LearningLibTorch.vcxproj.filters │ └── LearningLibTorch.vcxproj.user ├── README.md └── data ├── t10k-images-idx3-ubyte ├── t10k-labels-idx1-ubyte ├── train-images-idx3-ubyte └── train-labels-idx1-ubyte /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | **/x64/** 35 | */.vs/* 36 | -------------------------------------------------------------------------------- /LearningLibTorch/LearningLibTorch.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 16 4 | VisualStudioVersion = 16.0.30523.141 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LearningLibTorch", "LearningLibTorch\LearningLibTorch.vcxproj", "{05769FE2-11F1-476C-BAA2-2F04EC03339B}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {05769FE2-11F1-476C-BAA2-2F04EC03339B}.Debug|x64.ActiveCfg = Debug|x64 17 | {05769FE2-11F1-476C-BAA2-2F04EC03339B}.Debug|x64.Build.0 = Debug|x64 18 | {05769FE2-11F1-476C-BAA2-2F04EC03339B}.Debug|x86.ActiveCfg = Debug|Win32 19 | {05769FE2-11F1-476C-BAA2-2F04EC03339B}.Debug|x86.Build.0 = Debug|Win32 20 | {05769FE2-11F1-476C-BAA2-2F04EC03339B}.Release|x64.ActiveCfg = Release|x64 21 | {05769FE2-11F1-476C-BAA2-2F04EC03339B}.Release|x64.Build.0 = Release|x64 22 | {05769FE2-11F1-476C-BAA2-2F04EC03339B}.Release|x86.ActiveCfg = Release|Win32 23 | {05769FE2-11F1-476C-BAA2-2F04EC03339B}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | GlobalSection(ExtensibilityGlobals) = postSolution 29 | SolutionGuid = {89936C1F-3E2E-4FA3-997F-3BAF40B8C36C} 30 | EndGlobalSection 31 | EndGlobal 32 | -------------------------------------------------------------------------------- /LearningLibTorch/LearningLibTorch/DeepLearning.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // Where to find the MNIST dataset. 11 | const std::string kDataRoot = R"(..\..\data)"; 12 | 13 | // The batch size for training. 14 | const int64_t kTrainBatchSize = 64; 15 | 16 | // The batch size for testing. 17 | const int64_t kTestBatchSize = 1000; 18 | 19 | // The number of epochs to train. 20 | const int64_t kNumberOfEpochs = 10; 21 | 22 | // After how many batches to log a new update with the loss value. 23 | const int64_t kLogInterval = 10; 24 | 25 | struct Net : torch::nn::Module { 26 | Net() 27 | : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)), 28 | conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)), 29 | fc1(320, 50), 30 | fc2(50, 10) { 31 | register_module("conv1", conv1); 32 | register_module("conv2", conv2); 33 | register_module("conv2_drop", conv2_drop); 34 | register_module("fc1", fc1); 35 | register_module("fc2", fc2); 36 | } 37 | 38 | torch::Tensor forward(torch::Tensor x) { 39 | x = torch::relu(torch::max_pool2d(conv1->forward(x), 2)); 40 | x = torch::relu( 41 | torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2)); 42 | x = x.view({ -1, 320 }); 43 | x = torch::relu(fc1->forward(x)); 44 | x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training()); 45 | x = fc2->forward(x); 46 | return torch::log_softmax(x, /*dim=*/1); 47 | } 48 | 49 | torch::nn::Conv2d conv1; 50 | torch::nn::Conv2d conv2; 51 | torch::nn::Dropout2d conv2_drop; 52 | torch::nn::Linear fc1; 53 | torch::nn::Linear fc2; 54 | }; 55 | 56 | template 57 | void train( 58 | size_t epoch, 59 | Net& model, 60 | torch::Device device, 61 | DataLoader& data_loader, 62 | torch::optim::Optimizer& optimizer, 63 | size_t dataset_size) { 64 | model.train(); 65 | size_t batch_idx = 0; 66 | for (auto& batch : data_loader) { 67 | auto data = batch.data.to(device), targets = batch.target.to(device); 68 | optimizer.zero_grad(); 69 | auto output = model.forward(data); 70 | auto loss = torch::nll_loss(output, targets); 71 | AT_ASSERT(!std::isnan(loss.template item())); 72 | loss.backward(); 73 | optimizer.step(); 74 | 75 | if (batch_idx++ % kLogInterval == 0) { 76 | std::printf( 77 | "\rTrain Epoch: %ld [%5ld/%5ld] Loss: %.4f", 78 | epoch, 79 | batch_idx * batch.data.size(0), 80 | dataset_size, 81 | loss.template item()); 82 | } 83 | } 84 | } 85 | 86 | template 87 | void test( 88 | Net& model, 89 | torch::Device device, 90 | DataLoader& data_loader, 91 | size_t dataset_size) { 92 | torch::NoGradGuard no_grad; 93 | model.eval(); 94 | double test_loss = 0; 95 | int32_t correct = 0; 96 | for (const auto& batch : data_loader) { 97 | auto data = batch.data.to(device), targets = batch.target.to(device); 98 | auto output = model.forward(data); 99 | test_loss += torch::nll_loss( 100 | output, 101 | targets, 102 | /*weight=*/{}, 103 | torch::Reduction::Sum) 104 | .template item(); 105 | auto pred = output.argmax(1); 106 | correct += pred.eq(targets).sum().template item(); 107 | } 108 | 109 | test_loss /= dataset_size; 110 | std::printf( 111 | "\nTest set: Average loss: %.4f | Accuracy: %.3f\n", 112 | test_loss, 113 | static_cast(correct) / dataset_size); 114 | } 115 | 116 | auto main() -> int { 117 | torch::manual_seed(1); 118 | 119 | torch::DeviceType device_type; 120 | if (torch::cuda::is_available()) { 121 | std::cout << "CUDA available! Training on GPU." << std::endl; 122 | device_type = torch::kCUDA; 123 | } 124 | else { 125 | std::cout << "Training on CPU." << std::endl; 126 | device_type = torch::kCPU; 127 | } 128 | torch::Device device(device_type); 129 | 130 | Net model; 131 | model.to(device); 132 | 133 | auto train_dataset = torch::data::datasets::MNIST(kDataRoot) 134 | .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) 135 | .map(torch::data::transforms::Stack<>()); 136 | const size_t train_dataset_size = train_dataset.size().value(); 137 | auto train_loader = 138 | torch::data::make_data_loader( 139 | std::move(train_dataset), kTrainBatchSize); 140 | 141 | auto test_dataset = torch::data::datasets::MNIST( 142 | kDataRoot, torch::data::datasets::MNIST::Mode::kTest) 143 | .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) 144 | .map(torch::data::transforms::Stack<>()); 145 | const size_t test_dataset_size = test_dataset.size().value(); 146 | auto test_loader = 147 | torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize); 148 | 149 | torch::optim::SGD optimizer( 150 | model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5)); 151 | 152 | for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) { 153 | train(epoch, model, device, *train_loader, optimizer, train_dataset_size); 154 | test(model, device, *test_loader, test_dataset_size); 155 | } 156 | } -------------------------------------------------------------------------------- /LearningLibTorch/LearningLibTorch/LearningLibTorch.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | Debug 14 | x64 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | 16.0 23 | Win32Proj 24 | {05769fe2-11f1-476c-baa2-2f04ec03339b} 25 | LearningTorchScript 26 | 10.0 27 | 28 | 29 | 30 | Application 31 | true 32 | v142 33 | Unicode 34 | 35 | 36 | Application 37 | false 38 | v142 39 | true 40 | Unicode 41 | 42 | 43 | Application 44 | true 45 | v142 46 | Unicode 47 | 48 | 49 | Application 50 | false 51 | v142 52 | true 53 | Unicode 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | true 75 | 76 | 77 | false 78 | 79 | 80 | true 81 | 82 | 83 | false 84 | 85 | 86 | 87 | Level3 88 | true 89 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) 90 | false 91 | $(ProjectDir)..\..\..\libtorch-win-shared-with-deps-debug-1.6.0+cpu\libtorch\include;$(ProjectDir)..\..\..\libtorch-win-shared-with-deps-debug-1.6.0+cpu\libtorch\include\torch\csrc\api\include;%(AdditionalIncludeDirectories) 92 | 93 | 94 | Console 95 | true 96 | $(ProjectDir)..\..\..\libtorch-win-shared-with-deps-debug-1.6.0+cpu\libtorch\lib;%(AdditionalLibraryDirectories) 97 | torch.lib;c10.lib;torch_cpu.lib;%(AdditionalDependencies) 98 | 99 | 100 | 101 | 102 | Level3 103 | true 104 | true 105 | true 106 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 107 | true 108 | 109 | 110 | Console 111 | true 112 | true 113 | true 114 | 115 | 116 | 117 | 118 | Level3 119 | true 120 | _DEBUG;_CONSOLE;%(PreprocessorDefinitions) 121 | false 122 | $(ProjectDir)..\..\..\libtorch-win-shared-with-deps-debug-1.6.0+cpu\libtorch\include;$(ProjectDir)..\..\..\libtorch-win-shared-with-deps-debug-1.6.0+cpu\libtorch\include\torch\csrc\api\include;%(AdditionalIncludeDirectories) 123 | 124 | 125 | Console 126 | true 127 | $(ProjectDir)..\..\..\libtorch-win-shared-with-deps-debug-1.6.0+cpu\libtorch\lib;%(AdditionalLibraryDirectories) 128 | torch_cpu.lib;torch.lib;c10.lib;%(AdditionalDependencies) 129 | 130 | 131 | xcopy $(ProjectDir)..\..\..\libtorch-win-shared-with-deps-debug-1.6.0+cpu\libtorch\lib\*.dll $(SolutionDir)$(Platform)\$(Configuration)\ /c /y 132 | 133 | 134 | 135 | 136 | Level3 137 | true 138 | true 139 | true 140 | NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 141 | false 142 | $(ProjectDir)..\..\..\libtorch-win-shared-with-deps-1.6.0+cpu\libtorch\include;$(ProjectDir)..\..\..\libtorch-win-shared-with-deps-1.6.0+cpu\libtorch\include\torch\csrc\api\include;%(AdditionalIncludeDirectories) 143 | 144 | 145 | Console 146 | true 147 | true 148 | true 149 | $(ProjectDir)..\..\..\libtorch-win-shared-with-deps-1.6.0+cpu\libtorch\lib;%(AdditionalLibraryDirectories) 150 | torch_cpu.lib;c10.lib;torch.lib;%(AdditionalDependencies) 151 | 152 | 153 | xcopy $(ProjectDir)..\..\..\libtorch-win-shared-with-deps-1.6.0+cpu\libtorch\lib\*.dll $(SolutionDir)$(Platform)\$(Configuration)\ /c /y 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /LearningLibTorch/LearningLibTorch/LearningLibTorch.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | Source Files 20 | 21 | 22 | -------------------------------------------------------------------------------- /LearningLibTorch/LearningLibTorch/LearningLibTorch.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # libtorch-mnist-visual-studio 2 | This repository contains a visual studio project for training a classifier on the mnist dataset using the libtorch c++ wrapper. 3 | 4 | The tutorial can be found here: 5 | 6 | https://expoundai.wordpress.com/2020/10/13/setting-up-a-cpp-project-in-visual-studio-2019-with-libtorch-1-6/ 7 | -------------------------------------------------------------------------------- /data/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msminhas93/libtorch-mnist-visual-studio/512eab52845ad62e209ab696b7a39ae18af69c6a/data/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /data/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /data/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msminhas93/libtorch-mnist-visual-studio/512eab52845ad62e209ab696b7a39ae18af69c6a/data/train-images-idx3-ubyte -------------------------------------------------------------------------------- /data/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msminhas93/libtorch-mnist-visual-studio/512eab52845ad62e209ab696b7a39ae18af69c6a/data/train-labels-idx1-ubyte --------------------------------------------------------------------------------