├── Applications └── Range of applications.md ├── Images ├── Image └── wav-kan.jpg ├── KAN.py ├── LICENSE ├── README.md └── wavKAN.py /Applications/Range of applications.md: -------------------------------------------------------------------------------- 1 | # Possible Applications of Wav-KAN 2 | 3 | ## Wav-KAN Applications 4 | Wav-KAN will appear in all the following fields: 5 | 6 | ## Applications of Wavelet Transform 7 | [Link to Source](https://www.intechopen.com/chapters/1181336) 8 | 9 | The relevance of the wavelet transform resides in its capacity to evaluate signals and images at many resolutions while maintaining time-frequency localization, rendering it a potent tool in several domains. The wavelet transform is significant and applicable in several major areas: 10 | 11 | 1. **Image Compression:** Reducing the data size of images while retaining quality by analyzing and discarding less critical image information in different frequency bands. 12 | 2. **Signal Denoising:** Removing unwanted noise from signals while preserving essential features by separating noise from the signal components at various scales. 13 | 3. **Audio Compression:** Reducing the size of audio files without significant loss of quality, vital in efficient storage and transmission of audio data. 14 | 4. **Feature Extraction in Image Processing:** Identifying and extracting meaningful features from images, such as edges or textures, for subsequent analysis or pattern recognition. 15 | 5. **Seismic Signal Analysis:** Studying seismic waves to understand subsurface structures and earthquake characteristics, aiding in geophysical exploration. 16 | 6. **Edge Detection in Image Processing:** Identifying boundaries or edges between objects in images, crucial for object recognition and image segmentation. 17 | 7. **Financial Time-Series Analysis:** Studying financial data trends, identifying patterns, and predicting market behavior for investment decisions. 18 | 8. **Speech Processing:** Analyzing speech signals for tasks like speech recognition, language translation, and voice-based interfaces. 19 | 9. **Biometric Systems:** Extracting distinctive features from biometric data (like fingerprints or irises) for identity verification. 20 | 10. **Communication Systems:** Analyzing modulated signals in communication systems for signal processing, error correction, and so forth. 21 | 11. **Pattern Recognition:** Identifying and categorizing patterns or objects in data, crucial in machine learning and computer vision. 22 | 12. **Geophysical Data Analysis:** Processing geophysical data to understand geological formations and subsurface structures. 23 | 13. **Texture Analysis in Image Processing:** Characterizing textures in images for various applications, including remote sensing and materials analysis. 24 | 14. **Nondestructive Testing:** Analyzing signals to detect flaws or defects in materials without causing damage, used in industry and materials science. 25 | 15. **Vibration Analysis:** Studying vibrations in mechanical systems for fault detection and condition monitoring in machinery. 26 | 16. **Time-Frequency Analysis in EEG Signals:** Extracting frequency information over time from EEG signals to understand brain activity patterns. 27 | 17. **Molecular Biology:** Analyzing biological signals to study genetic patterns, molecular interactions, and so on, in biological research. 28 | 18. **Fault Detection in Power Systems:** Monitoring power systems to detect and diagnose faults for maintaining grid stability. 29 | 19. **Environmental Data Analysis:** Analyzing environmental signals for studying climate patterns, ecological changes, and so forth. 30 | 20. **Video Compression:** Compressing video data efficiently for storage, streaming, and transmission. 31 | 21. **Sonar Signal Processing:** Analyzing underwater signals for navigation, target detection, and marine communication. 32 | 22. **Radar Signal Processing:** Analyzing radar signals for object detection, tracking, and navigation in aerospace and defense. 33 | 23. **Spectral Analysis:** Decomposing signals into frequency components for analyzing spectral characteristics. 34 | 24. **Image Enhancement:** Improving the quality or appearance of images for better visualization or analysis. 35 | 25. **Data Fusion:** Combining multiple sources of information to enhance data accuracy or completeness. 36 | 26. **Character Recognition:** Identifying and converting characters from images into text for OCR applications. 37 | 27. **Object Tracking:** Following the movement of objects in video sequences for surveillance or monitoring. 38 | 28. **Fractal Analysis:** Analyzing complex patterns or shapes using fractal geometry for various applications. 39 | 29. **Remote Sensing:** Using sensors to collect data from a distance for environmental or geographical analysis. 40 | 30. **System Identification:** Modeling and understanding the behavior of dynamical systems from measured data. 41 | 31. **Image Watermarking:** Embedding information into images for copyright protection or authentication. 42 | 32. **Wireless Communication Systems:** Analyzing signals in wireless networks for efficient data transmission. 43 | 33. **Image Registration:** Aligning multiple images for comparison or creating panoramic views. 44 | 34. **Anomaly Detection:** Identifying unusual patterns or events in data that deviate from expected behavior. 45 | 35. **Quality Assessment in Images:** Evaluating image quality for various applications like printing or medical imaging. 46 | 36. **Time Series Forecasting:** Predicting future values based on past data patterns in time series analysis. 47 | 37. **Motion Detection in Video:** Detecting movement in video sequences for security or activity monitoring. 48 | 38. **Hyperspectral Imaging Analysis:** Analyzing images with numerous spectral bands for detailed material identification. 49 | 39. **Structural Health Monitoring:** Monitoring structural conditions of buildings or infrastructure for maintenance. 50 | 40. **Channel Equalization:** Compensating for distortion in communication channels to recover transmitted signals. 51 | 41. **Quantum Signal Processing:** Analyzing quantum signals or information processing in quantum systems. 52 | 42. **Robotics and Vision Systems:** Processing visual data for robot guidance and control in robotics applications. 53 | 43. **ECG Signal Analysis:** Analyzing electrocardiogram signals for diagnosing heart conditions or abnormalities. 54 | 44. **Sonography Image Processing:** Enhancing and analyzing ultrasound images for medical diagnosis. 55 | 45. **DNA Sequence Analysis:** Analyzing DNA sequences for understanding genetic information and mutations. 56 | 46. **Audio Signal Separation:** Separating mixed audio sources into individual components for analysis or modification. 57 | 47. **Speaker Recognition:** Identifying individuals by analyzing characteristics of their voice patterns. 58 | 48. **Waveform Analysis:** Analyzing waveforms to understand characteristics or patterns in signals. 59 | 49. **Information Retrieval:** Extracting relevant information from large datasets or databases. 60 | 50. **Computational Neuroscience:** Applying computational methods to study brain function and neural systems. 61 | 51. **Gait Analysis:** Analyzing human walking patterns for medical, biomechanical, or forensic purposes. 62 | 52. **Gesture Recognition:** Recognizing and interpreting human gestures for human–computer interaction. 63 | 53. **Traffic Analysis and Prediction:** Analyzing traffic patterns for congestion prediction and management. 64 | 54. **Functional Magnetic Resonance Imaging (fMRI) Analysis:** Analyzing brain activity based on fMRI scans to understand brain function. 65 | 55. **Texture Synthesis:** Creating new textures based on existing ones for graphics or modeling. 66 | 56. **Sleep Pattern Analysis:** Studying sleep patterns and stages for sleep disorder diagnosis. 67 | 57. **Electroencephalography (EEG) Analysis:** Analyzing brain electrical activity for neuroscience or medical diagnostics. 68 | 58. **Antenna Array Processing:** Processing signals from antenna arrays for improved wireless communications. 69 | 59. **Intrusion Detection:** Detecting and preventing unauthorized access or attacks in computer systems. 70 | 60. **Text Mining:** Extracting useful information or patterns from large volumes of text data. 71 | 61. **Time-Frequency Analysis in Music:** Analyzing music signals to understand their frequency and time characteristics. 72 | 62. **Eye Tracking:** Tracking eye movements to understand visual attention or diagnose eye conditions. 73 | 63. **Glottal Analysis:** Studying characteristics of vocal fold vibrations for speech and voice analysis. 74 | 64. **Solar Activity Prediction:** Predicting solar activities like sunspots or flares for space weather forecasting. 75 | 65. **Image Matting:** Extracting foreground objects from an image for editing or composition. 76 | 66. **Electrocardiography (ECG) Signal Analysis:** Analyzing heart electrical activity for diagnosing cardiac conditions. 77 | 67. **Spatiotemporal Data Analysis:** Analyzing data considering both space and time dimensions for various applications. 78 | 68. **Synthetic Aperture Radar (SAR) Processing:** Analyzing radar data for high-resolution imaging in remote sensing applications. 79 | 69. **Gene Expression Analysis:** Studying patterns of gene activity to understand biological processes and diseases. 80 | 70. **Surface Defect Detection:** Identifying defects or anomalies on surfaces for quality control in manufacturing. 81 | 71. **Oceanographic Data Analysis:** Analyzing ocean data for understanding marine ecosystems, currents, and climate. 82 | 72. **Financial Volatility Analysis:** Studying fluctuations in financial markets to assess risk and volatility. 83 | 73. **ECG-based Biometric Systems:** Using ECG signals for biometric identification or authentication purposes. 84 | 74. **Structural Damage Identification:** Identifying structural damage or deterioration in buildings or infrastructure. 85 | 75. **Traffic Signal Timing Optimization:** Optimizing traffic signal timings for better traffic flow and congestion management. 86 | 76. **Human Activity Recognition:** Identifying and categorizing human activities from sensor data for various applications. 87 | 77. **Biomedical Image Fusion:** Combining multiple biomedical images for better visualization or analysis. 88 | 78. **Radio Astronomy Data Analysis:** Analyzing signals from radio telescopes for studying celestial objects or phenomena. 89 | 79. **Brain-Computer Interfaces:** Using brain signals for controlling external devices or computers. 90 | 80. **Solar Power Forecasting:** Predicting solar energy production for efficient grid management. 91 | 81. **Gesture-based Human–Computer Interaction:** Using gestures for controlling or interacting with computers or devices. 92 | 82. **Melody Extraction in Music Signals:** Extracting melodies or dominant pitches from music signals. 93 | 83. **Ionosphere Signal Processing:** Analyzing ionospheric signals for communication or navigation purposes. 94 | 84. **Neuroimaging Data Analysis:** Processing brain imaging data for studying brain structure or function. 95 | 85. **Cyber-Physical Systems Analysis:** Analyzing systems that integrate physical and computational components. 96 | 86. **Photonics Signal Processing:** Processing signals in photonics for various optical or light-based applications. 97 | 87. **Object Detection in Images:** Detecting and locating objects within images or videos for various applications. 98 | 88. **Forensic Image Analysis:** Analyzing images for forensic investigations or evidence examination. 99 | 100 | These applications showcase the wide-ranging utility of wavelet transform across diverse fields, illustrating its pivotal role in signal processing, data analysis, and scientific research in numerous domains. 101 | 102 | -------------------------------------------------------------------------------- /Images/Image: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Images/wav-kan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zavareh1/Wav-KAN/030e66a821221cf2608ac3fec2a9997a66eeac32/Images/wav-kan.jpg -------------------------------------------------------------------------------- /KAN.py: -------------------------------------------------------------------------------- 1 | '''This is a sample code for the simulations of the paper: 2 | Bozorgasl, Zavareh and Chen, Hao, Wav-KAN: Wavelet Kolmogorov-Arnold Networks (May, 2024) 3 | 4 | https://arxiv.org/abs/2405.12832 5 | and also available at: 6 | https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4835325 7 | We used efficient KAN notation and some part of the code:https://github.com/Blealtan/efficient-kan 8 | 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | import math 19 | 20 | class KANLinear(nn.Module): 21 | def __init__(self, in_features, out_features, wavelet_type='mexican_hat'): 22 | super(KANLinear, self).__init__() 23 | self.in_features = in_features 24 | self.out_features = out_features 25 | self.wavelet_type = wavelet_type 26 | 27 | # Parameters for wavelet transformation 28 | self.scale = nn.Parameter(torch.ones(out_features, in_features)) 29 | self.translation = nn.Parameter(torch.zeros(out_features, in_features)) 30 | 31 | # Linear weights for combining outputs 32 | #self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 33 | self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features)) #not used; you may like to use it for wieghting base activation and adding it like Spl-KAN paper 34 | self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features)) 35 | 36 | nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5)) 37 | nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) 38 | 39 | # Base activation function #not used for this experiment 40 | self.base_activation = nn.SiLU() 41 | 42 | # Batch normalization 43 | self.bn = nn.BatchNorm1d(out_features) 44 | 45 | def wavelet_transform(self, x): 46 | if x.dim() == 2: 47 | x_expanded = x.unsqueeze(1) 48 | else: 49 | x_expanded = x 50 | 51 | translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1) 52 | scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1) 53 | x_scaled = (x_expanded - translation_expanded) / scale_expanded 54 | 55 | # Implementation of different wavelet types 56 | if self.wavelet_type == 'mexican_hat': 57 | term1 = ((x_scaled ** 2)-1) 58 | term2 = torch.exp(-0.5 * x_scaled ** 2) 59 | wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2 60 | wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet) 61 | wavelet_output = wavelet_weighted.sum(dim=2) 62 | elif self.wavelet_type == 'morlet': 63 | omega0 = 5.0 # Central frequency 64 | real = torch.cos(omega0 * x_scaled) 65 | envelope = torch.exp(-0.5 * x_scaled ** 2) 66 | wavelet = envelope * real 67 | wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet) 68 | wavelet_output = wavelet_weighted.sum(dim=2) 69 | 70 | elif self.wavelet_type == 'dog': 71 | # Implementing Derivative of Gaussian Wavelet 72 | dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2) 73 | wavelet = dog 74 | wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet) 75 | wavelet_output = wavelet_weighted.sum(dim=2) 76 | elif self.wavelet_type == 'meyer': 77 | # Implement Meyer Wavelet here 78 | # Constants for the Meyer wavelet transition boundaries 79 | v = torch.abs(x_scaled) 80 | pi = math.pi 81 | 82 | def meyer_aux(v): 83 | return torch.where(v <= 1/2,torch.ones_like(v),torch.where(v >= 1,torch.zeros_like(v),torch.cos(pi / 2 * nu(2 * v - 1)))) 84 | 85 | def nu(t): 86 | return t**4 * (35 - 84*t + 70*t**2 - 20*t**3) 87 | # Meyer wavelet calculation using the auxiliary function 88 | wavelet = torch.sin(pi * v) * meyer_aux(v) 89 | wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet) 90 | wavelet_output = wavelet_weighted.sum(dim=2) 91 | elif self.wavelet_type == 'shannon': 92 | # Windowing the sinc function to limit its support 93 | pi = math.pi 94 | sinc = torch.sinc(x_scaled / pi) # sinc(x) = sin(pi*x) / (pi*x) 95 | 96 | # Applying a Hamming window to limit the infinite support of the sinc function 97 | window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype, device=x_scaled.device) 98 | # Shannon wavelet is the product of the sinc function and the window 99 | wavelet = sinc * window 100 | wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet) 101 | wavelet_output = wavelet_weighted.sum(dim=2) 102 | #You can try many more wavelet types ... 103 | else: 104 | raise ValueError("Unsupported wavelet type") 105 | 106 | return wavelet_output 107 | 108 | def forward(self, x): 109 | wavelet_output = self.wavelet_transform(x) 110 | #You may like test the cases like Spl-KAN 111 | #wav_output = F.linear(wavelet_output, self.weight) 112 | #base_output = F.linear(self.base_activation(x), self.weight1) 113 | 114 | base_output = F.linear(x, self.weight1) 115 | combined_output = wavelet_output #+ base_output 116 | 117 | # Apply batch normalization 118 | return self.bn(combined_output) 119 | 120 | class KAN(nn.Module): 121 | def __init__(self, layers_hidden, wavelet_type='mexican_hat'): 122 | super(KAN, self).__init__() 123 | self.layers = nn.ModuleList() 124 | for in_features, out_features in zip(layers_hidden[:-1], layers_hidden[1:]): 125 | self.layers.append(KANLinear(in_features, out_features, wavelet_type)) 126 | 127 | def forward(self, x): 128 | for layer in self.layers: 129 | x = layer(x) 130 | return x 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zavareh Bozorgasl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wav-KAN: Wavelet Kolmogorov-Arnold Networks 2 | The codes to replicate the simulations of the paper:"Wav-KAN: Wavelet Kolmogorov-Arnold Networks". To see a diverse range of possible applications of Wav-KAN, check "Applications" folder! 3 | 4 | ### Links to the Paper 5 | - Available at: [arXiv](https://arxiv.org/abs/2405.12832) 6 | - Also available at: [SSRN](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4835325) 7 | 8 | 9 | ### We applied Wav-KAN to Hyperspectral Image Classification 10 | - Available at: [arXiv](https://arxiv.org/abs/2406.07869) 11 | 12 | ### Highlights of Wav-KAN on Social Media (**X**, formerly Twitter) 13 | ![View on X](Images/wav-kan.jpg) 14 | 15 | This image showcases Wav-KAN being highlighted and shared with the community on social media. It reflects the growing interest and engagement around this innovative framework. 16 | 17 | ## Current Contents of the Repository 18 | - **MNIST Training and Testing**: 19 | - The repository currently contains the codes required to replicate MNIST training and testing. 20 | - More codes and examples will be added in future updates. 21 | - **Possible applications of Wavelet/Wav-KAN** 22 | 23 | ## Abstract 24 | In this paper, we introduce **Wav-KAN**, an innovative neural network architecture that leverages the **Wavelet Kolmogorov-Arnold Networks (Wav-KAN)** framework to enhance interpretability and performance. 25 | 26 | Traditional multilayer perceptrons (MLPs) and even recent advancements like Spl-KAN face challenges such as: 27 | - Interpretability 28 | - Training speed 29 | - Robustness 30 | - Computational efficiency 31 | - Performance limitations 32 | 33 | ### Wav-KAN addresses these issues by: 34 | - Incorporating **wavelet functions** into the Kolmogorov-Arnold network structure. 35 | - Efficiently capturing both **high-frequency** and **low-frequency components** of input data. 36 | - Using **continuos (dyadic) wavelet transforms** for multiresolution analysis, eliminating the need for recalculations in detail extraction. 37 | 38 | ### Key Features: 39 | - Wavelet-based approximations employ **orthogonal or semi-orthogonal bases**, balancing data structure representation and noise reduction. 40 | - Enhanced accuracy, faster training speeds, and increased robustness compared to Spl-KAN and MLPs. 41 | - Adaptability to the data structure, akin to how water conforms to its container. 42 | 43 | Our results highlight the potential of Wav-KAN as a **powerful tool** for developing interpretable and high-performance neural networks, with applications across various fields. This work paves the way for further exploration and implementation of Wav-KAN in frameworks such as **PyTorch** and **TensorFlow**, aspiring to make wavelets in KAN as common as activation functions like **ReLU** and **sigmoid** in universal approximation theory (UAT). 44 | 45 | ## Future Updates 46 | - Additional code implementations and simulations for Wav-KAN. 47 | 48 | 49 | --- 50 | 51 | Stay tuned for updates! Feedback and contributions are welcome. 🚀 52 | 53 | 54 | ## Reference 55 | 56 | If you find this repository helpful in your research or projects, please consider citing the following paper: 57 | 58 | ```bibtex 59 | @article{bozorgasl2024wavkan, 60 | author = {Zavareh Bozorgasl and Hao Chen}, 61 | title = {Wav-KAN: Wavelet Kolmogorov-Arnold Networks}, 62 | journal = {arXiv preprint arXiv:2405.12832}, 63 | year = {2024}, 64 | url = {https://arxiv.org/abs/2405.12832} 65 | } 66 | -------------------------------------------------------------------------------- /wavKAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | import pandas as pd 10 | from KAN import * 11 | 12 | # Defining the wavelet types 13 | #wavelet_types = ['mexican_hat', 'morlet', 'dog', 'meyer', 'shannon', 'bump', etc.] #It can include #all wavelet types 14 | wavelet_types = ['mexican_hat', 'morlet', 'dog'] 15 | 16 | # Loading MNIST data set 17 | transform = transforms.Compose( 18 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 19 | ) 20 | trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) 21 | valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) 22 | trainloader = DataLoader(trainset, batch_size=64, shuffle=True) 23 | valloader = DataLoader(valset, batch_size=64, shuffle=False) 24 | 25 | # Trials and Epochs (epochs per trial) 26 | trials = 5 27 | epochs_per_trial = 50 28 | 29 | # Looping over each wavelet type 30 | for wavelet in wavelet_types: 31 | all_train_losses, all_train_accuracies = [], [] 32 | all_val_losses, all_val_accuracies = [], [] 33 | print(f'Wavelet is {wavelet}') 34 | #For a specified number of trials 35 | for trial in range(trials): 36 | print(f'Trial is {trial}') 37 | # Define model, optimizer, scheduler for each trial 38 | model = KAN([28 * 28, 32, 10], wavelet_type=wavelet) 39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | model.to(device) 41 | optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) 42 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 43 | criterion = nn.CrossEntropyLoss() 44 | 45 | trial_train_losses, trial_val_losses = [], [] 46 | trial_train_accuracies, trial_val_accuracies = [], [] 47 | #For a specified number of epchs 48 | for epoch in range(epochs_per_trial): 49 | # Training 50 | train_loss, train_correct, train_total = 0.0, 0, 0 51 | model.train() 52 | #for images, labels in tqdm(trainloader): 53 | for images, labels in trainloader: 54 | images = images.view(-1, 28 * 28).to(device) 55 | labels = labels.to(device) 56 | optimizer.zero_grad() 57 | outputs = model(images) 58 | loss = criterion(outputs, labels) 59 | loss.backward() 60 | optimizer.step() 61 | 62 | train_loss += loss.item() 63 | _, predicted = torch.max(outputs.data, 1) 64 | train_total += labels.size(0) 65 | train_correct += (predicted == labels).sum().item() 66 | 67 | train_loss /= len(trainloader) 68 | train_acc = 100 * train_correct / train_total 69 | trial_train_losses.append(train_loss) 70 | trial_train_accuracies.append(train_acc) 71 | 72 | # Validation 73 | val_loss, val_correct, val_total = 0.0, 0, 0 74 | model.eval() 75 | with torch.no_grad(): 76 | for images, labels in valloader: 77 | images = images.view(-1, 28 * 28).to(device) 78 | labels = labels.to(device) 79 | outputs = model(images) 80 | loss = criterion(outputs, labels) 81 | val_loss += loss.item() 82 | _, predicted = torch.max(outputs.data, 1) 83 | val_total += labels.size(0) 84 | val_correct += (predicted == labels).sum().item() 85 | val_loss /= len(valloader) 86 | val_acc = 100 * val_correct / val_total 87 | trial_val_losses.append(val_loss) 88 | trial_val_accuracies.append(val_acc) 89 | 90 | # Update learning rate 91 | scheduler.step() 92 | #collecting statistics 93 | all_train_losses.append(trial_train_losses) 94 | all_train_accuracies.append(trial_train_accuracies) 95 | all_val_losses.append(trial_val_losses) 96 | all_val_accuracies.append(trial_val_accuracies) 97 | # Average results across trials and write to Excel 98 | avg_train_losses = pd.DataFrame(all_train_losses).mean().tolist() 99 | avg_train_accuracies = pd.DataFrame(all_train_accuracies).mean().tolist() 100 | avg_val_losses = pd.DataFrame(all_val_losses).mean().tolist() 101 | avg_val_accuracies = pd.DataFrame(all_val_accuracies).mean().tolist() 102 | 103 | results_df = pd.DataFrame({ 104 | 'Epoch': range(1, epochs_per_trial + 1), 105 | 'Train Loss': avg_train_losses, 106 | 'Train Accuracy': avg_train_accuracies, 107 | 'Validation Loss': avg_val_losses, 108 | 'Validation Accuracy': avg_val_accuracies 109 | }) 110 | # Saving the results 111 | # Saving the results to an Excel file named after the wavelet type 112 | file_name = f'{wavelet}_results.xlsx' 113 | results_df.to_excel(file_name, index=False) 114 | 115 | print(f"Results saved to {file_name}.") 116 | --------------------------------------------------------------------------------