├── data └── .keep ├── samples └── .keep ├── screenshots ├── luigi_raceway.png └── record_setup.png ├── .gitignore ├── requirements.txt ├── LICENSE.txt ├── play.py ├── train.py ├── README.md ├── record.py └── utils.py /data/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /samples/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /screenshots/luigi_raceway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinhughes27/TensorKart/HEAD/screenshots/luigi_raceway.png -------------------------------------------------------------------------------- /screenshots/record_setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinhughes27/TensorKart/HEAD/screenshots/record_setup.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | samples/** 2 | MarioKart64.z64 3 | *.pyc 4 | *.npy 5 | checkpoint 6 | model.ckpt* 7 | *.gif 8 | *.ogv 9 | *.mp4 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | inputs==0.1 2 | mss==3.2.0 3 | matplotlib==1.5.2 4 | numpy==1.19.5 5 | scikit-image==0.12.3 6 | termcolor==1.1.0 7 | h5py==2.7.0 8 | tensorflow==2.5.0 9 | Pillow==4.1.0 10 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kevin Hughes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /play.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from utils import resize_image, XboxController 4 | from termcolor import cprint 5 | 6 | import gym 7 | import gym_mupen64plus 8 | from train import create_model 9 | import numpy as np 10 | 11 | # Play 12 | class Actor(object): 13 | 14 | def __init__(self): 15 | # Load in model from train.py and load in the trained weights 16 | self.model = create_model(keep_prob=1) # no dropout 17 | self.model.load_weights('model_weights.h5') 18 | 19 | # Init contoller for manual override 20 | self.real_controller = XboxController() 21 | 22 | def get_action(self, obs): 23 | 24 | ### determine manual override 25 | manual_override = self.real_controller.LeftBumper == 1 26 | 27 | if not manual_override: 28 | ## Look 29 | vec = resize_image(obs) 30 | vec = np.expand_dims(vec, axis=0) # expand dimensions for predict, it wants (1,66,200,3) not (66, 200, 3) 31 | ## Think 32 | joystick = self.model.predict(vec, batch_size=1)[0] 33 | 34 | else: 35 | joystick = self.real_controller.read() 36 | joystick[1] *= -1 # flip y (this is in the config when it runs normally) 37 | 38 | 39 | ## Act 40 | 41 | ### calibration 42 | output = [ 43 | int(joystick[0] * 80), 44 | int(joystick[1] * 80), 45 | int(round(joystick[2])), 46 | int(round(joystick[3])), 47 | int(round(joystick[4])), 48 | ] 49 | 50 | ### print to console 51 | if manual_override: 52 | cprint("Manual: " + str(output), 'yellow') 53 | else: 54 | cprint("AI: " + str(output), 'green') 55 | 56 | return output 57 | 58 | 59 | if __name__ == '__main__': 60 | env = gym.make('Mario-Kart-Royal-Raceway-v0') 61 | 62 | obs = env.reset() 63 | env.render() 64 | print('env ready!') 65 | 66 | actor = Actor() 67 | print('actor ready!') 68 | 69 | print('beginning episode loop') 70 | total_reward = 0 71 | end_episode = False 72 | while not end_episode: 73 | action = actor.get_action(obs) 74 | obs, reward, end_episode, info = env.step(action) 75 | env.render() 76 | total_reward += reward 77 | 78 | print('end episode... total reward: ' + str(total_reward)) 79 | 80 | obs = env.reset() 81 | print('env ready!') 82 | 83 | input('press to quit') 84 | 85 | env.close() 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.keras.models import Sequential 6 | from tensorflow.keras.layers import Dense, Dropout, Flatten 7 | from tensorflow.keras.layers import Conv2D 8 | from tensorflow.keras import optimizers 9 | from tensorflow.keras import backend as K 10 | from tensorflow.keras.callbacks import ModelCheckpoint 11 | from utils import Sample 12 | 13 | # Global variable 14 | OUT_SHAPE = 5 15 | INPUT_SHAPE = (Sample.IMG_H, Sample.IMG_W, Sample.IMG_D) 16 | 17 | 18 | def customized_loss(y_true, y_pred, loss='euclidean'): 19 | # Simply a mean squared error that penalizes large joystick summed values 20 | if loss == 'L2': 21 | L2_norm_cost = 0.001 22 | val = K.mean(K.square((y_pred - y_true)), axis=-1) \ 23 | + K.sum(K.square(y_pred), axis=-1)/2 * L2_norm_cost 24 | # euclidean distance loss 25 | elif loss == 'euclidean': 26 | val = K.sqrt(K.sum(K.square(y_pred-y_true), axis=-1)) 27 | return val 28 | 29 | 30 | def create_model(keep_prob = 0.8): 31 | model = Sequential() 32 | 33 | # NVIDIA's model 34 | model.add(Conv2D(24, kernel_size=(5, 5), strides=(2, 2), activation='relu', input_shape= INPUT_SHAPE)) 35 | model.add(Conv2D(36, kernel_size=(5, 5), strides=(2, 2), activation='relu')) 36 | model.add(Conv2D(48, kernel_size=(5, 5), strides=(2, 2), activation='relu')) 37 | model.add(Conv2D(64, kernel_size=(3, 3), activation='relu')) 38 | model.add(Conv2D(64, kernel_size=(3, 3), activation='relu')) 39 | model.add(Flatten()) 40 | model.add(Dense(1164, activation='relu')) 41 | drop_out = 1 - keep_prob 42 | model.add(Dropout(drop_out)) 43 | model.add(Dense(100, activation='relu')) 44 | model.add(Dropout(drop_out)) 45 | model.add(Dense(50, activation='relu')) 46 | model.add(Dropout(drop_out)) 47 | model.add(Dense(10, activation='relu')) 48 | model.add(Dropout(drop_out)) 49 | model.add(Dense(OUT_SHAPE, activation='softsign')) 50 | 51 | return model 52 | 53 | 54 | if __name__ == '__main__': 55 | # Load Training Data 56 | x_train = np.load("data/X.npy") 57 | y_train = np.load("data/y.npy") 58 | 59 | print(x_train.shape[0], 'train samples') 60 | 61 | # Training loop variables 62 | epochs = 100 63 | batch_size = 50 64 | 65 | model = create_model() 66 | 67 | checkpoint = ModelCheckpoint('model_weights.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min') 68 | callbacks_list = [checkpoint] 69 | 70 | model.compile(loss=customized_loss, optimizer=optimizers.adam()) 71 | model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, shuffle=True, validation_split=0.1, callbacks=callbacks_list) 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | TensorKart 2 | ========== 3 | 4 | self-driving MarioKart with TensorFlow 5 | 6 | Driving a new (untrained) section of the Royal Raceway: 7 | 8 | ![RoyalRaceway.gif](https://media.giphy.com/media/1435VvCosVezQY/giphy.gif) 9 | 10 | Driving Luigi Raceway: 11 | 12 | [![LuigiRacewayVideo](/screenshots/luigi_raceway.png?raw=true)](https://youtu.be/vrccd3yeXnc) 13 | 14 | The model was trained with: 15 | * 4 races on Luigi Raceway 16 | * 2 races on Kalimari Desert 17 | * 2 races on Mario Raceway 18 | 19 | With even a small training set the model is sometimes able to generalize to a new track (Royal Raceway seen above). 20 | 21 | 22 | Dependencies 23 | ------------ 24 | * `python` and `pip` then run `pip install -r requirements.txt` 25 | * `mupen64plus` (install via apt-get) 26 | 27 | 28 | Recording Samples 29 | ----------------- 30 | 1. Start your emulator program (`mupen64plus`) and run Mario Kart 64 31 | 2. Make sure you have a joystick connected and that `mupen64plus` is using the sdl input plugin 32 | 3. Run `record.py` 33 | 4. Make sure the graph responds to joystick input. 34 | 5. Position the emulator window so that the image is captured by the program (top left corner) 35 | 6. Press record and play through a level. You can trim some images off the front and back of the data you collect afterwards (by removing lines in `data.csv`). 36 | 37 | ![record](/screenshots/record_setup.png?raw=true) 38 | 39 | Notes 40 | - the GUI will stop updating while recording to avoid any slow downs. 41 | - double check the samples, sometimes the screenshot is the desktop instead. Remove the appropriate lines from the `data.csv` file 42 | 43 | 44 | Viewing Samples 45 | --------------- 46 | Run `python utils.py viewer samples/luigi_raceway` to view the samples 47 | 48 | 49 | Preparing Training Data 50 | ----------------------- 51 | Run `python utils.py prepare samples/*` with an array of sample directories to build an `X` and `y` matrix for training. (zsh will expand samples/* to all the directories. Passing a glob directly also works) 52 | 53 | `X` is a 3-Dimensional array of images 54 | 55 | `y` is the expected joystick ouput as an array: 56 | 57 | ``` 58 | [0] joystick x axis 59 | [1] joystick y axis 60 | [2] button a 61 | [3] button b 62 | [4] button rb 63 | ``` 64 | 65 | 66 | Training 67 | -------- 68 | The `train.py` program will train a model using Google's TensorFlow framework and cuDNN for GPU acceleration. Training can take a while (~1 hour) depending on how much data you are training with and your system specs. The program will save the best model from all epochs of training to disk when it is done. 69 | 70 | 71 | Play 72 | ---- 73 | The `play.py` program will use the [`gym-mupen64plus`](https://github.com/bzier/gym-mupen64plus) environment to execute the trained agent against the MarioKart environment. The environment will provide the screenshots of the emulator. These images will be sent to the model to acquire the joystick command to send. The AI joystick commands can be overridden by holding the 'LB' button on the controller. 74 | 75 | 76 | Future Work / Ideas: 77 | -------------------- 78 | * Add a reinforcement layer based on lap time or other metrics so that the AI can start to teach itself now that it has a baseline. The environment currently provides a reward signal of `-1` per time-step, which gives the AI agent a metric to calculate its performance during each race (episode), the goal being to maximize reward and therefore, minimize overall race duration. 79 | * Could also have a shadow mode where the AI just draws out what it would do rather than sending actions. A real self driving car would have this and use it a lot before letting it take the wheel. 80 | * Deep learning is all about data; perhaps a community could form around collecting a large amount of data and pushing the performance of this AI. 81 | 82 | 83 | Related Projects: 84 | -------------------- 85 | [`Xbox Game AI`](https://github.com/mgagvani/Xbox-Game-AI) - Uses [`PYXInput`](https://github.com/bayangan1991/PYXInput) for direct control of any Xbox/PC game. 86 | 87 | [`SerpentAI`](https://github.com/SerpentAI/SerpentAI) - Game Agent Framework to create AIs for any game. 88 | 89 | [`Donkey Gym`](https://github.com/tawnkramer/gym-donkeycar) - OpenAI Gym Environments for self-driving "[`Donkey Car`](https://github.com/autorope/donkeycar)". 90 | 91 | [`AirSim`](https://github.com/microsoft/AirSim) - An Unreal Engine simulator for autonoumous vehicles. 92 | 93 | Special Thanks To 94 | ----------------- 95 | * https://github.com/SullyChen/Autopilot-TensorFlow 96 | 97 | 98 | Contributing 99 | ------------ 100 | Open a PR! I promise I am friendly :) 101 | -------------------------------------------------------------------------------- /record.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | import os 5 | import shutil 6 | import mss 7 | import matplotlib 8 | matplotlib.use('TkAgg') 9 | from datetime import datetime 10 | from matplotlib.figure import Figure 11 | from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg as FigCanvas 12 | 13 | from PIL import ImageTk, Image 14 | 15 | import sys 16 | 17 | PY3_OR_LATER = sys.version_info[0] >= 3 18 | 19 | if PY3_OR_LATER: 20 | # Python 3 specific definitions 21 | import tkinter as tk 22 | import tkinter.ttk as ttk 23 | import tkinter.messagebox as tkMessageBox 24 | else: 25 | # Python 2 specific definitions 26 | import Tkinter as tk 27 | import ttk 28 | import tkMessageBox 29 | 30 | from utils import Screenshot, XboxController 31 | 32 | IMAGE_SIZE = (320, 240) 33 | IDLE_SAMPLE_RATE = 1500 34 | SAMPLE_RATE = 200 35 | IMAGE_TYPE = ".png" 36 | 37 | class MainWindow(): 38 | """ Main frame of the application 39 | """ 40 | 41 | def __init__(self): 42 | self.root = tk.Tk() 43 | self.sct = mss.mss() 44 | 45 | self.root.title('Data Acquisition') 46 | self.root.geometry("660x325") 47 | self.root.resizable(False, False) 48 | 49 | # Init controller 50 | self.controller = XboxController() 51 | 52 | # Create GUI 53 | self.create_main_panel() 54 | 55 | # Timer 56 | self.rate = IDLE_SAMPLE_RATE 57 | self.sample_rate = SAMPLE_RATE 58 | self.idle_rate = IDLE_SAMPLE_RATE 59 | self.recording = False 60 | self.t = 0 61 | self.pause_timer = False 62 | self.on_timer() 63 | 64 | self.root.mainloop() 65 | 66 | def create_main_panel(self): 67 | # Panels 68 | top_half = tk.Frame(self.root) 69 | top_half.pack(side=tk.TOP, expand=True, padx=5, pady=5) 70 | message = tk.Label(self.root, text="(Note: UI updates are disabled while recording)") 71 | message.pack(side=tk.TOP, padx=5) 72 | bottom_half = tk.Frame(self.root) 73 | bottom_half.pack(side=tk.LEFT, padx=5, pady=10) 74 | 75 | # Images 76 | self.img_panel = tk.Label(top_half, image=ImageTk.PhotoImage("RGB", size=IMAGE_SIZE)) # Placeholder 77 | self.img_panel.pack(side = tk.LEFT, expand=False, padx=5) 78 | 79 | # Joystick 80 | self.init_plot() 81 | self.PlotCanvas = FigCanvas(figure=self.fig, master=top_half) 82 | self.PlotCanvas.get_tk_widget().pack(side=tk.RIGHT, expand=False, padx=5) 83 | 84 | # Recording 85 | textframe = tk.Frame(bottom_half, width=332, height=15, padx=5) 86 | textframe.pack(side=tk.LEFT) 87 | textframe.pack_propagate(0) 88 | self.outputDirStrVar = tk.StringVar() 89 | self.txt_outputDir = tk.Entry(textframe, textvariable=self.outputDirStrVar, width=100) 90 | self.txt_outputDir.pack(side=tk.LEFT) 91 | self.outputDirStrVar.set("samples/" + datetime.now().strftime('%Y-%m-%d_%H:%M:%S')) 92 | 93 | self.record_button = ttk.Button(bottom_half, text="Record", command=self.on_btn_record) 94 | self.record_button.pack(side = tk.LEFT, padx=5) 95 | 96 | 97 | def init_plot(self): 98 | self.plotMem = 50 # how much data to keep on the plot 99 | self.plotData = [[0] * (5)] * self.plotMem # mem storage for plot 100 | 101 | self.fig = Figure(figsize=(4,3), dpi=80) # 320,240 102 | self.axes = self.fig.add_subplot(111) 103 | 104 | 105 | def on_timer(self): 106 | self.poll() 107 | 108 | # stop drawing if recording to avoid slow downs 109 | if self.recording == False: 110 | self.draw() 111 | 112 | if not self.pause_timer: 113 | self.root.after(self.rate, self.on_timer) 114 | 115 | 116 | def poll(self): 117 | self.img = self.take_screenshot() 118 | self.controller_data = self.controller.read() 119 | self.update_plot() 120 | 121 | if self.recording == True: 122 | self.save_data() 123 | self.t += 1 124 | 125 | 126 | def take_screenshot(self): 127 | # Get raw pixels from the screen 128 | sct_img = self.sct.grab({ "top": Screenshot.OFFSET_Y, 129 | "left": Screenshot.OFFSET_X, 130 | "width": Screenshot.SRC_W, 131 | "height": Screenshot.SRC_H}) 132 | 133 | # Create the Image 134 | return Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX') 135 | 136 | 137 | def update_plot(self): 138 | self.plotData.append(self.controller_data) # adds to the end of the list 139 | self.plotData.pop(0) # remove the first item in the list, ie the oldest 140 | 141 | 142 | def save_data(self): 143 | image_file = self.outputDir+'/'+'img_'+str(self.t)+IMAGE_TYPE 144 | self.img.save(image_file) 145 | 146 | # write csv line 147 | self.outfile.write( image_file + ',' + ','.join(map(str, self.controller_data)) + '\n' ) 148 | 149 | 150 | def draw(self): 151 | # Image 152 | self.img.thumbnail(IMAGE_SIZE, Image.ANTIALIAS) # Resize 153 | self.img_panel.img = ImageTk.PhotoImage(self.img) 154 | self.img_panel['image'] = self.img_panel.img 155 | 156 | # Joystick 157 | x = np.asarray(self.plotData) 158 | self.axes.clear() 159 | self.axes.plot(range(0,self.plotMem), x[:,0], 'r') 160 | self.axes.plot(range(0,self.plotMem), x[:,1], 'b') 161 | self.axes.plot(range(0,self.plotMem), x[:,2], 'g') 162 | self.axes.plot(range(0,self.plotMem), x[:,3], 'k') 163 | self.axes.plot(range(0,self.plotMem), x[:,4], 'y') 164 | self.PlotCanvas.draw() 165 | 166 | 167 | def on_btn_record(self): 168 | # pause timer 169 | self.pause_timer = True 170 | 171 | if self.recording: 172 | self.recording = False 173 | else: 174 | self.start_recording() 175 | 176 | if self.recording: 177 | self.t = 0 # Reset our counter for the new recording 178 | self.record_button["text"] = "Stop" 179 | self.rate = self.sample_rate 180 | # make / open outfile 181 | self.outfile = open(self.outputDir+'/'+'data.csv', 'a') 182 | else: 183 | self.record_button["text"] = "Record" 184 | self.rate = self.idle_rate 185 | self.outfile.close() 186 | 187 | # un pause timer 188 | self.pause_timer = False 189 | self.on_timer() 190 | 191 | 192 | def start_recording(self): 193 | should_record = True 194 | 195 | # check that a dir has been specified 196 | if not self.outputDirStrVar.get(): 197 | tkMessageBox.showerror(title='Error', message='Specify the Output Directory', parent=self.root) 198 | should_record = False 199 | 200 | else: # a directory was specified 201 | self.outputDir = self.outputDirStrVar.get() 202 | 203 | # check if path exists - i.e. may be saving over data 204 | if os.path.exists(self.outputDir): 205 | 206 | # overwrite the data, yes/no? 207 | if tkMessageBox.askyesno(title='Warning!', message='Output Directory Exists - Overwrite Data?', parent=self.root): 208 | # delete & re-make the dir: 209 | shutil.rmtree(self.outputDir) 210 | os.mkdir(self.outputDir) 211 | 212 | # answer was 'no', so do not overwrite the data 213 | else: 214 | should_record = False 215 | self.txt_outputDir.focus_set() 216 | 217 | # directory doesn't exist, so make one 218 | else: 219 | os.mkdir(self.outputDir) 220 | 221 | self.recording = should_record 222 | 223 | 224 | if __name__ == '__main__': 225 | app = MainWindow() 226 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import array 5 | 6 | import numpy as np 7 | 8 | from skimage.color import rgb2gray 9 | from skimage.transform import resize 10 | from skimage.io import imread 11 | 12 | import matplotlib.pyplot as plt 13 | import matplotlib.image as mpimg 14 | 15 | from inputs import get_gamepad 16 | import math 17 | import threading 18 | 19 | 20 | def resize_image(img): 21 | im = resize(img, (Sample.IMG_H, Sample.IMG_W, Sample.IMG_D)) 22 | im_arr = im.reshape((Sample.IMG_H, Sample.IMG_W, Sample.IMG_D)) 23 | return im_arr 24 | 25 | 26 | class Screenshot(object): 27 | SRC_W = 640 28 | SRC_H = 480 29 | SRC_D = 3 30 | 31 | OFFSET_X = 0 32 | OFFSET_Y = 0 33 | 34 | 35 | class Sample: 36 | IMG_W = 200 37 | IMG_H = 66 38 | IMG_D = 3 39 | 40 | 41 | class XboxController(object): 42 | MAX_TRIG_VAL = math.pow(2, 8) 43 | MAX_JOY_VAL = math.pow(2, 15) 44 | 45 | def __init__(self): 46 | 47 | self.LeftJoystickY = 0 48 | self.LeftJoystickX = 0 49 | self.RightJoystickY = 0 50 | self.RightJoystickX = 0 51 | self.LeftTrigger = 0 52 | self.RightTrigger = 0 53 | self.LeftBumper = 0 54 | self.RightBumper = 0 55 | self.A = 0 56 | self.X = 0 57 | self.Y = 0 58 | self.B = 0 59 | self.LeftThumb = 0 60 | self.RightThumb = 0 61 | self.Back = 0 62 | self.Start = 0 63 | self.LeftDPad = 0 64 | self.RightDPad = 0 65 | self.UpDPad = 0 66 | self.DownDPad = 0 67 | 68 | self._monitor_thread = threading.Thread(target=self._monitor_controller, args=()) 69 | self._monitor_thread.daemon = True 70 | self._monitor_thread.start() 71 | 72 | 73 | def read(self): 74 | x = self.LeftJoystickX 75 | y = self.LeftJoystickY 76 | a = self.A 77 | b = self.X # b=1, x=2 78 | rb = self.RightBumper 79 | return [x, y, a, b, rb] 80 | 81 | 82 | def _monitor_controller(self): 83 | while True: 84 | events = get_gamepad() 85 | for event in events: 86 | if event.code == 'ABS_Y': 87 | self.LeftJoystickY = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 88 | elif event.code == 'ABS_X': 89 | self.LeftJoystickX = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 90 | elif event.code == 'ABS_RY': 91 | self.RightJoystickY = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 92 | elif event.code == 'ABS_RX': 93 | self.RightJoystickX = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 94 | elif event.code == 'ABS_Z': 95 | self.LeftTrigger = event.state / XboxController.MAX_TRIG_VAL # normalize between 0 and 1 96 | elif event.code == 'ABS_RZ': 97 | self.RightTrigger = event.state / XboxController.MAX_TRIG_VAL # normalize between 0 and 1 98 | elif event.code == 'BTN_TL': 99 | self.LeftBumper = event.state 100 | elif event.code == 'BTN_TR': 101 | self.RightBumper = event.state 102 | elif event.code == 'BTN_SOUTH': 103 | self.A = event.state 104 | elif event.code == 'BTN_NORTH': 105 | self.X = event.state 106 | elif event.code == 'BTN_WEST': 107 | self.Y = event.state 108 | elif event.code == 'BTN_EAST': 109 | self.B = event.state 110 | elif event.code == 'BTN_THUMBL': 111 | self.LeftThumb = event.state 112 | elif event.code == 'BTN_THUMBR': 113 | self.RightThumb = event.state 114 | elif event.code == 'BTN_SELECT': 115 | self.Back = event.state 116 | elif event.code == 'BTN_START': 117 | self.Start = event.state 118 | elif event.code == 'BTN_TRIGGER_HAPPY1': 119 | self.LeftDPad = event.state 120 | elif event.code == 'BTN_TRIGGER_HAPPY2': 121 | self.RightDPad = event.state 122 | elif event.code == 'BTN_TRIGGER_HAPPY3': 123 | self.UpDPad = event.state 124 | elif event.code == 'BTN_TRIGGER_HAPPY4': 125 | self.DownDPad = event.state 126 | 127 | 128 | class Data(object): 129 | def __init__(self): 130 | self._X = np.load("data/X.npy") 131 | self._y = np.load("data/y.npy") 132 | self._epochs_completed = 0 133 | self._index_in_epoch = 0 134 | self._num_examples = self._X.shape[0] 135 | 136 | @property 137 | def num_examples(self): 138 | return self._num_examples 139 | 140 | def next_batch(self, batch_size): 141 | start = self._index_in_epoch 142 | self._index_in_epoch += batch_size 143 | if self._index_in_epoch > self._num_examples: 144 | # Finished epoch 145 | self._epochs_completed += 1 146 | # Start next epoch 147 | start = 0 148 | self._index_in_epoch = batch_size 149 | assert batch_size <= self._num_examples 150 | end = self._index_in_epoch 151 | return self._X[start:end], self._y[start:end] 152 | 153 | 154 | def load_sample(sample): 155 | image_files = np.loadtxt(sample + '/data.csv', delimiter=',', dtype=str, usecols=(0,)) 156 | joystick_values = np.loadtxt(sample + '/data.csv', delimiter=',', usecols=(1,2,3,4,5)) 157 | return image_files, joystick_values 158 | 159 | def load_imgs(sample): 160 | image_files = np.loadtxt(sample + '/data.csv', delimiter=',', dtype=str, usecols=(0,)) 161 | return image_files 162 | 163 | # training data viewer 164 | def viewer(sample): 165 | image_files, joystick_values = load_sample(sample) 166 | 167 | plotData = [] 168 | 169 | plt.ion() 170 | plt.figure('viewer', figsize=(16, 6)) 171 | 172 | for i in range(len(image_files)): 173 | 174 | # joystick 175 | print(i, " ", joystick_values[i,:]) 176 | 177 | # format data 178 | plotData.append( joystick_values[i,:] ) 179 | if len(plotData) > 30: 180 | plotData.pop(0) 181 | x = np.asarray(plotData) 182 | 183 | # image (every 3rd) 184 | if (i % 3 == 0): 185 | plt.subplot(121) 186 | image_file = image_files[i] 187 | img = mpimg.imread(image_file) 188 | plt.imshow(img) 189 | 190 | # plot 191 | plt.subplot(122) 192 | plt.plot(range(i,i+len(plotData)), x[:,0], 'r') 193 | plt.hold(True) 194 | plt.plot(range(i,i+len(plotData)), x[:,1], 'b') 195 | plt.plot(range(i,i+len(plotData)), x[:,2], 'g') 196 | plt.plot(range(i,i+len(plotData)), x[:,3], 'k') 197 | plt.plot(range(i,i+len(plotData)), x[:,4], 'y') 198 | plt.draw() 199 | plt.hold(False) 200 | 201 | plt.pause(0.0001) # seconds 202 | i += 1 203 | 204 | 205 | # prepare training data 206 | def prepare(samples): 207 | print("Preparing data") 208 | 209 | for sample in samples: 210 | image_files = load_imgs(sample) 211 | num_samples += len(image_files) 212 | 213 | print(f"There are {num_samples} samples") 214 | 215 | X = np.empty(shape=(num_samples, Sample.IMG_H, Sample.IMG_W, 3), dtype=np.uint8) 216 | y = [] 217 | 218 | for idx, sample in enumerate(samples): 219 | print(sample) 220 | 221 | # load sample 222 | image_files, joystick_values = load_sample(sample) 223 | 224 | # add joystick values to y 225 | y.append(joystick_values) 226 | 227 | # load, prepare and add images to X 228 | for image_file in image_files: 229 | image = imread(image_file) 230 | vec = resize_image(image) 231 | X[idx] = vec 232 | 233 | print("Saving to file...") 234 | y = np.concatenate(y) 235 | 236 | np.save("data/X", X) 237 | np.save("data/y", y) 238 | 239 | print("Done!") 240 | return 241 | 242 | 243 | if __name__ == '__main__': 244 | if sys.argv[1] == 'viewer': 245 | viewer(sys.argv[2]) 246 | elif sys.argv[1] == 'prepare': 247 | prepare(sys.argv[2:]) 248 | --------------------------------------------------------------------------------