├── README.md ├── data_loader.py ├── datasets ├── data.yaml ├── images │ ├── train │ │ └── readme.txt │ └── val │ │ └── readme.txt └── labels │ ├── train │ └── readme.txt │ └── val │ └── readme.txt ├── distill.py ├── models └── readme.txt ├── reformat_dataset.py └── results └── readme.txt /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv8 Knowledge Distillation with Custom Dataset 2 | This project implements knowledge distillation on YOLOv8 to transfer your big model to smaller model, with your custom dataset 3 | 4 | ## Installation 5 | Due to misconception report and suggestion, this code is still under my improvement. Please don't use this code for a while! wait for my update. Thanks! 6 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset, DataLoader 3 | from PIL import Image 4 | from pathlib import Path 5 | 6 | class YOLODataset(Dataset): 7 | def __init__(self, images_folder, labels_folder, transform=None): 8 | self.images_folder = images_folder 9 | self.labels_folder = labels_folder 10 | self.image_filenames = sorted([f for f in os.listdir(images_folder) if f.endswith('.jpg')]) 11 | self.transform = transform 12 | 13 | # Print jumlah gambar dan file yang ditemukan untuk debugging 14 | print(f"Found {len(self.image_filenames)} images in {images_folder}") 15 | 16 | def __len__(self): 17 | return len(self.image_filenames) 18 | 19 | def __getitem__(self, idx): 20 | # Load image 21 | img_path = os.path.join(self.images_folder, self.image_filenames[idx]) 22 | image = Image.open(img_path).convert("RGB") 23 | 24 | # Load corresponding label file (optional) 25 | label_path = os.path.join(self.labels_folder, self.image_filenames[idx].replace('.jpg', '.txt')) 26 | 27 | # Debug: Print paths to ensure files exist 28 | print(f"Loading image: {img_path}") 29 | print(f"Loading label: {label_path}") 30 | 31 | # Apply transformations (resizing, etc.) 32 | if self.transform: 33 | image = self.transform(image) 34 | 35 | # Return image and label path 36 | return image, label_path 37 | 38 | # Dataset and DataLoader setup 39 | def load_dataset(data_path, batch_size, transform): 40 | """ 41 | Load dataset with custom YOLO dataset loader. 42 | """ 43 | train_images_path = Path(data_path).parent / 'images/train' 44 | train_labels_path = Path(data_path).parent / 'labels/train' 45 | val_images_path = Path(data_path).parent / 'images/val' 46 | val_labels_path = Path(data_path).parent / 'labels/val' 47 | 48 | # Dataset untuk training dan validation 49 | train_dataset = YOLODataset(images_folder=train_images_path, labels_folder=train_labels_path, transform=transform) 50 | val_dataset = YOLODataset(images_folder=val_images_path, labels_folder=val_labels_path, transform=transform) 51 | 52 | # DataLoader 53 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 54 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 55 | 56 | return train_loader, val_loader 57 | -------------------------------------------------------------------------------- /datasets/data.yaml: -------------------------------------------------------------------------------- 1 | # Replace {user_path} with your actual user path, e.g., C:/Users/your_username 2 | # Path to the training images folder 3 | train: C:/Users/{user_path}/datasets/images/train 4 | 5 | # Path to the validation images folder 6 | val: C:/Users/{user_path}/datasets/images/val 7 | 8 | # Number of classes 9 | nc: 5 10 | 11 | # Names of the classes 12 | names: ['Apple', 'Banana', 'Kiwi', 'Orange', 'Pear'] 13 | -------------------------------------------------------------------------------- /datasets/images/train/readme.txt: -------------------------------------------------------------------------------- 1 | place your train dataset image here. for example: 2 | image1.jpg 3 | image2.jpg 4 | etc 5 | -------------------------------------------------------------------------------- /datasets/images/val/readme.txt: -------------------------------------------------------------------------------- 1 | upload your validation images here. for example 2 | image3.jpg 3 | image4.jpg 4 | etc 5 | -------------------------------------------------------------------------------- /datasets/labels/train/readme.txt: -------------------------------------------------------------------------------- 1 | Place your annotation.txt file here. It should have the same name as the images in the training dataset. for example 2 | image1.txt 3 | image2.txt 4 | -------------------------------------------------------------------------------- /datasets/labels/val/readme.txt: -------------------------------------------------------------------------------- 1 | Place your annotation.txt file here. It should have the same name as the images in the validation dataset. for example 2 | image3.txt 3 | image4.txt 4 | -------------------------------------------------------------------------------- /distill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from ultralytics import YOLO 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from data_loader import load_dataset # Pastikan file data_loader.py ada 7 | from pathlib import Path 8 | 9 | # Load teacher and student models 10 | teacher_model = YOLO('models/yolov8m.pt') # Teacher model (YOLOv8 medium) 11 | student_model = YOLO('models/yolov8n.pt') # Student model (YOLOv8 nano) 12 | 13 | # Path dataset 14 | data_path = 'C:/Users/phantom/kd/datasets/data.yaml' # Path absolut ke file data.yaml Anda 15 | 16 | # Hyperparameters 17 | alpha = 0.5 # weight for distillation loss 18 | temperature = 3 # temperature for softening logits 19 | batch_size = 16 20 | epochs = 100 21 | 22 | # Data transformations 23 | transform = transforms.Compose([ 24 | transforms.Resize((640, 640)), # Resize images to 640x640 (as required by YOLOv8) 25 | transforms.ToTensor(), 26 | ]) 27 | 28 | def distillation_loss(student_logits, teacher_logits, temperature, alpha): 29 | """ 30 | Menghitung distillation loss. 31 | """ 32 | student_soft = F.log_softmax(student_logits / temperature, dim=1) 33 | teacher_soft = F.softmax(teacher_logits / temperature, dim=1) 34 | loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2) 35 | return alpha * loss 36 | 37 | def train_with_distillation(data_path): 38 | """ 39 | Fungsi utama untuk melatih model dengan distillation. 40 | """ 41 | optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001) 42 | 43 | # Load dataset dari data.yaml 44 | student_model.train(data=data_path, epochs=epochs, batch=batch_size) 45 | 46 | for epoch in range(epochs): 47 | student_model.train() 48 | teacher_model.eval() # Freeze teacher model during training 49 | 50 | running_loss = 0.0 51 | 52 | # Here you should get your dataloader and batch processing (if necessary) 53 | train_loader, val_loader = load_dataset(data_path, batch_size, transform) 54 | 55 | for batch in train_loader: 56 | images, _ = batch # Dataloader returns images and labels, we only need images 57 | 58 | # Move data to the same device as the model 59 | images = images.to(student_model.device) 60 | 61 | # Get predictions from both models 62 | with torch.no_grad(): 63 | teacher_out = teacher_model(images) 64 | student_out = student_model(images) 65 | 66 | # Calculate distillation loss 67 | distill_loss = distillation_loss(student_out['pred'], teacher_out['pred'], temperature, alpha) 68 | 69 | # Combine with student model's original loss 70 | student_loss = student_out['loss'] 71 | total_loss = student_loss + distill_loss 72 | 73 | # Backpropagation 74 | optimizer.zero_grad() 75 | total_loss.backward() 76 | optimizer.step() 77 | 78 | running_loss += total_loss.item() 79 | 80 | print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}') 81 | 82 | # Save student model after distillation 83 | student_model.save('results/student_model_distilled.pt') 84 | 85 | if __name__ == "__main__": 86 | # Pastikan file data.yaml sudah diatur dengan benar 87 | train_with_distillation(data_path) 88 | -------------------------------------------------------------------------------- /models/readme.txt: -------------------------------------------------------------------------------- 1 | The Yolov8 model will be downloaded to this path automatically. 2 | please choose your own teacher and student model. For example 3 | Yolov8m = teacher 4 | yolov8n = student 5 | -------------------------------------------------------------------------------- /reformat_dataset.py: -------------------------------------------------------------------------------- 1 | # This code is used to reformat the dataset folder structure from Roboflow format 2 | # to the format required for knowledge distillation datasets. 3 | # Please modify the folder paths accordingly. 4 | 5 | import os 6 | import shutil 7 | import warnings 8 | from tqdm import tqdm 9 | 10 | # Define the source and destination folders 11 | # Modify the source and destination folder paths based on your dataset. 12 | source_folder = { 13 | 'train_images': 'source_datasets/datasets/train/images', # Source folder for training images 14 | 'train_labels': 'source_datasets/datasets/train/labels', # Source folder for training labels 15 | 'valid_images': 'source_datasets/datasets/valid/images', # Source folder for validation images 16 | 'valid_labels': 'source_datasets/datasets/valid/labels' # Source folder for validation labels 17 | } 18 | 19 | # Define the destination folder structure as per the desired format 20 | destination_folder = { 21 | 'train_images': 'datasets/train/images', # Destination folder for training images 22 | 'val_images': 'datasets/train/val', # Destination folder for validation images 23 | 'train_labels': 'datasets/labels/train', # Destination folder for training labels 24 | 'val_labels': 'datasets/labels/val' # Destination folder for validation labels 25 | } 26 | 27 | # Function to move files from the source folder to the destination folder 28 | def move_files(source, destination): 29 | # Create the destination folder if it doesn't exist 30 | if not os.path.exists(destination): 31 | os.makedirs(destination) 32 | 33 | # Get a list of all files in the source folder 34 | files = os.listdir(source) 35 | 36 | # Use tqdm to show a progress bar while moving files 37 | for filename in tqdm(files, desc=f"Moving files from {source} to {destination}"): 38 | file_path = os.path.join(source, filename) # Define the full path of the source file 39 | dest_path = os.path.join(destination, filename) # Define the full path of the destination file 40 | 41 | try: 42 | # If the file exists and is valid, move it to the destination folder 43 | if os.path.isfile(file_path): 44 | shutil.move(file_path, dest_path) 45 | # Show a success warning if the file is moved successfully 46 | warnings.warn(f"Successfully moved {filename} to {destination}", UserWarning) 47 | else: 48 | # Show a warning if the file is not valid 49 | warnings.warn(f"{filename} is not a valid file", UserWarning) 50 | except Exception as e: 51 | # Show a warning if there was an error moving the file 52 | warnings.warn(f"Failed to move {filename} due to: {str(e)}", UserWarning) 53 | 54 | # Move images and labels from train and valid sets with progress bar 55 | # Moving image and label files from Roboflow folder structure to the new desired format. 56 | move_files(source_folder['train_images'], destination_folder['train_images']) 57 | move_files(source_folder['train_labels'], destination_folder['train_labels']) 58 | move_files(source_folder['valid_images'], destination_folder['val_images']) 59 | move_files(source_folder['valid_labels'], destination_folder['val_labels']) 60 | -------------------------------------------------------------------------------- /results/readme.txt: -------------------------------------------------------------------------------- 1 | Distillation model will be place here! 2 | --------------------------------------------------------------------------------