├── .gitignore ├── README.md ├── module.py ├── plain.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | #not adding the virtual environment to gituhub (installed through requirements.txt) 2 | venv 3 | #pycache is generated by run 4 | __pycache__ 5 | #PyCharm specific folder 6 | .idea 7 | #git specific folder 8 | .gitignore.swp 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PytorchDataloaderForTensorflow 2 | ## Use 3 | As the PyTorch Dataloader has some transforms for input images that can not be done with tf.keras transforms easily it is useful to be able to load image data with a PyTorch dataloader even for fitting a tf.keras model. Therefore a class is implemented that uses a PyTorch dataloader object (doing the transformation on the data) which can be fed into the tf.keras.model.fit_generator function, to provide the training data for the tf.keras model. 4 | ## Setup 5 | The python files were created for python version 3.7, although it might also work for past or future versions. 6 | To use this class, some python modules need to be installed first. Using pip the packages can be installed by either typing 7 | pip install -r requirements.txt 8 | in terminal, if the requirements.txt file exists in the current working directory or by typing 9 | pip install tensorflow==2.0.0 torch==1.3.1 torchvision==0.4.2 10 | into the terminal (!python and pip need to be installed first, the recommended version for pip is at least 19.3.1). The versions of the modules listed above were used at the time of the creation of these files but future versions of these modules might alos work. Another way to install these packages is by using conda. 11 | ## Code 12 | For using the class created for fitting a tf.keras model there are two options: 13 | 1. Put the code straight into a python file:
14 | For that the code from the file [plain.py](plain.py) should be copied into the python file. 15 | 2. Importing the class from a different python file:
16 | For that the file [module.py](module.py) should be inserted into the project folder in which the executed file lies and imported at the top of the executed file:
17 | from module import DataGenerator 18 | 19 | In the following python code the following elements should be included:
20 | ```python 21 | # load the required modules 22 | import tensorflow.keras as k 23 | import torch as pt 24 | from torchvision as tv 25 | 26 | # define the transforms for the pytorch dataloader 27 | # additional transforms from the torch.transforms package can be added 28 | transform = tv.transforms.Compose( 29 | [...], 30 | tv.transforms.ToTensor(), 31 | [...] 32 | ) 33 | 34 | # create the dataloader for the tf.keras model from PyTorch DataLoader object 35 | dataset = tv.datasets.ImageFolder('path/to/folder', transform=transform) 36 | dataloader = DataGenerator(pt.utils.data.DataLoader(dataset, [...]), ncl) # ncl represents the number of classes for the model 37 | 38 | # creating and defining the tf.keras model 39 | model = k.models.Sequential() 40 | [...] # using the model.add([...]) function new layers can be added to the model 41 | 42 | model.compile([...]) # compile the model (custom parameter choices) 43 | model.fit_generator(dataloader, [...]) # fitting the model using the datagenerator (custom parameter choices) 44 | 45 | model.save('path/to/model/name.h5') # save the model (optional but useful) 46 | ``` 47 | The recommended way of using this class is by importing it as a module because docstrings are provided to document the module. In the plain.py file the documentation is not present for shortening the code. 48 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.keras as k 3 | 4 | 5 | class DataGenerator(k.utils.Sequence): 6 | """ 7 | class to be fed into model.fit_generator method of tf.keras model 8 | 9 | uses a pytorch dataloader object to create a new generator object that can be used by tf.keras 10 | dataloader in pytorch must be used to load image data 11 | transforms on the input image data can be done with pytorch, model fitting still with tf.keras 12 | 13 | ... 14 | 15 | Attributes 16 | ---------- 17 | gen : torch.utils.data.dataloader.DataLoader 18 | pytorch dataloader object; should be able to load image data for pytorch model 19 | ncl : int 20 | number of classes of input data; equal to number of outputs of model 21 | """ 22 | def __init__(self, gen, ncl): 23 | """ 24 | Parameters 25 | ---------- 26 | gen : torch.utils.data.dataloader.DataLoader 27 | pytorch dataloader object; should be able to load image data for pytorch model 28 | ncl : int 29 | number of classes of input data; equal to number of outputs of model 30 | """ 31 | self.gen = gen 32 | self.iter = iter(gen) 33 | self.ncl = ncl 34 | 35 | def __getitem__(self, _): 36 | """ 37 | function used by model.fit_generator to get next input image batch 38 | 39 | Variables 40 | --------- 41 | ims : np.ndarray 42 | image inputs; tensor of (batch_size, height, width, channels); input of model 43 | lbs : np.ndarray 44 | labels; tensor of (batch_size, number_of_classes); correct outputs for model 45 | """ 46 | # catch when no items left in iterator 47 | try: 48 | ims, lbs = next(self.iter) # generation of data handled by pytorch dataloader 49 | # catch when no items left in iterator 50 | except StopIteration: 51 | self.iter = iter(self.gen) # reinstanciate iteator of data 52 | ims, lbs = next(self.iter) # generation of data handled by pytorch dataloader 53 | # swap dimensions of image data to match tf.keras dimension ordering 54 | ims = np.swapaxes(np.swapaxes(ims.numpy(), 1, 3), 1, 2) 55 | # convert labels to one hot representation 56 | lbs = np.eye(self.ncl)[lbs] 57 | return ims, lbs 58 | 59 | def __len__(self): 60 | """ 61 | function that returns the number of batches in one epoch 62 | """ 63 | return len(self.gen) 64 | -------------------------------------------------------------------------------- /plain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.keras as k 3 | 4 | 5 | class DataGenerator(k.utils.Sequence): 6 | 7 | def __init__(self, gen, ncl): 8 | self.gen = gen 9 | self.iter = iter(gen) 10 | self.ncl = ncl 11 | 12 | def __getitem__(self, _): 13 | try: 14 | ims, lbs = next(self.iter) 15 | except StopIteration: 16 | self.iter = iter(self.gen) 17 | ims, lbs = next(self.iter) 18 | ims = np.swapaxes(np.swapaxes(ims.numpy(), 1, 3), 1, 2) 19 | lbs = np.eye(self.ncl)[lbs].reshape(self.gen.batch_size, self.ncl) 20 | return ims, lbs 21 | 22 | def __len__(self): 23 | return len(self.gen) 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | astor==0.8.0 3 | cachetools==3.1.1 4 | certifi==2024.7.4 5 | chardet==3.0.4 6 | gast==0.2.2 7 | google-auth==1.7.1 8 | google-auth-oauthlib==0.4.1 9 | google-pasta==0.1.8 10 | grpcio==1.53.2 11 | h5py==2.10.0 12 | idna==3.7 13 | Keras-Applications==1.0.8 14 | Keras-Preprocessing==1.1.0 15 | Markdown==3.1.1 16 | numpy==1.22.0 17 | oauthlib==3.1.0 18 | opt-einsum==3.1.0 19 | Pillow==10.3.0 20 | protobuf==3.18.3 21 | pyasn1==0.4.8 22 | pyasn1-modules==0.2.7 23 | requests==2.32.2 24 | requests-oauthlib==1.3.0 25 | rsa==4.7 26 | six==1.13.0 27 | tensorboard==2.0.2 28 | tensorflow==2.12.1 29 | tensorflow-estimator==2.0.1 30 | termcolor==1.1.0 31 | torch==2.2.0 32 | torchvision==0.4.2 33 | urllib3==1.26.19 34 | Werkzeug==3.0.3 35 | wrapt==1.11.2 36 | --------------------------------------------------------------------------------