├── requirements.txt ├── __init__.py ├── outputs.npy ├── dataset_features.npy ├── dataset_inputs.npy ├── dataset_outputs.npy ├── .github └── FUNDING.yml ├── pytorch_pygad_regression.py ├── torchga.py ├── pytorch_pygad_XOR_classification.py ├── pytorch_pygad_image_classification_Dense.py ├── pytorch_pygad_image_classification_CNN.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | pygad 2 | torch 3 | numpy -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .torchga import * 2 | 3 | __version__ = "1.4.0" 4 | -------------------------------------------------------------------------------- /outputs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmedfgad/TorchGA/HEAD/outputs.npy -------------------------------------------------------------------------------- /dataset_features.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmedfgad/TorchGA/HEAD/dataset_features.npy -------------------------------------------------------------------------------- /dataset_inputs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmedfgad/TorchGA/HEAD/dataset_inputs.npy -------------------------------------------------------------------------------- /dataset_outputs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahmedfgad/TorchGA/HEAD/dataset_outputs.npy -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | # paypal: http://paypal.me/ahmedfgad # Replace with a single Patreon username 5 | open_collective: pygad 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: ['https://donate.stripe.com/eVa5kO866elKgM0144', 'http://paypal.me/ahmedfgad'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /pytorch_pygad_regression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pygad.torchga 3 | import pygad 4 | 5 | def fitness_func(ga_instanse, solution, sol_idx): 6 | global data_inputs, data_outputs, torch_ga, model, loss_function 7 | 8 | predictions = pygad.torchga.predict(model=model, 9 | solution=solution, 10 | data=data_inputs) 11 | abs_error = loss_function(predictions, data_outputs).detach().numpy() + 0.00000001 12 | 13 | solution_fitness = 1.0 / abs_error 14 | 15 | return solution_fitness 16 | 17 | def callback_generation(ga_instance): 18 | print("Generation = {generation}".format(generation=ga_instance.generations_completed)) 19 | print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) 20 | 21 | # Create the PyTorch model. 22 | input_layer = torch.nn.Linear(3, 2) 23 | relu_layer = torch.nn.ReLU() 24 | output_layer = torch.nn.Linear(2, 1) 25 | 26 | model = torch.nn.Sequential(input_layer, 27 | relu_layer, 28 | output_layer) 29 | # print(model) 30 | 31 | # Create an instance of the pygad.torchga.TorchGA class to build the initial population. 32 | torch_ga = pygad.torchga.TorchGA(model=model, 33 | num_solutions=10) 34 | 35 | loss_function = torch.nn.L1Loss() 36 | 37 | # Data inputs 38 | data_inputs = torch.tensor([[0.02, 0.1, 0.15], 39 | [0.7, 0.6, 0.8], 40 | [1.5, 1.2, 1.7], 41 | [3.2, 2.9, 3.1]]) 42 | 43 | # Data outputs 44 | data_outputs = torch.tensor([[0.1], 45 | [0.6], 46 | [1.3], 47 | [2.5]]) 48 | 49 | # Prepare the PyGAD parameters. Check the documentation for more information: https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class 50 | num_generations = 250 # Number of generations. 51 | num_parents_mating = 5 # Number of solutions to be selected as parents in the mating pool. 52 | initial_population = torch_ga.population_weights # Initial population of network weights 53 | 54 | ga_instance = pygad.GA(num_generations=num_generations, 55 | num_parents_mating=num_parents_mating, 56 | initial_population=initial_population, 57 | fitness_func=fitness_func, 58 | on_generation=callback_generation) 59 | 60 | ga_instance.run() 61 | 62 | # After the generations complete, some plots are showed that summarize how the outputs/fitness values evolve over generations. 63 | ga_instance.plot_fitness(title="PyGAD & PyTorch - Iteration vs. Fitness", linewidth=4) 64 | 65 | # Returning the details of the best solution. 66 | solution, solution_fitness, solution_idx = ga_instance.best_solution() 67 | print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) 68 | print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) 69 | 70 | predictions = pygad.torchga.predict(model=model, 71 | solution=solution, 72 | data=data_inputs) 73 | print("Predictions : \n", predictions.detach().numpy()) 74 | 75 | abs_error = loss_function(predictions, data_outputs) 76 | print("Absolute Error : ", abs_error.detach().numpy()) 77 | -------------------------------------------------------------------------------- /torchga.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy 3 | import torch 4 | 5 | def model_weights_as_vector(model): 6 | weights_vector = [] 7 | 8 | for curr_weights in model.state_dict().values(): 9 | # Calling detach() to remove the computational graph from the layer. 10 | # cpu() is called for making shore the data is moved from GPU to cpu 11 | # numpy() is called for converting the tensor into a NumPy array. 12 | curr_weights = curr_weights.cpu().detach().numpy() 13 | vector = numpy.reshape(curr_weights, newshape=(curr_weights.size)) 14 | weights_vector.extend(vector) 15 | 16 | return numpy.array(weights_vector) 17 | 18 | def model_weights_as_dict(model, weights_vector): 19 | weights_dict = model.state_dict() 20 | 21 | start = 0 22 | for key in weights_dict: 23 | # Calling detach() to remove the computational graph from the layer. 24 | # cpu() is called for making shore the data is moved from GPU to cpu 25 | # numpy() is called for converting the tensor into a NumPy array. 26 | w_matrix = weights_dict[key].cpu().detach().numpy() 27 | layer_weights_shape = w_matrix.shape 28 | layer_weights_size = w_matrix.size 29 | 30 | layer_weights_vector = weights_vector[start:start + layer_weights_size] 31 | layer_weights_matrix = numpy.reshape(layer_weights_vector, newshape=(layer_weights_shape)) 32 | weights_dict[key] = torch.from_numpy(layer_weights_matrix) 33 | 34 | start = start + layer_weights_size 35 | 36 | return weights_dict 37 | 38 | def predict(model, solution, data): 39 | # Fetch the parameters of the best solution. 40 | model_weights_dict = model_weights_as_dict(model=model, 41 | weights_vector=solution) 42 | 43 | # Use the current solution as the model parameters. 44 | _model = copy.deepcopy(model) 45 | _model.load_state_dict(model_weights_dict) 46 | 47 | with torch.no_grad(): 48 | predictions = _model(data) 49 | 50 | return predictions 51 | 52 | class TorchGA: 53 | 54 | def __init__(self, model, num_solutions): 55 | 56 | """ 57 | Creates an instance of the TorchGA class to build a population of model parameters. 58 | 59 | model: A PyTorch model class. 60 | num_solutions: Number of solutions in the population. Each solution has different model parameters. 61 | """ 62 | 63 | self.model = model 64 | 65 | self.num_solutions = num_solutions 66 | 67 | # A list holding references to all the solutions (i.e. networks) used in the population. 68 | self.population_weights = self.create_population() 69 | 70 | def create_population(self): 71 | 72 | """ 73 | Creates the initial population of the genetic algorithm as a list of networks' weights (i.e. solutions). Each element in the list holds a different weights of the PyTorch model. 74 | 75 | The method returns a list holding the weights of all solutions. 76 | """ 77 | 78 | model_weights_vector = model_weights_as_vector(model=self.model) 79 | 80 | net_population_weights = [] 81 | net_population_weights.append(model_weights_vector) 82 | 83 | for idx in range(self.num_solutions-1): 84 | 85 | net_weights = copy.deepcopy(model_weights_vector) 86 | net_weights = numpy.array(net_weights) + numpy.random.uniform(low=-1.0, high=1.0, size=model_weights_vector.size) 87 | 88 | # Appending the weights to the population. 89 | net_population_weights.append(net_weights) 90 | 91 | return net_population_weights 92 | -------------------------------------------------------------------------------- /pytorch_pygad_XOR_classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pygad.torchga 3 | import pygad 4 | 5 | def fitness_func(ga_instanse, solution, sol_idx): 6 | global data_inputs, data_outputs, torch_ga, model, loss_function 7 | 8 | predictions = pygad.torchga.predict(model=model, 9 | solution=solution, 10 | data=data_inputs) 11 | 12 | solution_fitness = 1.0 / (loss_function(predictions, data_outputs).detach().numpy() + 0.00000001) 13 | 14 | return solution_fitness 15 | 16 | def callback_generation(ga_instance): 17 | print("Generation = {generation}".format(generation=ga_instance.generations_completed)) 18 | print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) 19 | 20 | # Create the PyTorch model. 21 | input_layer = torch.nn.Linear(2, 4) 22 | relu_layer = torch.nn.ReLU() 23 | dense_layer = torch.nn.Linear(4, 2) 24 | output_layer = torch.nn.Softmax(1) 25 | 26 | model = torch.nn.Sequential(input_layer, 27 | relu_layer, 28 | dense_layer, 29 | output_layer) 30 | # print(model) 31 | 32 | # Create an instance of the pygad.torchga.TorchGA class to build the initial population. 33 | torch_ga = pygad.torchga.TorchGA(model=model, 34 | num_solutions=10) 35 | 36 | loss_function = torch.nn.BCELoss() 37 | 38 | # XOR problem inputs 39 | data_inputs = torch.tensor([[0.0, 0.0], 40 | [0.0, 1.0], 41 | [1.0, 0.0], 42 | [1.0, 1.0]]) 43 | 44 | # XOR problem outputs 45 | data_outputs = torch.tensor([[1.0, 0.0], 46 | [0.0, 1.0], 47 | [0.0, 1.0], 48 | [1.0, 0.0]]) 49 | 50 | # Prepare the PyGAD parameters. Check the documentation for more information: https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class 51 | num_generations = 250 # Number of generations. 52 | num_parents_mating = 5 # Number of solutions to be selected as parents in the mating pool. 53 | initial_population = torch_ga.population_weights # Initial population of network weights. 54 | 55 | # Create an instance of the pygad.GA class 56 | ga_instance = pygad.GA(num_generations=num_generations, 57 | num_parents_mating=num_parents_mating, 58 | initial_population=initial_population, 59 | fitness_func=fitness_func, 60 | parallel_processing=3, 61 | on_generation=callback_generation) 62 | 63 | # Start the genetic algorithm evolution. 64 | ga_instance.run() 65 | 66 | # After the generations complete, some plots are showed that summarize how the outputs/fitness values evolve over generations. 67 | ga_instance.plot_fitness(title="PyGAD & PyTorch - Iteration vs. Fitness", linewidth=4) 68 | 69 | # Returning the details of the best solution. 70 | solution, solution_fitness, solution_idx = ga_instance.best_solution() 71 | print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) 72 | print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) 73 | 74 | predictions = pygad.torchga.predict(model=model, 75 | solution=solution, 76 | data=data_inputs) 77 | print("Predictions : \n", predictions.detach().numpy()) 78 | 79 | # Calculate the binary crossentropy for the trained model. 80 | print("Binary Crossentropy : ", loss_function(predictions, data_outputs).detach().numpy()) 81 | 82 | # Calculate the classification accuracy of the trained model. 83 | a = torch.max(predictions, axis=1) 84 | b = torch.max(data_outputs, axis=1) 85 | accuracy = torch.true_divide(torch.sum(a.indices == b.indices), len(data_outputs)) 86 | print("Accuracy : ", accuracy.detach().numpy()) 87 | -------------------------------------------------------------------------------- /pytorch_pygad_image_classification_Dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pygad.torchga 3 | import pygad 4 | import numpy 5 | 6 | def fitness_func(ga_instanse, solution, sol_idx): 7 | global data_inputs, data_outputs, torch_ga, model, loss_function 8 | 9 | predictions = pygad.torchga.predict(model=model, 10 | solution=solution, 11 | data=data_inputs) 12 | 13 | solution_fitness = 1.0 / (loss_function(predictions, data_outputs).detach().numpy() + 0.00000001) 14 | 15 | return solution_fitness 16 | 17 | def callback_generation(ga_instance): 18 | print("Generation = {generation}".format(generation=ga_instance.generations_completed)) 19 | print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) 20 | 21 | # Build the PyTorch model using the functional API. 22 | input_layer = torch.nn.Linear(360, 50) 23 | relu_layer = torch.nn.ReLU() 24 | dense_layer = torch.nn.Linear(50, 4) 25 | output_layer = torch.nn.Softmax(1) 26 | 27 | model = torch.nn.Sequential(input_layer, 28 | relu_layer, 29 | dense_layer, 30 | output_layer) 31 | 32 | # Create an instance of the pygad.torchga.TorchGA class to build the initial population. 33 | torch_ga = pygad.torchga.TorchGA(model=model, 34 | num_solutions=10) 35 | 36 | loss_function = torch.nn.CrossEntropyLoss() 37 | 38 | # Data inputs 39 | data_inputs = torch.from_numpy(numpy.load("dataset_features.npy")).float() 40 | 41 | # Data outputs 42 | data_outputs = torch.from_numpy(numpy.load("outputs.npy")).long() 43 | # The next 2 lines are equivelant to this Keras function to perform 1-hot encoding: tensorflow.keras.utils.to_categorical(data_outputs) 44 | # temp_outs = numpy.zeros((data_outputs.shape[0], numpy.unique(data_outputs).size), dtype=numpy.uint8) 45 | # temp_outs[numpy.arange(data_outputs.shape[0]), numpy.uint8(data_outputs)] = 1 46 | 47 | # Prepare the PyGAD parameters. Check the documentation for more information: https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class 48 | num_generations = 200 # Number of generations. 49 | num_parents_mating = 5 # Number of solutions to be selected as parents in the mating pool. 50 | initial_population = torch_ga.population_weights # Initial population of network weights. 51 | 52 | # Create an instance of the pygad.GA class 53 | ga_instance = pygad.GA(num_generations=num_generations, 54 | num_parents_mating=num_parents_mating, 55 | initial_population=initial_population, 56 | fitness_func=fitness_func, 57 | on_generation=callback_generation) 58 | 59 | # Start the genetic algorithm evolution. 60 | ga_instance.run() 61 | 62 | # After the generations complete, some plots are showed that summarize how the outputs/fitness values evolve over generations. 63 | ga_instance.plot_fitness(title="PyGAD & PyTorch - Iteration vs. Fitness", linewidth=4) 64 | 65 | # Returning the details of the best solution. 66 | solution, solution_fitness, solution_idx = ga_instance.best_solution() 67 | print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) 68 | print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) 69 | 70 | predictions = pygad.torchga.predict(model=model, 71 | solution=solution, 72 | data=data_inputs) 73 | # print("Predictions : \n", predictions) 74 | 75 | # Calculate the crossentropy loss of the trained model. 76 | print("Crossentropy : ", loss_function(predictions, data_outputs).detach().numpy()) 77 | 78 | # Calculate the classification accuracy for the trained model. 79 | accuracy = torch.true_divide(torch.sum(torch.max(predictions, axis=1).indices == data_outputs), len(data_outputs)) 80 | print("Accuracy : ", accuracy.detach().numpy()) 81 | -------------------------------------------------------------------------------- /pytorch_pygad_image_classification_CNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pygad.torchga 3 | import pygad 4 | import numpy 5 | 6 | def fitness_func(ga_instanse, solution, sol_idx): 7 | global data_inputs, data_outputs, torch_ga, model, loss_function 8 | 9 | predictions = pygad.torchga.predict(model=model, 10 | solution=solution, 11 | data=data_inputs) 12 | 13 | solution_fitness = 1.0 / (loss_function(predictions, data_outputs).detach().numpy() + 0.00000001) 14 | 15 | return solution_fitness 16 | 17 | def callback_generation(ga_instance): 18 | print("Generation = {generation}".format(generation=ga_instance.generations_completed)) 19 | print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) 20 | 21 | # Build the PyTorch model. 22 | input_layer = torch.nn.Conv2d(in_channels=3, out_channels=5, kernel_size=7) 23 | relu_layer1 = torch.nn.ReLU() 24 | max_pool1 = torch.nn.MaxPool2d(kernel_size=5, stride=5) 25 | 26 | conv_layer2 = torch.nn.Conv2d(in_channels=5, out_channels=3, kernel_size=3) 27 | relu_layer2 = torch.nn.ReLU() 28 | 29 | flatten_layer1 = torch.nn.Flatten() 30 | # The value 768 is pre-computed by tracing the sizes of the layers' outputs. 31 | dense_layer1 = torch.nn.Linear(in_features=768, out_features=15) 32 | relu_layer3 = torch.nn.ReLU() 33 | 34 | dense_layer2 = torch.nn.Linear(in_features=15, out_features=4) 35 | output_layer = torch.nn.Softmax(1) 36 | 37 | model = torch.nn.Sequential(input_layer, 38 | relu_layer1, 39 | max_pool1, 40 | conv_layer2, 41 | relu_layer2, 42 | flatten_layer1, 43 | dense_layer1, 44 | relu_layer3, 45 | dense_layer2, 46 | output_layer) 47 | 48 | # Create an instance of the pygad.torchga.TorchGA class to build the initial population. 49 | torch_ga = pygad.torchga.TorchGA(model=model, 50 | num_solutions=10) 51 | 52 | loss_function = torch.nn.CrossEntropyLoss() 53 | 54 | # Data inputs 55 | data_inputs = torch.from_numpy(numpy.load("dataset_inputs.npy")).float() 56 | data_inputs = data_inputs.reshape((data_inputs.shape[0], data_inputs.shape[3], data_inputs.shape[1], data_inputs.shape[2])) 57 | 58 | # Data outputs 59 | data_outputs = torch.from_numpy(numpy.load("dataset_outputs.npy")).long() 60 | 61 | # Prepare the PyGAD parameters. Check the documentation for more information: https://pygad.readthedocs.io/en/latest/README_pygad_ReadTheDocs.html#pygad-ga-class 62 | num_generations = 200 # Number of generations. 63 | num_parents_mating = 5 # Number of solutions to be selected as parents in the mating pool. 64 | initial_population = torch_ga.population_weights # Initial population of network weights. 65 | 66 | # Create an instance of the pygad.GA class 67 | ga_instance = pygad.GA(num_generations=num_generations, 68 | num_parents_mating=num_parents_mating, 69 | initial_population=initial_population, 70 | fitness_func=fitness_func, 71 | on_generation=callback_generation) 72 | 73 | # Start the genetic algorithm evolution. 74 | ga_instance.run() 75 | 76 | # After the generations complete, some plots are showed that summarize how the outputs/fitness values evolve over generations. 77 | ga_instance.plot_fitness(title="PyGAD & PyTorch - Iteration vs. Fitness", linewidth=4) 78 | 79 | # Returning the details of the best solution. 80 | solution, solution_fitness, solution_idx = ga_instance.best_solution() 81 | print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) 82 | print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) 83 | 84 | predictions = pygad.torchga.predict(model=model, 85 | solution=solution, 86 | data=data_inputs) 87 | # print("Predictions : \n", predictions) 88 | 89 | # Calculate the crossentropy for the trained model. 90 | print("Crossentropy : ", loss_function(predictions, data_outputs).detach().numpy()) 91 | 92 | # Calculate the classification accuracy for the trained model. 93 | accuracy = torch.true_divide(torch.sum(torch.max(predictions, axis=1).indices == data_outputs), len(data_outputs)) 94 | print("Accuracy : ", accuracy.detach().numpy()) 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchGA: Training PyTorch Models using the Genetic Algorithm 2 | [TorchGA](https://github.com/ahmedfgad/TorchGA) is part of the [PyGAD](https://pypi.org/project/pygad) library for training [PyTorch](https://pytorch.org) models using the genetic algorithm (GA). This feature is supported starting from [PyGAD](https://pypi.org/project/pygad) 2.10.0. 3 | 4 | The [TorchGA](https://github.com/ahmedfgad/TorchGA) project has a single module named `torchga.py` which has a class named `TorchGA` for preparing an initial population of PyTorch model parameters. 5 | 6 | [PyGAD](https://pypi.org/project/pygad) is an open-source Python library for building the genetic algorithm and training machine learning algorithms. Check the library's documentation at [Read The Docs](https://pygad.readthedocs.io/): https://pygad.readthedocs.io 7 | 8 | # Donation 9 | 10 | - [Credit/Debit Card](https://donate.stripe.com/eVa5kO866elKgM0144): https://donate.stripe.com/eVa5kO866elKgM0144 11 | - [Open Collective](https://opencollective.com/pygad): [opencollective.com/pygad](https://opencollective.com/pygad) 12 | - PayPal: Use either this link: [paypal.me/ahmedfgad](https://paypal.me/ahmedfgad) or the e-mail address ahmed.f.gad@gmail.com 13 | - Interac e-Transfer: Use e-mail address ahmed.f.gad@gmail.com 14 | 15 | # Installation 16 | 17 | To install [PyGAD](https://pypi.org/project/pygad), simply use pip to download and install the library from [PyPI](https://pypi.org/project/pygad) (Python Package Index). The library lives a PyPI at this page https://pypi.org/project/pygad. 18 | 19 | ```python 20 | pip3 install pygad 21 | ``` 22 | 23 | To get started with PyGAD, please read the documentation at [Read The Docs](https://pygad.readthedocs.io/) https://pygad.readthedocs.io. 24 | 25 | # PyGAD Source Code 26 | 27 | The source code of the `PyGAD` modules is found in the following GitHub projects: 28 | 29 | - [pygad](https://github.com/ahmedfgad/GeneticAlgorithmPython): (https://github.com/ahmedfgad/GeneticAlgorithmPython) 30 | - [pygad.nn](https://github.com/ahmedfgad/NumPyANN): https://github.com/ahmedfgad/NumPyANN 31 | - [pygad.gann](https://github.com/ahmedfgad/NeuralGenetic): https://github.com/ahmedfgad/NeuralGenetic 32 | - [pygad.cnn](https://github.com/ahmedfgad/NumPyCNN): https://github.com/ahmedfgad/NumPyCNN 33 | - [pygad.gacnn](https://github.com/ahmedfgad/CNNGenetic): https://github.com/ahmedfgad/CNNGenetic 34 | - [pygad.kerasga](https://github.com/ahmedfgad/KerasGA): https://github.com/ahmedfgad/KerasGA 35 | - [pygad.torchga](https://github.com/ahmedfgad/TorchGA): https://github.com/ahmedfgad/TorchGA 36 | 37 | The documentation of PyGAD is available at [Read The Docs](https://pygad.readthedocs.io/) https://pygad.readthedocs.io. 38 | 39 | # PyGAD Documentation 40 | 41 | The documentation of the PyGAD library is available at [Read The Docs](https://pygad.readthedocs.io) at this link: https://pygad.readthedocs.io. It discusses the modules supported by PyGAD, all its classes, methods, attribute, and functions. For each module, a number of examples are given. 42 | 43 | If there is an issue using PyGAD, feel free to post at issue in this [GitHub repository](https://github.com/ahmedfgad/GeneticAlgorithmPython) https://github.com/ahmedfgad/GeneticAlgorithmPython or by sending an e-mail to ahmed.f.gad@gmail.com. 44 | 45 | If you built a project that uses PyGAD, then please drop an e-mail to ahmed.f.gad@gmail.com with the following information so that your project is included in the documentation. 46 | 47 | - Project title 48 | - Brief description 49 | - Preferably, a link that directs the readers to your project 50 | 51 | Please check the **Contact Us** section for more contact details. 52 | 53 | # Life Cycle of PyGAD 54 | 55 | The next figure lists the different stages in the lifecycle of an instance of the `pygad.GA` class. Note that PyGAD stops when either all generations are completed or when the function passed to the `on_generation` parameter returns the string `stop`. 56 | 57 | ![PyGAD Lifecycle](https://user-images.githubusercontent.com/16560492/89446279-9c6f8380-d754-11ea-83fd-a60ea2f53b85.jpg) 58 | 59 | The next code implements all the callback functions to trace the execution of the genetic algorithm. Each callback function prints its name. 60 | 61 | ```python 62 | import pygad 63 | import numpy 64 | 65 | function_inputs = [4,-2,3.5,5,-11,-4.7] 66 | desired_output = 44 67 | 68 | def fitness_func(ga_instance, solution, solution_idx): 69 | output = numpy.sum(solution*function_inputs) 70 | fitness = 1.0 / (numpy.abs(output - desired_output) + 0.000001) 71 | return fitness 72 | 73 | fitness_function = fitness_func 74 | 75 | def on_start(ga_instance): 76 | print("on_start()") 77 | 78 | def on_fitness(ga_instance, population_fitness): 79 | print("on_fitness()") 80 | 81 | def on_parents(ga_instance, selected_parents): 82 | print("on_parents()") 83 | 84 | def on_crossover(ga_instance, offspring_crossover): 85 | print("on_crossover()") 86 | 87 | def on_mutation(ga_instance, offspring_mutation): 88 | print("on_mutation()") 89 | 90 | def on_generation(ga_instance): 91 | print("on_generation()") 92 | 93 | def on_stop(ga_instance, last_population_fitness): 94 | print("on_stop()") 95 | 96 | ga_instance = pygad.GA(num_generations=3, 97 | num_parents_mating=5, 98 | fitness_func=fitness_function, 99 | sol_per_pop=10, 100 | num_genes=len(function_inputs), 101 | on_start=on_start, 102 | on_fitness=on_fitness, 103 | on_parents=on_parents, 104 | on_crossover=on_crossover, 105 | on_mutation=on_mutation, 106 | on_generation=on_generation, 107 | on_stop=on_stop) 108 | 109 | ga_instance.run() 110 | ``` 111 | 112 | Based on the used 3 generations as assigned to the `num_generations` argument, here is the output. 113 | 114 | ``` 115 | on_start() 116 | 117 | on_fitness() 118 | on_parents() 119 | on_crossover() 120 | on_mutation() 121 | on_generation() 122 | 123 | on_fitness() 124 | on_parents() 125 | on_crossover() 126 | on_mutation() 127 | on_generation() 128 | 129 | on_fitness() 130 | on_parents() 131 | on_crossover() 132 | on_mutation() 133 | on_generation() 134 | 135 | on_stop() 136 | ``` 137 | 138 | # Examples 139 | 140 | Check the [PyGAD's documentation](https://pygad.readthedocs.io/en/latest/gacnn.html) for more examples information. You can also find more information about the implementation of the examples. 141 | 142 | ## Example 1: Regression Model 143 | 144 | ```python 145 | import torch 146 | import torchga 147 | import pygad 148 | 149 | def fitness_func(ga_instance, solution, sol_idx): 150 | global data_inputs, data_outputs, torch_ga, model, loss_function 151 | 152 | model_weights_dict = torchga.model_weights_as_dict(model=model, 153 | weights_vector=solution) 154 | 155 | # Use the current solution as the model parameters. 156 | model.load_state_dict(model_weights_dict) 157 | 158 | predictions = model(data_inputs) 159 | abs_error = loss_function(predictions, data_outputs).detach().numpy() + 0.00000001 160 | 161 | solution_fitness = 1.0 / abs_error 162 | 163 | return solution_fitness 164 | 165 | def callback_generation(ga_instance): 166 | print("Generation = {generation}".format(generation=ga_instance.generations_completed)) 167 | print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) 168 | 169 | # Create the PyTorch model. 170 | input_layer = torch.nn.Linear(3, 5) 171 | relu_layer = torch.nn.ReLU() 172 | output_layer = torch.nn.Linear(5, 1) 173 | 174 | model = torch.nn.Sequential(input_layer, 175 | relu_layer, 176 | output_layer) 177 | # print(model) 178 | 179 | # Create an instance of the pygad.torchga.TorchGA class to build the initial population. 180 | torch_ga = torchga.TorchGA(model=model, 181 | num_solutions=10) 182 | 183 | loss_function = torch.nn.L1Loss() 184 | 185 | # Data inputs 186 | data_inputs = torch.tensor([[0.02, 0.1, 0.15], 187 | [0.7, 0.6, 0.8], 188 | [1.5, 1.2, 1.7], 189 | [3.2, 2.9, 3.1]]) 190 | 191 | # Data outputs 192 | data_outputs = torch.tensor([[0.1], 193 | [0.6], 194 | [1.3], 195 | [2.5]]) 196 | 197 | # Prepare the PyGAD parameters. Check the documentation for more information: https://pygad.readthedocs.io/en/latest/pygad.html#pygad-ga-class 198 | num_generations = 250 # Number of generations. 199 | num_parents_mating = 5 # Number of solutions to be selected as parents in the mating pool. 200 | initial_population = torch_ga.population_weights # Initial population of network weights 201 | 202 | ga_instance = pygad.GA(num_generations=num_generations, 203 | num_parents_mating=num_parents_mating, 204 | initial_population=initial_population, 205 | fitness_func=fitness_func, 206 | on_generation=callback_generation) 207 | 208 | ga_instance.run() 209 | 210 | # After the generations complete, some plots are showed that summarize how the outputs/fitness values evolve over generations. 211 | ga_instance.plot_fitness(title="PyGAD & PyTorch - Iteration vs. Fitness", linewidth=4) 212 | 213 | # Returning the details of the best solution. 214 | solution, solution_fitness, solution_idx = ga_instance.best_solution() 215 | print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) 216 | print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) 217 | 218 | # Fetch the parameters of the best solution. 219 | best_solution_weights = torchga.model_weights_as_dict(model=model, 220 | weights_vector=solution) 221 | model.load_state_dict(best_solution_weights) 222 | predictions = model(data_inputs) 223 | print("Predictions : \n", predictions.detach().numpy()) 224 | 225 | abs_error = loss_function(predictions, data_outputs) 226 | print("Absolute Error : ", abs_error.detach().numpy()) 227 | ``` 228 | 229 | ## Example 2: XOR Binary Classification 230 | 231 | ```python 232 | import torch 233 | import torchga 234 | import pygad 235 | 236 | def fitness_func(ga_instance, solution, sol_idx): 237 | global data_inputs, data_outputs, torch_ga, model, loss_function 238 | 239 | model_weights_dict = torchga.model_weights_as_dict(model=model, 240 | weights_vector=solution) 241 | 242 | # Use the current solution as the model parameters. 243 | model.load_state_dict(model_weights_dict) 244 | 245 | predictions = model(data_inputs) 246 | 247 | solution_fitness = 1.0 / (loss_function(predictions, data_outputs).detach().numpy() + 0.00000001) 248 | 249 | return solution_fitness 250 | 251 | def callback_generation(ga_instance): 252 | print("Generation = {generation}".format(generation=ga_instance.generations_completed)) 253 | print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) 254 | 255 | # Create the PyTorch model. 256 | input_layer = torch.nn.Linear(2, 4) 257 | relu_layer = torch.nn.ReLU() 258 | dense_layer = torch.nn.Linear(4, 2) 259 | output_layer = torch.nn.Softmax(1) 260 | 261 | model = torch.nn.Sequential(input_layer, 262 | relu_layer, 263 | dense_layer, 264 | output_layer) 265 | # print(model) 266 | 267 | # Create an instance of the pygad.torchga.TorchGA class to build the initial population. 268 | torch_ga = torchga.TorchGA(model=model, 269 | num_solutions=10) 270 | 271 | loss_function = torch.nn.BCELoss() 272 | 273 | # XOR problem inputs 274 | data_inputs = torch.tensor([[0.0, 0.0], 275 | [0.0, 1.0], 276 | [1.0, 0.0], 277 | [1.0, 1.0]]) 278 | 279 | # XOR problem outputs 280 | data_outputs = torch.tensor([[1.0, 0.0], 281 | [0.0, 1.0], 282 | [0.0, 1.0], 283 | [1.0, 0.0]]) 284 | 285 | # Prepare the PyGAD parameters. Check the documentation for more information: https://pygad.readthedocs.io/en/latest/pygad.html#pygad-ga-class 286 | num_generations = 250 # Number of generations. 287 | num_parents_mating = 5 # Number of solutions to be selected as parents in the mating pool. 288 | initial_population = torch_ga.population_weights # Initial population of network weights. 289 | 290 | # Create an instance of the pygad.GA class 291 | ga_instance = pygad.GA(num_generations=num_generations, 292 | num_parents_mating=num_parents_mating, 293 | initial_population=initial_population, 294 | fitness_func=fitness_func, 295 | on_generation=callback_generation) 296 | 297 | # Start the genetic algorithm evolution. 298 | ga_instance.run() 299 | 300 | # After the generations complete, some plots are showed that summarize how the outputs/fitness values evolve over generations. 301 | ga_instance.plot_fitness(title="PyGAD & PyTorch - Iteration vs. Fitness", linewidth=4) 302 | 303 | # Returning the details of the best solution. 304 | solution, solution_fitness, solution_idx = ga_instance.best_solution() 305 | print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) 306 | print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) 307 | 308 | # Fetch the parameters of the best solution. 309 | best_solution_weights = torchga.model_weights_as_dict(model=model, 310 | weights_vector=solution) 311 | model.load_state_dict(best_solution_weights) 312 | predictions = model(data_inputs) 313 | print("Predictions : \n", predictions.detach().numpy()) 314 | 315 | # Calculate the binary crossentropy for the trained model. 316 | print("Binary Crossentropy : ", loss_function(predictions, data_outputs).detach().numpy()) 317 | 318 | # Calculate the classification accuracy of the trained model. 319 | a = torch.max(predictions, axis=1) 320 | b = torch.max(data_outputs, axis=1) 321 | accuracy = torch.sum(a.indices == b.indices) / len(data_outputs) 322 | print("Accuracy : ", accuracy.detach().numpy()) 323 | ``` 324 | 325 | # For More Information 326 | 327 | There are different resources that can be used to get started with the building CNN and its Python implementation. 328 | 329 | ## Tutorial: Implementing Genetic Algorithm in Python 330 | 331 | To start with coding the genetic algorithm, you can check the tutorial titled [**Genetic Algorithm Implementation in Python**](https://www.linkedin.com/pulse/genetic-algorithm-implementation-python-ahmed-gad) available at these links: 332 | 333 | - [LinkedIn](https://www.linkedin.com/pulse/genetic-algorithm-implementation-python-ahmed-gad) 334 | - [Towards Data Science](https://towardsdatascience.com/genetic-algorithm-implementation-in-python-5ab67bb124a6) 335 | - [KDnuggets](https://www.kdnuggets.com/2018/07/genetic-algorithm-implementation-python.html) 336 | 337 | [This tutorial](https://www.linkedin.com/pulse/genetic-algorithm-implementation-python-ahmed-gad) is prepared based on a previous version of the project but it still a good resource to start with coding the genetic algorithm. 338 | 339 | [![Genetic Algorithm Implementation in Python](https://user-images.githubusercontent.com/16560492/78830052-a3c19300-79e7-11ea-8b9b-4b343ea4049c.png)](https://www.linkedin.com/pulse/genetic-algorithm-implementation-python-ahmed-gad) 340 | 341 | ## Tutorial: Introduction to Genetic Algorithm 342 | 343 | Get started with the genetic algorithm by reading the tutorial titled [**Introduction to Optimization with Genetic Algorithm**](https://www.linkedin.com/pulse/introduction-optimization-genetic-algorithm-ahmed-gad) which is available at these links: 344 | 345 | * [LinkedIn](https://www.linkedin.com/pulse/introduction-optimization-genetic-algorithm-ahmed-gad) 346 | * [Towards Data Science](https://www.kdnuggets.com/2018/03/introduction-optimization-with-genetic-algorithm.html) 347 | * [KDnuggets](https://towardsdatascience.com/introduction-to-optimization-with-genetic-algorithm-2f5001d9964b) 348 | 349 | [![Introduction to Genetic Algorithm](https://user-images.githubusercontent.com/16560492/82078259-26252d00-96e1-11ea-9a02-52a99e1054b9.jpg)](https://www.linkedin.com/pulse/introduction-optimization-genetic-algorithm-ahmed-gad) 350 | 351 | ## Tutorial: Build Neural Networks in Python 352 | 353 | Read about building neural networks in Python through the tutorial titled [**Artificial Neural Network Implementation using NumPy and Classification of the Fruits360 Image Dataset**](https://www.linkedin.com/pulse/artificial-neural-network-implementation-using-numpy-fruits360-gad) available at these links: 354 | 355 | * [LinkedIn](https://www.linkedin.com/pulse/artificial-neural-network-implementation-using-numpy-fruits360-gad) 356 | * [Towards Data Science](https://towardsdatascience.com/artificial-neural-network-implementation-using-numpy-and-classification-of-the-fruits360-image-3c56affa4491) 357 | * [KDnuggets](https://www.kdnuggets.com/2019/02/artificial-neural-network-implementation-using-numpy-and-image-classification.html) 358 | 359 | [![Building Neural Networks Python](https://user-images.githubusercontent.com/16560492/82078281-30472b80-96e1-11ea-8017-6a1f4383d602.jpg)](https://www.linkedin.com/pulse/artificial-neural-network-implementation-using-numpy-fruits360-gad) 360 | 361 | ## Tutorial: Optimize Neural Networks with Genetic Algorithm 362 | 363 | Read about training neural networks using the genetic algorithm through the tutorial titled [**Artificial Neural Networks Optimization using Genetic Algorithm with Python**](https://www.linkedin.com/pulse/artificial-neural-networks-optimization-using-genetic-ahmed-gad) available at these links: 364 | 365 | - [LinkedIn](https://www.linkedin.com/pulse/artificial-neural-networks-optimization-using-genetic-ahmed-gad) 366 | - [Towards Data Science](https://towardsdatascience.com/artificial-neural-networks-optimization-using-genetic-algorithm-with-python-1fe8ed17733e) 367 | - [KDnuggets](https://www.kdnuggets.com/2019/03/artificial-neural-networks-optimization-genetic-algorithm-python.html) 368 | 369 | [![Training Neural Networks using Genetic Algorithm Python](https://user-images.githubusercontent.com/16560492/82078300-376e3980-96e1-11ea-821c-aa6b8ceb44d4.jpg)](https://www.linkedin.com/pulse/artificial-neural-networks-optimization-using-genetic-ahmed-gad) 370 | 371 | ## Tutorial: Building CNN in Python 372 | 373 | To start with coding the genetic algorithm, you can check the tutorial titled [**Building Convolutional Neural Network using NumPy from Scratch**](https://www.linkedin.com/pulse/building-convolutional-neural-network-using-numpy-from-ahmed-gad) available at these links: 374 | 375 | - [LinkedIn](https://www.linkedin.com/pulse/building-convolutional-neural-network-using-numpy-from-ahmed-gad) 376 | - [Towards Data Science](https://towardsdatascience.com/building-convolutional-neural-network-using-numpy-from-scratch-b30aac50e50a) 377 | - [KDnuggets](https://www.kdnuggets.com/2018/04/building-convolutional-neural-network-numpy-scratch.html) 378 | - [Chinese Translation](http://m.aliyun.com/yunqi/articles/585741) 379 | 380 | [This tutorial](https://www.linkedin.com/pulse/building-convolutional-neural-network-using-numpy-from-ahmed-gad)) is prepared based on a previous version of the project but it still a good resource to start with coding CNNs. 381 | 382 | [![Building CNN in Python](https://user-images.githubusercontent.com/16560492/82431022-6c3a1200-9a8e-11ea-8f1b-b055196d76e3.png)](https://www.linkedin.com/pulse/building-convolutional-neural-network-using-numpy-from-ahmed-gad) 383 | 384 | ## Tutorial: Derivation of CNN from FCNN 385 | 386 | Get started with the genetic algorithm by reading the tutorial titled [**Derivation of Convolutional Neural Network from Fully Connected Network Step-By-Step**](https://www.linkedin.com/pulse/derivation-convolutional-neural-network-from-fully-connected-gad) which is available at these links: 387 | 388 | * [LinkedIn](https://www.linkedin.com/pulse/derivation-convolutional-neural-network-from-fully-connected-gad) 389 | * [Towards Data Science](https://towardsdatascience.com/derivation-of-convolutional-neural-network-from-fully-connected-network-step-by-step-b42ebafa5275) 390 | * [KDnuggets](https://www.kdnuggets.com/2018/04/derivation-convolutional-neural-network-fully-connected-step-by-step.html) 391 | 392 | [![Derivation of CNN from FCNN](https://user-images.githubusercontent.com/16560492/82431369-db176b00-9a8e-11ea-99bd-e845192873fc.png)](https://www.linkedin.com/pulse/derivation-convolutional-neural-network-from-fully-connected-gad) 393 | 394 | ## Book: Practical Computer Vision Applications Using Deep Learning with CNNs 395 | 396 | You can also check my book cited as [**Ahmed Fawzy Gad 'Practical Computer Vision Applications Using Deep Learning with CNNs'. Dec. 2018, Apress, 978-1-4842-4167-7**](https://www.amazon.com/Practical-Computer-Vision-Applications-Learning/dp/1484241665) which discusses neural networks, convolutional neural networks, deep learning, genetic algorithm, and more. 397 | 398 | Find the book at these links: 399 | 400 | - [Amazon](https://www.amazon.com/Practical-Computer-Vision-Applications-Learning/dp/1484241665) 401 | - [Springer](https://link.springer.com/book/10.1007/978-1-4842-4167-7) 402 | - [Apress](https://www.apress.com/gp/book/9781484241660) 403 | - [O'Reilly](https://www.oreilly.com/library/view/practical-computer-vision/9781484241677) 404 | - [Google Books](https://books.google.com.eg/books?id=xLd9DwAAQBAJ) 405 | 406 | ![Fig04](https://user-images.githubusercontent.com/16560492/78830077-ae7c2800-79e7-11ea-980b-53b6bd879eeb.jpg) 407 | 408 | # Citing PyGAD - Bibtex Formatted Citation 409 | 410 | If you used PyGAD, please consider adding a citation to the following paper about PyGAD: 411 | 412 | ``` 413 | @misc{gad2021pygad, 414 | title={PyGAD: An Intuitive Genetic Algorithm Python Library}, 415 | author={Ahmed Fawzy Gad}, 416 | year={2021}, 417 | eprint={2106.06158}, 418 | archivePrefix={arXiv}, 419 | primaryClass={cs.NE} 420 | } 421 | ``` 422 | 423 | # Contact Us 424 | 425 | * E-mail: ahmed.f.gad@gmail.com 426 | * [LinkedIn](https://www.linkedin.com/in/ahmedfgad) 427 | * [Amazon Author Page](https://amazon.com/author/ahmedgad) 428 | * [Heartbeat](https://heartbeat.fritz.ai/@ahmedfgad) 429 | * [Paperspace](https://blog.paperspace.com/author/ahmed) 430 | * [KDnuggets](https://kdnuggets.com/author/ahmed-gad) 431 | * [TowardsDataScience](https://towardsdatascience.com/@ahmedfgad) 432 | * [GitHub](https://github.com/ahmedfgad) 433 | --------------------------------------------------------------------------------