├── requirements.txt ├── .gitignore ├── main.py ├── predict.py ├── train.py ├── get_data.py ├── README.md ├── inspect_model.py ├── show.py ├── track.py └── get_model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | just 2 | tensorflow 3 | keras 4 | pynput 5 | cv2 6 | numpy 7 | matplotlib 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *#* 3 | *.DS_STORE 4 | *.log 5 | *Data.fs* 6 | *flymake* 7 | dist/* 8 | *egg* 9 | urllist* 10 | build/ 11 | __pycache__/ 12 | /.Python 13 | /bin/ 14 | /include/ 15 | /lib/ 16 | /pip-selfcheck.json 17 | .tox/ 18 | .cache 19 | .coverage 20 | .coverage.* 21 | .coveralls.yml 22 | models/* -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import train 3 | import track 4 | import predict 5 | 6 | command = sys.argv[1] 7 | args = sys.argv[2:] 8 | 9 | if command == "train": 10 | train.train(*args) 11 | elif command == "track": 12 | track.record(*args) 13 | elif command == "predict": 14 | predict.loop(*args) 15 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import get_model 2 | import get_data 3 | import track 4 | 5 | 6 | def loop(model_name): 7 | model = get_model.get_model(model_name) 8 | 9 | for image, _, _ in track.yield_images(): 10 | im = get_data.prep_images(image) 11 | x, y = model.predict(im)[0] 12 | print(x, y) 13 | # predict x, y and set mouse accordingly 14 | track.set_mouse_position(x, y) 15 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import get_data 2 | import get_model 3 | 4 | 5 | def train(data_name, model_name): 6 | X, y = get_data.get_training_xy("~/tracktrack/" + data_name + "/") 7 | model = get_model.get_model(model_name) 8 | 9 | try: 10 | model.fit(X[0::3], y[0::3], verbose=1, batch_size=64, nb_epoch=300, 11 | validation_data=(X[1::3], y[1::3])) 12 | except KeyboardInterrupt: 13 | pass 14 | 15 | get_model.save_model(model, model_name) 16 | -------------------------------------------------------------------------------- /get_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | 4 | import matplotlib.image 5 | import just 6 | 7 | 8 | def prep_images(images): 9 | if not isinstance(images, list): 10 | images = [images] 11 | X = np.array([scipy.misc.imresize(im, (72, 128, 3)) for im in images]) 12 | X = np.moveaxis(X, -1, 1) 13 | return X 14 | 15 | 16 | def get_training_xy(data_path="~/tracktrack/"): 17 | positions = list(just.iread(data_path + "positions.jsonl")) 18 | images = [matplotlib.image.imread(x) for x in just.glob(data_path + "im*.png")] 19 | m = min(len(images), len(positions)) 20 | 21 | X = prep_images(images[-m:]) 22 | positions = positions[-m:] 23 | y = np.array(positions) 24 | 25 | return X, y 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep Eye2Mouse 2 | 3 | Very Alpha. TODO: Next step should be to create a balanced dataset! 4 | 5 | ### Installation 6 | 7 | Install the stuff in requirements, opencv2 is tough, good luck with tensorflow and keras as well. 8 | 9 | ### track (to learn) 10 | 11 | ```bash 12 | python main.py track livin_room 13 | ``` 14 | 15 | Press "s" to start learning. Follow the mouse around and keep your eyes focused on it. 16 | Press "q" to quit (might be difficult :D just make it crash) 17 | 18 | ### train (model) 19 | 20 | ```bash 21 | python main.py train livin_room deep1 22 | ``` 23 | 24 | Trains on your data. 25 | 26 | ### predict (go live with the model) 27 | 28 | ```bash 29 | python main.py predict deep1 30 | ``` 31 | 32 | Press "s" to start predicting and moving the mouse around. 33 | Press "q" to quit (might be difficult :D just make it crash) 34 | -------------------------------------------------------------------------------- /inspect_model.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot 2 | import numpy as np 3 | import get_model 4 | import get_data 5 | 6 | X, y = get_data.get_training_xy() 7 | model = get_model.get_model("deep_eye2mouse") 8 | 9 | for i in range(3): 10 | print(i, np.mean((model.predict(X[i::3]) - y[i::3])**2)) 11 | 12 | for j in range(100): 13 | print(model.predict(X[i::3])[j]) 14 | print(y[i::3][j]) 15 | print() 16 | 17 | matplotlib.pyplot.scatter(model.predict(X[2::3])[:100, 0], 18 | model.predict(X[2::3])[:100, 1], c="red") 19 | matplotlib.pyplot.scatter(y[2::3][:100, 0], y[2::3][:100, 1], c="blue") 20 | for v in range(60): 21 | v = v + 100 22 | matplotlib.pyplot.scatter(model.predict(X[2::3])[v, 0], model.predict(X[2::3])[v, 1], c="red") 23 | matplotlib.pyplot.scatter(y[2::3][v, 0], y[2::3][v, 1], c="blue") 24 | plt.pause(0.05) 25 | -------------------------------------------------------------------------------- /show.py: -------------------------------------------------------------------------------- 1 | import time 2 | import just 3 | import numpy as np 4 | import cv2 5 | 6 | cap = cv2.VideoCapture(0) 7 | 8 | from pynput.mouse import Button, Controller 9 | 10 | mouse = Controller() 11 | 12 | it = 0 13 | until_val = 0 14 | t1 = time.time() 15 | while(True): 16 | # Capture frame-by-frame 17 | ret, frame = cap.read() 18 | # Our operations on the frame come here 19 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) 20 | # Display the resulting frame 21 | frame = cv2.flip(frame, 1) 22 | cv2.imshow('frame', frame) 23 | if cv2.waitKey(1) & 0xFF == ord('s'): 24 | until_val = it + 10000 25 | t2 = time.time() 26 | if t2 > t1 + 0.1 and it < until_val: 27 | cv2.imwrite("/Users/pascal/tracktrack/output.png", frame) 28 | prediction = just.read("/Users/pascal/tracktrack/output.json", no_exist=None) 29 | if prediction is not None: 30 | x, y = prediction 31 | mouse.position = (x, y) 32 | it += 1 33 | t1 = t2 34 | if cv2.waitKey(1) & 0xFF == ord('q'): 35 | break 36 | 37 | # When everything done, release the capture 38 | cap.release() 39 | cv2.destroyAllWindows() 40 | -------------------------------------------------------------------------------- /track.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import just 4 | import cv2 5 | 6 | from pynput.mouse import Controller 7 | 8 | interfaces = {"mouse": None, "video_cap": None} 9 | 10 | 11 | def get_interfaces(): 12 | if interfaces["mouse"] is None: 13 | interfaces["mouse"] = Controller() 14 | if interfaces["video_cap"] is None: 15 | interfaces["video_cap"] = cv2.VideoCapture(0) 16 | return interfaces["mouse"], interfaces["video_cap"] 17 | 18 | 19 | def set_mouse_position(x, y): 20 | interfaces["mouse"].position = (int(x), int(y)) 21 | 22 | 23 | def yield_images(interval=0.1): 24 | mouse, video_cap = get_interfaces() 25 | it = 0 26 | until_val = 0 27 | t1 = time.time() 28 | while True: 29 | # Capture frame-by-frame 30 | _, frame = video_cap.read() 31 | 32 | # flip horizontal 33 | frame = cv2.flip(frame, 1) 34 | 35 | cv2.imshow('frame', frame) 36 | 37 | if cv2.waitKey(1) & 0xFF == ord('s'): 38 | until_val = it + 1000 39 | 40 | t2 = time.time() 41 | if t2 > t1 + interval and it < until_val: 42 | yield frame, it, list(mouse.position) 43 | it += 1 44 | t1 = t2 45 | 46 | if cv2.waitKey(1) & 0xFF == ord('q'): 47 | break 48 | 49 | # When everything done, release the video_capture 50 | video_cap.release() 51 | cv2.destroyAllWindows() 52 | 53 | 54 | def record(data_name, data_path="~/tracktrack/"): 55 | path = just.make_path(data_path + data_name + "/") 56 | offset = len(just.glob(path + "/im*.png")) 57 | for image, it, mouse_pos in yield_images(): 58 | cv2.imwrite(path + "/im_{}.png".format(it + offset), image) 59 | just.append(mouse_pos, path + "/positions.jsonl") 60 | -------------------------------------------------------------------------------- /get_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import just 4 | from keras.layers import (Activation, Convolution2D, Dense, Dropout, Flatten, MaxPooling2D) 5 | from keras.models import Sequential 6 | from keras.models import model_from_json 7 | from keras.layers.normalization import BatchNormalization 8 | 9 | 10 | def get_model(model_name=None, models_path="models/"): 11 | if model_name: 12 | model_file = models_path + model_name + ".json" 13 | weights_file = models_path + model_name + ".h5" 14 | if os.path.isfile(model_file): 15 | model = model_from_json(json.dumps(just.read(model_file))) 16 | # load weights into new model 17 | model.load_weights(weights_file) 18 | print("Loaded model from disk") 19 | else: 20 | print("Cannot read model, creating fresh one") 21 | # Create the model 22 | model = Sequential() 23 | model.add(BatchNormalization(input_shape=(3, 72, 128))) 24 | model.add(Convolution2D(32, 3, 3, border_mode='same', activation='relu')) 25 | model.add(Activation('relu')) 26 | model.add(Dropout(0.15)) 27 | model.add(BatchNormalization()) 28 | model.add(Convolution2D(32, 3, 3, activation='relu', border_mode='same')) 29 | model.add(MaxPooling2D(pool_size=(2, 2))) 30 | model.add(Flatten()) 31 | model.add(Dense(512, activation='relu')) 32 | model.add(Dropout(0.5)) 33 | # to prediction 34 | model.add(Dense(2)) 35 | model.add(Activation('linear')) 36 | model.compile(loss='mean_squared_error', optimizer="adam") 37 | return model 38 | 39 | 40 | def save_model(model, model_name, models_path="models/"): 41 | # serialize model to JSON 42 | model_file = models_path + model_name + ".json" 43 | weights_file = models_path + model_name + ".h5" 44 | just.write(json.loads(model.to_json()), model_file) 45 | model.save_weights(weights_file) 46 | print("Saved model to disk") 47 | --------------------------------------------------------------------------------