├── .gitignore
├── README.md
├── module.py
├── plain.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | #not adding the virtual environment to gituhub (installed through requirements.txt)
2 | venv
3 | #pycache is generated by run
4 | __pycache__
5 | #PyCharm specific folder
6 | .idea
7 | #git specific folder
8 | .gitignore.swp
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PytorchDataloaderForTensorflow
2 | ## Use
3 | As the PyTorch Dataloader has some transforms for input images that can not be done with tf.keras
transforms easily it is useful to be able to load image data with a PyTorch dataloader even for fitting a tf.keras
model. Therefore a class is implemented that uses a PyTorch dataloader object (doing the transformation on the data) which can be fed into the tf.keras.model.fit_generator
function, to provide the training data for the tf.keras
model.
4 | ## Setup
5 | The python files were created for python version 3.7, although it might also work for past or future versions.
6 | To use this class, some python modules need to be installed first. Using pip
the packages can be installed by either typing
7 | pip install -r requirements.txt
8 | in terminal, if the requirements.txt file exists in the current working directory or by typing
9 | pip install tensorflow==2.0.0 torch==1.3.1 torchvision==0.4.2
10 | into the terminal (!python and pip need to be installed first, the recommended version for pip is at least 19.3.1). The versions of the modules listed above were used at the time of the creation of these files but future versions of these modules might alos work. Another way to install these packages is by using conda
.
11 | ## Code
12 | For using the class created for fitting a tf.keras
model there are two options:
13 | 1. Put the code straight into a python file:
14 | For that the code from the file [plain.py](plain.py) should be copied into the python file.
15 | 2. Importing the class from a different python file:
16 | For that the file [module.py](module.py) should be inserted into the project folder in which the executed file lies and imported at the top of the executed file:
17 | from module import DataGenerator
18 |
19 | In the following python code the following elements should be included:
20 | ```python
21 | # load the required modules
22 | import tensorflow.keras as k
23 | import torch as pt
24 | from torchvision as tv
25 |
26 | # define the transforms for the pytorch dataloader
27 | # additional transforms from the torch.transforms package can be added
28 | transform = tv.transforms.Compose(
29 | [...],
30 | tv.transforms.ToTensor(),
31 | [...]
32 | )
33 |
34 | # create the dataloader for the tf.keras model from PyTorch DataLoader object
35 | dataset = tv.datasets.ImageFolder('path/to/folder', transform=transform)
36 | dataloader = DataGenerator(pt.utils.data.DataLoader(dataset, [...]), ncl) # ncl represents the number of classes for the model
37 |
38 | # creating and defining the tf.keras model
39 | model = k.models.Sequential()
40 | [...] # using the model.add([...]) function new layers can be added to the model
41 |
42 | model.compile([...]) # compile the model (custom parameter choices)
43 | model.fit_generator(dataloader, [...]) # fitting the model using the datagenerator (custom parameter choices)
44 |
45 | model.save('path/to/model/name.h5') # save the model (optional but useful)
46 | ```
47 | The recommended way of using this class is by importing it as a module because docstrings are provided to document the module. In the plain.py file the documentation is not present for shortening the code.
48 |
--------------------------------------------------------------------------------
/module.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow.keras as k
3 |
4 |
5 | class DataGenerator(k.utils.Sequence):
6 | """
7 | class to be fed into model.fit_generator method of tf.keras model
8 |
9 | uses a pytorch dataloader object to create a new generator object that can be used by tf.keras
10 | dataloader in pytorch must be used to load image data
11 | transforms on the input image data can be done with pytorch, model fitting still with tf.keras
12 |
13 | ...
14 |
15 | Attributes
16 | ----------
17 | gen : torch.utils.data.dataloader.DataLoader
18 | pytorch dataloader object; should be able to load image data for pytorch model
19 | ncl : int
20 | number of classes of input data; equal to number of outputs of model
21 | """
22 | def __init__(self, gen, ncl):
23 | """
24 | Parameters
25 | ----------
26 | gen : torch.utils.data.dataloader.DataLoader
27 | pytorch dataloader object; should be able to load image data for pytorch model
28 | ncl : int
29 | number of classes of input data; equal to number of outputs of model
30 | """
31 | self.gen = gen
32 | self.iter = iter(gen)
33 | self.ncl = ncl
34 |
35 | def __getitem__(self, _):
36 | """
37 | function used by model.fit_generator to get next input image batch
38 |
39 | Variables
40 | ---------
41 | ims : np.ndarray
42 | image inputs; tensor of (batch_size, height, width, channels); input of model
43 | lbs : np.ndarray
44 | labels; tensor of (batch_size, number_of_classes); correct outputs for model
45 | """
46 | # catch when no items left in iterator
47 | try:
48 | ims, lbs = next(self.iter) # generation of data handled by pytorch dataloader
49 | # catch when no items left in iterator
50 | except StopIteration:
51 | self.iter = iter(self.gen) # reinstanciate iteator of data
52 | ims, lbs = next(self.iter) # generation of data handled by pytorch dataloader
53 | # swap dimensions of image data to match tf.keras dimension ordering
54 | ims = np.swapaxes(np.swapaxes(ims.numpy(), 1, 3), 1, 2)
55 | # convert labels to one hot representation
56 | lbs = np.eye(self.ncl)[lbs]
57 | return ims, lbs
58 |
59 | def __len__(self):
60 | """
61 | function that returns the number of batches in one epoch
62 | """
63 | return len(self.gen)
64 |
--------------------------------------------------------------------------------
/plain.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow.keras as k
3 |
4 |
5 | class DataGenerator(k.utils.Sequence):
6 |
7 | def __init__(self, gen, ncl):
8 | self.gen = gen
9 | self.iter = iter(gen)
10 | self.ncl = ncl
11 |
12 | def __getitem__(self, _):
13 | try:
14 | ims, lbs = next(self.iter)
15 | except StopIteration:
16 | self.iter = iter(self.gen)
17 | ims, lbs = next(self.iter)
18 | ims = np.swapaxes(np.swapaxes(ims.numpy(), 1, 3), 1, 2)
19 | lbs = np.eye(self.ncl)[lbs].reshape(self.gen.batch_size, self.ncl)
20 | return ims, lbs
21 |
22 | def __len__(self):
23 | return len(self.gen)
24 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.8.1
2 | astor==0.8.0
3 | cachetools==3.1.1
4 | certifi==2024.7.4
5 | chardet==3.0.4
6 | gast==0.2.2
7 | google-auth==1.7.1
8 | google-auth-oauthlib==0.4.1
9 | google-pasta==0.1.8
10 | grpcio==1.53.2
11 | h5py==2.10.0
12 | idna==3.7
13 | Keras-Applications==1.0.8
14 | Keras-Preprocessing==1.1.0
15 | Markdown==3.1.1
16 | numpy==1.22.0
17 | oauthlib==3.1.0
18 | opt-einsum==3.1.0
19 | Pillow==10.3.0
20 | protobuf==3.18.3
21 | pyasn1==0.4.8
22 | pyasn1-modules==0.2.7
23 | requests==2.32.2
24 | requests-oauthlib==1.3.0
25 | rsa==4.7
26 | six==1.13.0
27 | tensorboard==2.0.2
28 | tensorflow==2.12.1
29 | tensorflow-estimator==2.0.1
30 | termcolor==1.1.0
31 | torch==2.2.0
32 | torchvision==0.4.2
33 | urllib3==1.26.19
34 | Werkzeug==3.0.3
35 | wrapt==1.11.2
36 |
--------------------------------------------------------------------------------