├── images ├── up.png ├── down.png ├── image.png ├── model.png ├── text.png ├── images.png ├── number.png ├── string.png └── dictionary.png ├── web ├── lstm.png ├── step10.png ├── step11.png ├── step12.png ├── step13.png ├── step14.png ├── step16.png ├── step17.png ├── step2.png ├── step3.png ├── step4.png ├── step5.png ├── step6.png ├── step7.png ├── step8.png ├── step9.png ├── char-rnn.png ├── function.png ├── step1-1.png ├── step1-2.png ├── step15-1.png ├── step15-2.png └── neural-net.png ├── examples ├── brain.jpg ├── tron.png ├── circuits.jpg ├── robot_head.jpg ├── make_cat_movie ├── paint_brains ├── 10xcats ├── make_new_colors ├── make_superheroes ├── make_new_curses └── star_trek_novels ├── requirements.txt ├── run_program.js ├── license ├── module.py ├── hooks.py ├── easygen.py ├── file_manager.js ├── style_transfer.py ├── readWikipedia.py ├── stylegan_runner.py ├── Easygen.ipynb ├── gpt2.py ├── lstm.py ├── module_dicts.js ├── image_modules.py └── gui.js /images/up.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/up.png -------------------------------------------------------------------------------- /web/lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/lstm.png -------------------------------------------------------------------------------- /web/step10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step10.png -------------------------------------------------------------------------------- /web/step11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step11.png -------------------------------------------------------------------------------- /web/step12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step12.png -------------------------------------------------------------------------------- /web/step13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step13.png -------------------------------------------------------------------------------- /web/step14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step14.png -------------------------------------------------------------------------------- /web/step16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step16.png -------------------------------------------------------------------------------- /web/step17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step17.png -------------------------------------------------------------------------------- /web/step2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step2.png -------------------------------------------------------------------------------- /web/step3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step3.png -------------------------------------------------------------------------------- /web/step4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step4.png -------------------------------------------------------------------------------- /web/step5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step5.png -------------------------------------------------------------------------------- /web/step6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step6.png -------------------------------------------------------------------------------- /web/step7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step7.png -------------------------------------------------------------------------------- /web/step8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step8.png -------------------------------------------------------------------------------- /web/step9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step9.png -------------------------------------------------------------------------------- /images/down.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/down.png -------------------------------------------------------------------------------- /images/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/image.png -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/model.png -------------------------------------------------------------------------------- /images/text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/text.png -------------------------------------------------------------------------------- /web/char-rnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/char-rnn.png -------------------------------------------------------------------------------- /web/function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/function.png -------------------------------------------------------------------------------- /web/step1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step1-1.png -------------------------------------------------------------------------------- /web/step1-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step1-2.png -------------------------------------------------------------------------------- /web/step15-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step15-1.png -------------------------------------------------------------------------------- /web/step15-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/step15-2.png -------------------------------------------------------------------------------- /examples/brain.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/examples/brain.jpg -------------------------------------------------------------------------------- /examples/tron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/examples/tron.png -------------------------------------------------------------------------------- /images/images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/images.png -------------------------------------------------------------------------------- /images/number.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/number.png -------------------------------------------------------------------------------- /images/string.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/string.png -------------------------------------------------------------------------------- /web/neural-net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/web/neural-net.png -------------------------------------------------------------------------------- /examples/circuits.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/examples/circuits.jpg -------------------------------------------------------------------------------- /images/dictionary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/images/dictionary.png -------------------------------------------------------------------------------- /examples/robot_head.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/easygen/master/examples/robot_head.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | imageio 3 | beautifulsoup4 4 | tensorflow 5 | fire>=0.1.3 6 | regex==2017.4.5 7 | requests==2.21.0 8 | tqdm==4.31.1 9 | toposort==1.5 10 | pyspellchecker 11 | selenium 12 | autocrop 13 | pillow>5 -------------------------------------------------------------------------------- /examples/make_cat_movie: -------------------------------------------------------------------------------- 1 | [{"module":"LoadModel","name":"LoadModel (1)","x":0,"y":0,"id":0,"file":"cats256x256.pkl","model":"cache/model0-1"},{"module":"StyleGAN_Movie","name":"StyleGAN_Movie (1)","x":220,"y":0,"id":1,"model":"cache/model0-1","length":"20","interp":"20","duration":"100","movie":"cache/images1-4"},{"module":"SaveImages","name":"SaveImages (1)","x":440,"y":0,"id":2,"images":"cache/images1-4","directory":"catmovie"}] -------------------------------------------------------------------------------- /run_program.js: -------------------------------------------------------------------------------- 1 | function run_program() { 2 | var input_box = document.getElementById('inp_run'); 3 | var path = input_box.value; 4 | async function foo() { 5 | console.log(path); 6 | const result = await google.colab.kernel.invokeFunction( 7 | 'notebook.python_run_hook', // The callback name. 8 | [path], // The arguments. 9 | {}); // kwargs 10 | return result; 11 | }; 12 | foo().then(function(value) {}); 13 | } -------------------------------------------------------------------------------- /examples/paint_brains: -------------------------------------------------------------------------------- 1 | [{"module":"LoadImages","name":"LoadImages (1)","x":0,"y":0,"id":0,"directory":"easygen/examples/brain.jpg","images":"cache/images0-1"},{"module":"LoadImages","name":"LoadImages (2)","x":16,"y":151,"id":4,"directory":"easygen/examples/circuits.jpg","images":"cache/images4-1"},{"module":"LoadImages","name":"LoadImages (3)","x":22,"y":302,"id":5,"directory":"easygen/examples/tron.png","images":"cache/images5-1"},{"module":"JoinImageDirectories","name":"JoinImageDirectories (1)","x":300,"y":150,"id":6,"dir1":"cache/images4-1","dir2":"cache/images5-1","output":"cache/images6-2"},{"module":"StyleTransfer","name":"StyleTransfer (1)","x":577,"y":8,"id":2,"content_image":"cache/images0-1","style_image":"cache/images6-2","steps":"1000","size":"512","style_weight":"1000000","content_weight":"1","content_layers":"1, 2, 3, 4, 5","style_layers":"1, 2, 3, 4, 5","output":"cache/images2-8"},{"module":"SaveImages","name":"SaveImages (1)","x":813,"y":165,"id":3,"images":"cache/images2-8","directory":"new_images"}] -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | Copyright 2017 Mark O. Riedl 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import re 4 | import shutil 5 | 6 | ######################## 7 | ### GLOBALS 8 | 9 | NEWLINE = '\n' 10 | INFINITY = float("INFINITY") 11 | 12 | 13 | ######################## 14 | 15 | 16 | def checkFiles(*filenames): 17 | for filename in filenames: 18 | if not os.path.exists(filename): 19 | return False 20 | return True 21 | 22 | 23 | ######################### 24 | 25 | def convertHexToASCII(str): 26 | return re.sub(r'\%([A-Z0-9][A-Z0-9])', lambda match: "{0}".format(bytes.fromhex(match.group(1)).decode("utf-8")), str) 27 | 28 | ######################### 29 | 30 | def prep_output_dir(path, makedir=True): 31 | # Does the directory already exist? 32 | if os.path.exists(path): 33 | # it does exist... delete 34 | if os.path.isdir(path): 35 | shutil.rmtree(path) 36 | else: 37 | os.remove(path) 38 | # Make the directory 39 | if makedir: 40 | os.mkdir(path) 41 | 42 | ####################### 43 | ### MODULE BASE CLASS 44 | 45 | class Module(): 46 | 47 | ready = True 48 | output_files = [] 49 | 50 | def run(): 51 | pass 52 | -------------------------------------------------------------------------------- /examples/10xcats: -------------------------------------------------------------------------------- 1 | [{"module":"LoadModel","name":"LoadModel (1)","x":0,"y":0,"id":0,"file":"cats256x256.pkl","model":"cache/model0-1"},{"module":"LoadImages","name":"LoadImages (1)","x":20,"y":152,"id":1,"directory":"easygen/examples/robot_head.jpg","images":"cache/images1-1"},{"module":"ResizeImages","name":"ResizeImages (1)","x":249,"y":181,"id":2,"input":"cache/images1-1","size":"256","output":"cache/images2-2"},{"module":"StyleGAN_FineTune","name":"StyleGAN_FineTune (1)","x":488,"y":8,"id":3,"model_in":"cache/model0-1","images":"cache/images2-2","start_kimg":"7000","max_kimg":"7014","seed":"1000","schedule":"2","model_out":"cache/model3-6","animation":"cache/images3-7"},{"module":"StyleGAN_Run","name":"StyleGAN_Run (1)","x":764,"y":28,"id":4,"model":"cache/model3-6","num":"20","images":"cache/images4-2"},{"module":"StyleGAN_Movie","name":"StyleGAN_Movie (1)","x":1038,"y":172,"id":5,"model":"cache/model3-6","length":"20","interp":"20","duration":"100","movie":"cache/images5-4"},{"module":"MakeMovie","name":"MakeMovie (1)","x":786,"y":370,"id":6,"images":"cache/images3-7","duration":"150","movie":"cache/images6-2"},{"module":"SaveImages","name":"SaveImages (1)","x":1255,"y":8,"id":7,"images":"cache/images4-2","directory":"new_cats"},{"module":"SaveImages","name":"SaveImages (2)","x":1275,"y":318,"id":8,"images":"cache/images5-4","directory":"cats_anim"},{"module":"SaveImages","name":"SaveImages (3)","x":1038,"y":448,"id":9,"images":"cache/images6-2","directory":"morph_anim"}] -------------------------------------------------------------------------------- /examples/make_new_colors: -------------------------------------------------------------------------------- 1 | [{"module":"ReadTextFile","name":"ReadTextFile (1)","x":0,"y":0,"id":0,"file":"easygen/examples/colors","output":"cache/text0-1"},{"module":"Regex_Search","name":"Regex_Search (1)","x":220,"y":41,"id":1,"input":"cache/text0-1","expression":"\"name\"[ ]*:[ ]*\"([\\w\\W]+?)\"","output":"cache/output","group_1":"cache/text1-3","group_2":"cache/group_2"},{"module":"MakeLowercase","name":"MakeLowercase (1)","x":452,"y":176,"id":2,"input":"cache/text1-3","output":"cache/text2-1"},{"module":"CharRNN_Train","name":"CharRNN_Train (1)","x":698,"y":-1,"id":3,"data":"cache/text2-1","history":"10","layers":"2","hidden_nodes":"64","epochs":"150","learning_rate":"0.0001","model":"cache/model3-6","dictionary":"cache/dictionary3-7"},{"module":"RandomSequence","name":"RandomSequence (1)","x":694,"y":415,"id":6,"input":"cache/text2-1","length":"10","output":"cache/text6-2"},{"module":"WriteTextFile","name":"WriteTextFile (4)","x":441,"y":406,"id":11,"input":"cache/text2-1","file":"training_data"},{"module":"CharRNN_Run","name":"CharRNN_Run (1)","x":948,"y":216,"id":4,"model":"cache/model3-6","dictionary":"cache/dictionary3-7","seed":"cache/text6-2","steps":"6000","temperature":"1.0","output":"cache/text4-5"},{"module":"WriteTextFile","name":"WriteTextFile (1)","x":1219,"y":241,"id":5,"input":"cache/text4-5","file":"new_colors"},{"module":"TextSubtract","name":"TextSubtract (1)","x":1220,"y":421,"id":7,"main":"cache/text4-5","subtract":"cache/text6-2","diff":"cache/text7-2"},{"module":"WriteTextFile","name":"WriteTextFile (2)","x":1479,"y":457,"id":8,"input":"cache/text7-2","file":"new_colors2"},{"module":"Spellcheck","name":"Spellcheck (1)","x":1478,"y":257,"id":9,"input":"cache/text7-2","output":"cache/text9-1"},{"module":"WriteTextFile","name":"WriteTextFile (3)","x":1716,"y":297,"id":10,"input":"cache/text9-1","file":"my_output_file"}] -------------------------------------------------------------------------------- /examples/make_superheroes: -------------------------------------------------------------------------------- 1 | [{"module":"ReadWikipedia","name":"ReadWikipedia (1)","x":0,"y":0,"id":0,"wiki_directory":"wiki","pattern":"*","categories":"superhero|superheroes|supervillain|supervillains|transformers","out_file":"cache/out_file","titles_file":"cache/text0-4"},{"module":"Regex_Sub","name":"Regex_Sub (1)","x":234,"y":3,"id":1,"input":"cache/text0-4","expression":"\\([\\w\\W]+?\\)","replacement":"","output":"cache/text1-3"},{"module":"WriteTextFile","name":"WriteTextFile (1)","x":235,"y":253,"id":11,"input":"cache/text0-4","file":"raw_data"},{"module":"MakeLowercase","name":"MakeLowercase (1)","x":449,"y":86,"id":2,"input":"cache/text1-3","output":"cache/text2-1"},{"module":"RandomizeLines","name":"RandomizeLines (1)","x":670,"y":141,"id":3,"input":"cache/text2-1","output":"cache/text3-1"},{"module":"WriteTextFile","name":"WriteTextFile (2)","x":659,"y":372,"id":4,"input":"cache/text3-1","file":"training_data"},{"module":"CharRNN_Train","name":"CharRNN_Train (1)","x":893,"y":-6,"id":5,"data":"cache/text3-1","history":"10","layers":"2","hidden_nodes":"64","epochs":"150","learning_rate":"0.0001","model":"cache/model5-6","dictionary":"cache/dictionary5-7"},{"module":"RandomSequence","name":"RandomSequence (1)","x":905,"y":404,"id":7,"input":"cache/text3-1","length":"10","output":"cache/text7-2"},{"module":"CharRNN_Run","name":"CharRNN_Run (1)","x":1188,"y":282,"id":6,"model":"cache/model5-6","dictionary":"cache/dictionary5-7","seed":"cache/text7-2","steps":"6000","temperature":"0.5","output":"cache/text6-5"},{"module":"TextSubtract","name":"TextSubtract (1)","x":1476,"y":135,"id":8,"main":"cache/text6-5","subtract":"cache/text3-1","diff":"cache/text8-2"},{"module":"WriteTextFile","name":"WriteTextFile (3)","x":1472,"y":417,"id":9,"input":"cache/text6-5","file":"new_names1"},{"module":"WriteTextFile","name":"WriteTextFile (4)","x":1747,"y":234,"id":10,"input":"cache/text8-2","file":"my_output_file"}] -------------------------------------------------------------------------------- /examples/make_new_curses: -------------------------------------------------------------------------------- 1 | [{"module":"WebCrawl","name":"WebCrawl (1)","x":10,"y":1,"id":0,"url":"https://en.wiktionary.org/wiki/Category:English_vulgarities","link_id":"","link_text":"next page","max_hops":"50","output":"cache/text0-4"},{"module":"Regex_Search","name":"Regex_Search (1)","x":228,"y":0,"id":1,"input":"cache/text0-4","expression":"
  • ([a-zA-Z0-9\\%\\'\\_\\#\\- ]+?)
  • ","output":"cache/output","group_1":"cache/text1-3","group_2":"cache/group_2"},{"module":"Regex_Search","name":"Regex_Search (2)","x":225,"y":288,"id":4,"input":"cache/text0-4","expression":"
  • ([a-zA-Z0-9\\%\\'\\_\\#\\- ]+?)
  • ","output":"cache/output","group_1":"cache/text4-3","group_2":"cache/group_2"},{"module":"ConcatenateTextFiles","name":"ConcatenateTextFiles (1)","x":451,"y":176,"id":6,"input_1":"cache/text1-3","input_2":"cache/text4-3","output":"cache/text6-2"},{"module":"RemoveDuplicates","name":"RemoveDuplicates (1)","x":672,"y":51,"id":7,"input":"cache/text6-2","output":"cache/text7-1"},{"module":"MakeLowercase","name":"MakeLowercase (1)","x":676,"y":206,"id":8,"input":"cache/text7-1","output":"cache/text8-1"},{"module":"WriteTextFile","name":"WriteTextFile (1)","x":678,"y":370,"id":9,"input":"cache/text8-1","file":"training_data"},{"module":"CharRNN_Train","name":"CharRNN_Train (1)","x":910,"y":6,"id":10,"data":"cache/text8-1","history":"10","layers":"2","hidden_nodes":"64","epochs":"100","learning_rate":"0.0001","model":"cache/model10-6","dictionary":"cache/dictionary10-7"},{"module":"RandomSequence","name":"RandomSequence (1)","x":914,"y":423,"id":12,"input":"cache/text8-1","length":"10","output":"cache/text12-2"},{"module":"CharRNN_Run","name":"CharRNN_Run (1)","x":1152,"y":297,"id":11,"model":"cache/model10-6","dictionary":"cache/dictionary10-7","seed":"cache/text12-2","steps":"6000","temperature":"1.0","output":"cache/text11-5"},{"module":"TextSubtract","name":"TextSubtract (1)","x":1374,"y":131,"id":13,"main":"cache/text11-5","subtract":"cache/text8-1","diff":"cache/text13-2"},{"module":"WriteTextFile","name":"WriteTextFile (2)","x":1379,"y":343,"id":14,"input":"cache/text13-2","file":"my_output_file"}] -------------------------------------------------------------------------------- /examples/star_trek_novels: -------------------------------------------------------------------------------- 1 | [{"module":"MakeCountFile","name":"MakeCountFile (1)","x":0,"y":0,"id":0,"num":"50","prefix":"https://www.barnesandnoble.com/b/books/romance/historical-romance/_/N-29Z8q8Z17yg?Nrpp=40&page=","postfix":"","output":"cache/text0-3"},{"module":"MakeCountFile","name":"MakeCountFile (2)","x":2,"y":234,"id":1,"num":"50","prefix":"https://www.barnesandnoble.com/b/books/science-fiction-fantasy/star-trek-fiction/_/N-29Z8q8Z182c?Nrpp=40&page=","postfix":"","output":"cache/text1-3"},{"module":"ReadAllFromWeb","name":"ReadAllFromWeb (1)","x":240,"y":24,"id":2,"urls":"cache/text0-3","data":"cache/text2-1"},{"module":"ReadAllFromWeb","name":"ReadAllFromWeb (2)","x":216,"y":287,"id":3,"urls":"cache/text1-3","data":"cache/text3-1"},{"module":"ConcatenateTextFiles","name":"ConcatenateTextFiles (1)","x":443,"y":142,"id":4,"input_1":"cache/text2-1","input_2":"cache/text3-1","output":"cache/text4-2"},{"module":"Regex_Search","name":"Regex_Search (1)","x":687,"y":13,"id":5,"input":"cache/text4-2","expression":"Title: ([\\w\\W ]+?), Author:","output":"cache/output","group_1":"cache/text5-3","group_2":"cache/group_2"},{"module":"MakeLowercase","name":"MakeLowercase (1)","x":680,"y":301,"id":6,"input":"cache/text5-3","output":"cache/text6-1"},{"module":"Regex_Sub","name":"Regex_Sub (1)","x":935,"y":9,"id":7,"input":"cache/text6-1","expression":"star trek[: \\#0-9]*","replacement":"","output":"cache/text7-3"},{"module":"Regex_Sub","name":"Regex_Sub (2)","x":942,"y":246,"id":17,"input":"cache/text7-3","expression":"\\([\\w\\W]+?\\)","replacement":"","output":"cache/text17-3"},{"module":"RemoveEmptyLines","name":"RemoveEmptyLines (1)","x":943,"y":475,"id":18,"input":"cache/text17-3","output":"cache/text18-1"},{"module":"RemoveDuplicates","name":"RemoveDuplicates (1)","x":1184,"y":16,"id":9,"input":"cache/text18-1","output":"cache/text9-1"},{"module":"RandomizeLines","name":"RandomizeLines (1)","x":1193,"y":180,"id":10,"input":"cache/text9-1","output":"cache/text10-1"},{"module":"WriteTextFile","name":"WriteTextFile (1)","x":1191,"y":407,"id":11,"input":"cache/text10-1","file":"training_data"},{"module":"CharRNN_Train","name":"CharRNN_Train (1)","x":1432,"y":4,"id":12,"data":"cache/text10-1","history":"10","layers":"2","hidden_nodes":"64","epochs":"150","learning_rate":"0.0001","model":"cache/model12-6","dictionary":"cache/dictionary12-7"},{"module":"RandomSequence","name":"RandomSequence (1)","x":1429,"y":415,"id":14,"input":"cache/text10-1","length":"10","output":"cache/text14-2"},{"module":"CharRNN_Run","name":"CharRNN_Run (1)","x":1640,"y":269,"id":13,"model":"cache/model12-6","dictionary":"cache/dictionary12-7","seed":"cache/text14-2","steps":"6000","temperature":"1.0","output":"cache/text13-5"},{"module":"TextSubtract","name":"TextSubtract (1)","x":1848,"y":48,"id":15,"main":"cache/text13-5","subtract":"cache/text10-1","diff":"cache/text15-2"},{"module":"Spellcheck","name":"Spellcheck (1)","x":1845,"y":248,"id":19,"input":"cache/text15-2","output":"cache/text19-1"},{"module":"WriteTextFile","name":"WriteTextFile (2)","x":1843,"y":413,"id":20,"input":"cache/text19-1","file":"my_output_file"}] -------------------------------------------------------------------------------- /hooks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | from PIL import Image 5 | from IPython.display import display 6 | 7 | TRASH_PATH = '/content/.trash' 8 | 9 | def python_cwd_hook_aux(dir): 10 | result = {} 11 | for file in os.listdir(dir): 12 | path = os.path.join(dir, file) 13 | if os.path.isdir(path): 14 | file = file + '/' 15 | result[file] = path 16 | result['./'] = os.getcwd() 17 | if dir != '/': 18 | parent_dir = str(Path(dir).parent) 19 | result['../'] = parent_dir 20 | return result 21 | 22 | def python_move_hook_aux(path1, path2): 23 | status = False 24 | if os.path.exists(path1): 25 | shutil.move(path1, path2) 26 | status = True 27 | return status 28 | 29 | def python_copy_hook_aux(path1, path2): 30 | status = False 31 | if os.path.exists(path1): 32 | # path 1 exists 33 | if os.path.isdir(path1): 34 | # path1 is a directory 35 | if os.path.exists(path2): 36 | # path2 exists 37 | if os.path.isdir(path2): 38 | # copying directory to selected directory 39 | # make new directory inside with same name as path1 40 | basename1 = os.path.basename(os.path.normpath(path1)) 41 | path2 = os.path.join(path2, basename1) 42 | shutil.copytree(path1, path2) 43 | status = True 44 | else: 45 | # copy directory to a file 46 | # can't do that 47 | print("cannot copy directory inside a file") 48 | else: 49 | # path1 is a file 50 | # doesn't matter if path2 is a file or directory 51 | shutil.copy(path1, path2) 52 | status = True 53 | return status 54 | 55 | def python_open_text_hook_aux(path): 56 | status = False 57 | if os.path.exists(path) and not os.path.isdir(path): 58 | with open(path, 'r') as file: 59 | try: 60 | text = file.read() 61 | print(text) 62 | status = True 63 | except: 64 | print("Cannot read text file", path) 65 | return status 66 | 67 | def python_open_image_hook_aux(path): 68 | status = False 69 | if os.path.exists(path) and not os.path.isdir(path): 70 | try: 71 | pil_im = Image.open(path, 'r') 72 | display(pil_im) 73 | status = True 74 | except: 75 | print("Cannot open image file", path) 76 | return status 77 | 78 | def python_save_hook_aux(file_text, filename): 79 | status = False 80 | with open(filename, 'w') as f: 81 | try: 82 | f.write(file_text) 83 | status = True 84 | except: 85 | print("Could not write to", filename) 86 | return status 87 | 88 | def python_load_hook_aux(filename): 89 | file_text = '' 90 | with open(filename, 'r') as f: 91 | try: 92 | file_text = f.read() 93 | status = True 94 | except: 95 | print("could not write to", filename) 96 | return file_text 97 | 98 | def python_mkdir_hook_aux(path, dir_name): 99 | status = False 100 | try: 101 | os.mkdir(os.path.join(path, dir_name)) 102 | status = True 103 | except: 104 | print("Could not create directory " + dir_name + " in " + path) 105 | return status 106 | 107 | def python_trash_hook_aux(path): 108 | status = False 109 | try: 110 | if not os.path.exists(TRASH_PATH): 111 | os.mkdir(TRASH_PATH) 112 | if os.path.exists(path): 113 | shutil.move(path, TRASH_PATH) 114 | status = True 115 | except: 116 | print("Could not move " + path + " to " + TRASH_PATH) 117 | return status 118 | -------------------------------------------------------------------------------- /easygen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import copy 4 | from modules import * 5 | from image_modules import * 6 | import pdb 7 | import unicodedata 8 | import argparse 9 | 10 | TEMP_DIRECTORY = './cache' 11 | HISTORY_PATH = os.path.join(TEMP_DIRECTORY, '.history') 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Run an easygen program.') 15 | parser.add_argument('program', help="the file containing the easgen program.") 16 | parser.add_argument('--clear_cache', help="Clear the cache", action='store_true') 17 | 18 | # Inputs: 19 | # - json description of the module 20 | # - dictionary of history for caching 21 | # Return 2 values: 22 | # - Whether the module was run (could have not run because of caching) 23 | # - The paths to files that were produced as outputs of the module (or none if the module wasn't ready) 24 | def runModule(module_json, history = {}): 25 | module_json_copy = copy.deepcopy(module_json) 26 | if 'module' in module_json_copy: 27 | module = module_json_copy['module'] 28 | ## Convert the json to a set of parameters to pass into a class of the same name as the module name 29 | ## Take the module name out 30 | del module_json_copy['module'] 31 | ## Take out the x and y 32 | if 'x' in module_json_copy: 33 | del module_json_copy['x'] 34 | if 'y' in module_json_copy: 35 | del module_json_copy['y'] 36 | if 'collapsed' in module_json_copy: 37 | del module_json_copy['collapsed'] 38 | ## Take out the module name and id 39 | if 'name' in module_json_copy: 40 | del module_json_copy['name'] 41 | if 'id' in module_json_copy: 42 | del module_json_copy['id'] 43 | 44 | rest = str(module_json_copy) 45 | 46 | params = rest.replace('{', '').replace('}', '') 47 | #params = re.sub(r'u\'', '', params) 48 | #params = re.sub(r'\'', '', params) 49 | #params = re.sub(r': ', '=', params) 50 | p1 = re.compile(r'\'([0-9a-zA-Z\_]+)\':[\s]*(\'[\<\>\(\)0-9a-zA-Z\_\.\,\/\:\*\-\?\+\=\#\&\@\%\\\[\]\|\"\^ ]*\')') 51 | params = p1.sub(r"\1=\2", params) 52 | params = re.sub(r'\'True\'', 'True', params) 53 | params = re.sub(r'\'true\'', 'True', params) 54 | params = re.sub(r'\'False\'', 'False', params) 55 | params = re.sub(r'\'false\'', 'False', params) 56 | p2 = re.compile(r'\'([0-9]*[\.]*[0-9]+)\'') 57 | params = p2.sub(r'\1', params) 58 | ## Put the module name back on as class name 59 | evalString = module + '(' + convertHexToASCII(params) + ')' 60 | print("Module:", evalString) 61 | ## If everything went well, we can now evaluate the string and create a new class. 62 | module = eval(evalString) 63 | ## Do we need to run it? 64 | done_files = [] # The files that are in temp and generated by an identical process 65 | for key in history.keys(): 66 | # key is a filename 67 | value = history[key] # string descriptor of the process that generated the file 68 | # Is the descriptor the same and does the file exist? 69 | if str(value) == str(module_json) and os.path.exists(key): 70 | # and if it's a directory is it non-empty? 71 | if os.path.isfile(key) or (os.path.isdir(key) and len(os.listdir(key)) > 0): 72 | # It is! 73 | done_files.append(key) 74 | output_files = module.output_files # These are the files we would expect to see 75 | # Check if all the output files are already done 76 | if len(set(output_files).difference(set(done_files))) == 0: 77 | # All the files we were going to output have been generated by an identical process 78 | print("Using cached data.") 79 | return False, output_files 80 | else: 81 | ## Run the class. 82 | if module.ready: 83 | print("Running...") 84 | module.run() 85 | print("Done.") 86 | return True, output_files 87 | else: 88 | print("Module not ready.") 89 | return False, None 90 | 91 | def main(program_path): 92 | ### Make sure required directories exist 93 | if not os.path.exists(TEMP_DIRECTORY): 94 | os.makedirs(TEMP_DIRECTORY) 95 | 96 | ### Read the history file 97 | history = {} 98 | if os.path.exists(HISTORY_PATH): 99 | with open(HISTORY_PATH, 'r') as historyfile: 100 | history_text = historyfile.read().strip() 101 | if len(history_text) > 0: 102 | history = eval(history_text) 103 | 104 | ### Read in the program file 105 | data_text = '' 106 | data = None 107 | for line in open(program_path, "r"): 108 | data_text = data_text + line.strip() 109 | 110 | ### Convert to json dictionary 111 | data = json.loads(data_text) 112 | #data = byteify(data) 113 | 114 | use_caching = True # Should we use the caching system? 115 | 116 | ### Run each module 117 | if data is not None: 118 | print("Running", program_path) 119 | for d in data: 120 | # Run the module! 121 | executed, output_files = runModule(d, history if use_caching else {}) 122 | # If the module ran and produced output_files then we shouldn't use caching anymore 123 | # If the module didn't run and output_files is None then it wasn't ready 124 | # and we are probably dead in the water 125 | # If the module didn't run and produced output_files then it is drawing from the cache, 126 | # keep doing so 127 | if (executed and output_files is not None) or (not executed and output_files is None): 128 | use_caching = False 129 | # Update the history 130 | if output_files is not None: 131 | for file in output_files: 132 | history[file] = d 133 | 134 | # Write the history to file 135 | with open(HISTORY_PATH, 'w') as historyfile: 136 | # Clean up the history file 137 | history_copy = copy.deepcopy(history) 138 | for file in history_copy.keys(): 139 | if not os.path.exists(file): 140 | del history[file] 141 | # Now write 142 | historyfile.write(str(history)) 143 | else: 144 | # No data to run 145 | print("Program is empty") 146 | 147 | if __name__ == '__main__': 148 | args = parser.parse_args() 149 | if args.clear_cache: 150 | shutil.rmtree(TEMP_DIRECTORY) 151 | main(args.program) -------------------------------------------------------------------------------- /file_manager.js: -------------------------------------------------------------------------------- 1 | //////////////////////////// 2 | // GLOBALS 3 | 4 | var path1 = '/content' // the cwd of the first file list box 5 | var path2 = '/content' // the cwd of the second file list box 6 | var selected1 = '/content' // the path to a file selected in the first file list box 7 | var selected2 = '/content' // the path to a file selected in the second file list box 8 | 9 | // Call python and get a dictionary containing file names and their paths 10 | function get_files(path, list_id) { 11 | async function foo() { 12 | console.log(path); 13 | const result = await google.colab.kernel.invokeFunction( 14 | 'notebook.python_cwd_hook', // The callback name. 15 | [path], // The arguments. 16 | {}); // kwargs 17 | return result; 18 | }; 19 | foo().then(function(value) { 20 | // parse the return value 21 | var returned = value.data['application/json']; 22 | var dict = eval(returned.result); // dictionary of filenames and full paths 23 | var file_list = document.getElementById(list_id); // list box html element 24 | // Clear the list box 25 | removeOptions(file_list); 26 | var files = []; // filenames 27 | var key; 28 | // Move filesnames (keys) out of dictionary into files list 29 | for (key in dict) { 30 | files.push(key); 31 | } 32 | // Sort the files 33 | var sorted_files = files.sort(); // the sorted file list 34 | // But make sure . is at the top of the list 35 | var files_temp = [] 36 | files_temp.push('./') 37 | var i; 38 | for (i = 0; i < sorted_files.length; i++) { 39 | if (sorted_files[i] !== './') { 40 | files_temp.push(sorted_files[i]) 41 | } 42 | } 43 | sorted_files = files_temp; 44 | // ASSERT: files are sorted and . is at the top of the list 45 | // Populate the list box 46 | var i; 47 | for (i = 0; i < sorted_files.length; i++) { 48 | var file = sorted_files[i]; 49 | var val_path = dict[file]; 50 | var opt = document.createElement('option'); 51 | opt.value = val_path; 52 | opt.innerHTML = file; 53 | file_list.appendChild(opt); 54 | } 55 | }); 56 | // ASSERT: nothing after here guaranteed to be executed before foo returns 57 | } 58 | 59 | // Remove all options from a select box 60 | function removeOptions(selectbox) { 61 | var i; 62 | for(i = selectbox.options.length - 1 ; i >= 0 ; i--) { 63 | selectbox.remove(i); 64 | } 65 | } 66 | 67 | 68 | 69 | // Set up the select boxes with double_click callbacks 70 | var file_list1 = document.getElementById("file_list1"); 71 | file_list1.ondblclick = function(){ 72 | var filename = this.options[this.selectedIndex].innerHTML; 73 | var path = this.options[this.selectedIndex].value; 74 | if (filename[filename.length-1] === "/") { 75 | // this is a directory 76 | path1 = path; 77 | selected1 = path; 78 | update_gui(path, "path1", "file_list1"); 79 | } 80 | }; 81 | file_list1.onclick = function() { 82 | var filename = this.options[this.selectedIndex].innerHTML; 83 | var path = this.options[this.selectedIndex].value; 84 | selected1 = path; 85 | }; 86 | 87 | var file_list2 = document.getElementById("file_list2"); 88 | file_list2.ondblclick = function(){ 89 | var filename = this.options[this.selectedIndex].innerHTML; 90 | var path = this.options[this.selectedIndex].value; 91 | if (filename[filename.length-1] === "/") { 92 | // this is a directory 93 | path2 = path; 94 | selected2 = path; 95 | update_gui(path, "path2", "file_list2"); 96 | } 97 | }; 98 | file_list2.onclick = function() { 99 | var filename = this.options[this.selectedIndex].innerHTML; 100 | var path = this.options[this.selectedIndex].value; 101 | selected2 = path; 102 | }; 103 | 104 | // update the gui if a path has changed 105 | function update_gui(path, dir_id, list_id) { 106 | var path_text = document.getElementById(dir_id); 107 | path_text.innerHTML = path; 108 | get_files(path, list_id); 109 | } 110 | 111 | // Copy button 112 | function do_copy_mouse_up() { 113 | (async function() { 114 | const result = await google.colab.kernel.invokeFunction( 115 | 'notebook.python_copy_hook', // The callback name. 116 | [selected1, selected2], // The arguments. 117 | {}); // kwargs 118 | const res = result.data['application/json']; 119 | })(); 120 | update_gui(path1, "path1", "file_list1") 121 | update_gui(path2, "path2", "file_list2") 122 | } 123 | 124 | // Move button 125 | function do_move_mouse_up() { 126 | (async function() { 127 | const result = await google.colab.kernel.invokeFunction( 128 | 'notebook.python_move_hook', // The callback name. 129 | [selected1, selected2], // The arguments. 130 | {}); // kwargs 131 | const res = result.data['application/json']; 132 | })(); 133 | update_gui(path1, "path1", "file_list1") 134 | update_gui(path2, "path2", "file_list2") 135 | } 136 | 137 | // Open text button 138 | function do_open_text_mouse_up() { 139 | (async function() { 140 | const result = await google.colab.kernel.invokeFunction( 141 | 'notebook.python_open_text_hook', // The callback name. 142 | [selected1], // The arguments. 143 | {}); // kwargs 144 | const res = result.data['application/json']; 145 | })(); 146 | } 147 | 148 | // Open image button 149 | function do_open_image_mouse_up() { 150 | (async function() { 151 | const result = await google.colab.kernel.invokeFunction( 152 | 'notebook.python_open_image_hook', // The callback name. 153 | [selected1], // The arguments. 154 | {}); // kwargs 155 | const res = result.data['application/json']; 156 | })(); 157 | } 158 | 159 | function do_mkdir_mouse_up() { 160 | var input_box = document.getElementById('mkdir_input'); 161 | dir_name = input_box.value; 162 | (async function() { 163 | const result = await google.colab.kernel.invokeFunction( 164 | 'notebook.python_mkdir_hook', // The callback name. 165 | [selected1, dir_name], // The arguments. 166 | {}); // kwargs 167 | const res = result.data['application/json']; 168 | })(); 169 | update_gui(path1, "path1", "file_list1") 170 | update_gui(path2, "path2", "file_list2") 171 | } 172 | 173 | function do_trash_mouse_up() { 174 | (async function() { 175 | const result = await google.colab.kernel.invokeFunction( 176 | 'notebook.python_trash_hook', // The callback name. 177 | [selected1], // The arguments. 178 | {}); // kwargs 179 | const res = result.data['application/json']; 180 | })(); 181 | update_gui(path1, "path1", "file_list1") 182 | update_gui(path2, "path2", "file_list2") 183 | } 184 | 185 | // GO 186 | update_gui(path1, "path1", "file_list1") 187 | update_gui(path2, "path2", "file_list2") -------------------------------------------------------------------------------- /style_transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | import PIL 7 | from PIL import Image 8 | import matplotlib.pyplot as plt 9 | 10 | import torchvision.transforms as transforms 11 | import torchvision.models as models 12 | 13 | import copy 14 | import re 15 | 16 | ###################################### 17 | ## GLOBALS 18 | 19 | # desired depth layers to compute style/content losses : 20 | CONTENT_LAYERS_DEFAULT = ['conv_4'] 21 | STYLE_LAYERS_DEFAULT = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] 22 | CNN_NORMALIZATION_MEAN = torch.tensor([0.485, 0.456, 0.406]) 23 | CNN_NORMALIZATION_STD = torch.tensor([0.229, 0.224, 0.225]) 24 | 25 | 26 | 27 | ####################################### 28 | ## CLASSES 29 | 30 | class ContentLoss(nn.Module): 31 | 32 | def __init__(self, target,): 33 | super(ContentLoss, self).__init__() 34 | # we 'detach' the target content from the tree used 35 | # to dynamically compute the gradient: this is a stated value, 36 | # not a variable. Otherwise the forward method of the criterion 37 | # will throw an error. 38 | self.target = target.detach() 39 | 40 | def forward(self, input): 41 | self.loss = F.mse_loss(input, self.target) 42 | return input 43 | 44 | 45 | class StyleLoss(nn.Module): 46 | 47 | def __init__(self, target_feature): 48 | super(StyleLoss, self).__init__() 49 | self.target = gram_matrix(target_feature).detach() 50 | 51 | def forward(self, input): 52 | G = gram_matrix(input) 53 | self.loss = F.mse_loss(G, self.target) 54 | return input 55 | 56 | # create a module to normalize input image so we can easily put it in a 57 | # nn.Sequential 58 | class Normalization(nn.Module): 59 | def __init__(self, mean, std): 60 | super(Normalization, self).__init__() 61 | # .view the mean and std to make them [C x 1 x 1] so that they can 62 | # directly work with image Tensor of shape [B x C x H x W]. 63 | # B is batch size. C is number of channels. H is height and W is width. 64 | self.mean = torch.tensor(mean).view(-1, 1, 1) 65 | self.std = torch.tensor(std).view(-1, 1, 1) 66 | 67 | def forward(self, img): 68 | # normalize img 69 | return (img - self.mean) / self.std 70 | 71 | ####################################### 72 | ## HELPERS 73 | 74 | def tensor_to_image(tensor): 75 | t = transforms.ToPILImage() # reconvert into PIL image 76 | image = tensor.cpu().clone() # we clone the tensor to not do changes on it 77 | image = image.squeeze(0) # remove the fake batch dimension 78 | image = t(image) 79 | return image 80 | 81 | def image_loader(image_name, size, device): 82 | # transform images to same size 83 | t = transforms.Compose([transforms.Resize(size), # scale imported image 84 | transforms.ToTensor()]) # transform it into a torch tensor 85 | image = Image.open(image_name) 86 | if image.mode != 'RGB': 87 | image = image.convert('RGB') 88 | # fake batch dimension required to fit network's input dimensions 89 | image = image.resize((size, size), PIL.Image.ANTIALIAS) 90 | image = t(image).unsqueeze(0) 91 | return image.to(device, torch.float) 92 | 93 | 94 | def gram_matrix(input): 95 | a, b, c, d = input.size() # a=batch size(=1) 96 | # b=number of feature maps 97 | # (c,d)=dimensions of a f. map (N=c*d) 98 | 99 | features = input.view(a * b, c * d) # resise F_XL into \hat F_XL 100 | 101 | G = torch.mm(features, features.t()) # compute the gram product 102 | 103 | # we 'normalize' the values of the gram matrix 104 | # by dividing by the number of element in each feature maps. 105 | return G.div(a * b * c * d) 106 | 107 | def get_input_optimizer(input_img): 108 | # this line to show that input is a parameter that requires a gradient 109 | optimizer = optim.LBFGS([input_img.requires_grad_()]) 110 | return optimizer 111 | 112 | def get_style_model_and_losses(cnn, normalization_mean, normalization_std, 113 | style_img, content_img, device, 114 | content_layers = CONTENT_LAYERS_DEFAULT, 115 | style_layers = CONTENT_LAYERS_DEFAULT): 116 | cnn = copy.deepcopy(cnn) 117 | 118 | # normalization module 119 | normalization = Normalization(normalization_mean, normalization_std).to(device) 120 | 121 | # just in order to have an iterable access to or list of content/syle 122 | # losses 123 | content_losses = [] 124 | style_losses = [] 125 | 126 | # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential 127 | # to put in modules that are supposed to be activated sequentially 128 | model = nn.Sequential(normalization) 129 | 130 | i = 0 # increment every time we see a conv 131 | for layer in cnn.children(): 132 | if isinstance(layer, nn.Conv2d): 133 | i += 1 134 | name = 'conv_{}'.format(i) 135 | elif isinstance(layer, nn.ReLU): 136 | name = 'relu_{}'.format(i) 137 | # The in-place version doesn't play very nicely with the ContentLoss 138 | # and StyleLoss we insert below. So we replace with out-of-place 139 | # ones here. 140 | layer = nn.ReLU(inplace=False) 141 | elif isinstance(layer, nn.MaxPool2d): 142 | name = 'pool_{}'.format(i) 143 | elif isinstance(layer, nn.BatchNorm2d): 144 | name = 'bn_{}'.format(i) 145 | else: 146 | raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) 147 | 148 | model.add_module(name, layer) 149 | 150 | if name in content_layers: 151 | # add content loss: 152 | target = model(content_img).detach() 153 | content_loss = ContentLoss(target) 154 | model.add_module("content_loss_{}".format(i), content_loss) 155 | content_losses.append(content_loss) 156 | 157 | if name in style_layers: 158 | # add style loss: 159 | target_feature = model(style_img).detach() 160 | style_loss = StyleLoss(target_feature) 161 | model.add_module("style_loss_{}".format(i), style_loss) 162 | style_losses.append(style_loss) 163 | 164 | # now we trim off the layers after the last content and style losses 165 | for i in range(len(model) - 1, -1, -1): 166 | if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): 167 | break 168 | 169 | model = model[:(i + 1)] 170 | 171 | return model, style_losses, content_losses 172 | 173 | ########################################### 174 | 175 | def run_style_transfer(cnn, content_img, style_img, input_img, device, 176 | normalization_mean = CNN_NORMALIZATION_MEAN, 177 | normalization_std = CNN_NORMALIZATION_STD, 178 | content_layers = CONTENT_LAYERS_DEFAULT, 179 | style_layers = STYLE_LAYERS_DEFAULT, 180 | num_steps = 300, 181 | style_weight = 1000000, 182 | content_weight = 1 183 | ): 184 | """Run the style transfer.""" 185 | print('Building the style transfer model..') 186 | model, style_losses, content_losses = get_style_model_and_losses(cnn, 187 | normalization_mean.to(device), normalization_std.to(device), style_img, content_img, device, 188 | content_layers, style_layers) 189 | optimizer = get_input_optimizer(input_img) 190 | 191 | best_img = [None] 192 | best_score = [None] 193 | 194 | print('Optimizing..') 195 | run = [0] 196 | while run[0] <= num_steps: 197 | 198 | def closure(): 199 | # correct the values of updated input image 200 | input_img.data.clamp_(0, 1) 201 | 202 | optimizer.zero_grad() 203 | model(input_img) 204 | style_score = 0 205 | content_score = 0 206 | 207 | for sl in style_losses: 208 | style_score += sl.loss 209 | for cl in content_losses: 210 | content_score += cl.loss 211 | 212 | style_score *= style_weight 213 | content_score *= content_weight 214 | 215 | loss = style_score + content_score 216 | loss.backward() 217 | 218 | run[0] += 1 219 | if run[0] % 50 == 0: 220 | print("run {}:".format(run)) 221 | print('Style Loss : {:4f} Content Loss: {:4f}'.format( 222 | style_score.item(), content_score.item())) 223 | print() 224 | current_score = style_score.item() + content_score.item() 225 | if best_img[0] is None or current_score <= best_score[0]: 226 | best_img[0] = input_img.clone() 227 | best_img[0].data.clamp_(0, 1) 228 | best_score[0] = current_score 229 | 230 | return style_score + content_score 231 | 232 | optimizer.step(closure) 233 | # a last correction... # not sure I need to do this 234 | input_img.data.clamp_(0, 1) 235 | 236 | return best_img[0] 237 | 238 | ########################################## 239 | def process_layers_spec(spec): 240 | spec = str(spec) 241 | layers = re.findall(r'[\-0-9]+', spec) 242 | layers = [int(num) for num in layers] 243 | if -1 in layers or 0 in layers: 244 | return STYLE_LAYERS_DEFAULT 245 | layers = list(filter(lambda x: x >= 1 and x <= 5, layers)) 246 | layers = sorted(layers) 247 | layers = ['conv_' + str(num) for num in layers] 248 | return layers 249 | 250 | ########################################## 251 | 252 | def run(content_image_path, style_image_path, output_path, 253 | image_size = 512, num_steps = 300, style_weight = 1000000, content_weight = 1, 254 | content_layers_spec='4', 255 | style_layers_spec = '1, 2, 3, 4, 5'): 256 | # CUDA or CPU 257 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 258 | 259 | # process content layers specification 260 | content_layers = process_layers_spec(content_layers_spec) 261 | style_layers = process_layers_spec(style_layers_spec) 262 | 263 | # Load images 264 | style_img = image_loader(style_image_path, image_size, device) 265 | content_img = image_loader(content_image_path, image_size, device) 266 | 267 | assert style_img.size() == content_img.size(), "style and content images must be the same size" 268 | 269 | # Load the VGG CNN. Download if necessary. 270 | cnn = models.vgg19(pretrained=True).features.to(device).eval() 271 | 272 | input_img = content_img.clone() 273 | # if you want to use white noise instead uncomment the below line: 274 | # input_img = torch.randn(content_img.data.size(), device=device) 275 | 276 | output = run_style_transfer(cnn, 277 | content_img, style_img, input_img, device, 278 | num_steps = num_steps, 279 | style_weight = style_weight, 280 | content_weight = content_weight, 281 | content_layers = content_layers, 282 | style_layers = style_layers) 283 | img = tensor_to_image(output) 284 | img.save(output_path, "JPEG") 285 | return img 286 | 287 | -------------------------------------------------------------------------------- /readWikipedia.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from xml.etree import ElementTree as etree 4 | import os 5 | import sys 6 | from bs4 import BeautifulSoup 7 | import functools 8 | 9 | wikiDir = "" # The directory to find all the wiki files 10 | outFilename = "" # The filename to dump plots to 11 | titleFilename = "" # The filename to dump title names to 12 | 13 | pattern = '*:plot' 14 | breakSentences = True 15 | 16 | okTags = ['b', 'i', 'a', 'strong', 'em'] 17 | listTags = ['ul', 'ol'] 18 | contentTypes = ['text', 'list'] 19 | 20 | 21 | # Check command line parameters 22 | ''' 23 | if len(sys.argv) > 3: 24 | wikiDir = sys.argv[1] 25 | outFilename = sys.argv[2] 26 | titleFilename = sys.argv[3] 27 | else: 28 | print "usage:", sys.argv[0], "wikidirectory resultfile titlefile" 29 | exit() 30 | ''' 31 | 32 | 33 | ######################## 34 | ### HELPER FUNCTIONS 35 | 36 | def matchPattern(pattern, str): 37 | if pattern == '*': 38 | return str 39 | else: 40 | pattern = pattern.split('|') 41 | for p in pattern: 42 | match = re.search(p+r'\b', str) 43 | if match is not None: 44 | return match.group(0) 45 | return '' 46 | 47 | 48 | def fixList(soup, depth = 0): 49 | result = '' 50 | if soup.name is None: 51 | lines = [s.strip() for s in soup.splitlines()] 52 | for line in lines: 53 | if len(line) > 0 and line[0] == '-': 54 | dashes = '' 55 | for n in range(depth): 56 | dashes = dashes + '-' 57 | result = result + dashes + ' ' + line[1:].strip() + '\n' 58 | elif soup.name in listTags: 59 | for child in soup.children: 60 | result = result + fixList(child, depth + 1) 61 | return result 62 | 63 | 64 | 65 | ######################## 66 | 67 | def ReadWikipedia(wiki_directory, pattern, categoryPattern, out_file, titles_file): 68 | pattern = pattern.split(':') 69 | titlePattern = pattern.pop(0) 70 | headerList = pattern 71 | 72 | contentType = 'text' # The type of content to grab 'text' or 'list'. Should be the tail of the headerPattern. 73 | 74 | files = [] # All the wiki files 75 | 76 | # Get all the wiki files left by wikiextractor 77 | for dirname, dirnames, filenames in os.walk(os.path.join('.', wiki_directory)): 78 | for filename in filenames: 79 | if filename[0] != '.': 80 | files.append(os.path.join(dirname, filename)) 81 | 82 | with open(out_file, "w") as outfile: 83 | # Opened the output file 84 | with open(titles_file, "w") as titlefile: 85 | # Opened the title file 86 | # Walk through each file. Each file has a json for each wikipedia article. Look for jsons with "plot" subheaders 87 | for file in files: 88 | #print >> outfile, "file:", file #FOR DEBUGGING 89 | data = [] # Each element is a json record 90 | # Read the file and get all the json records 91 | for line in open(file, 'r'): 92 | data.append(json.loads(line)) 93 | # Look for pattern matches in heading tags inside the text of the json 94 | for j in data: 95 | # j is a json record 96 | titleMatch = matchPattern(titlePattern, j['title']) 97 | categoryMatch = matchPattern(categoryPattern, j['categories']) 98 | if len(titleMatch) > 0 and len(categoryMatch) > 0: 99 | print("title:", titleMatch, "in", j['title']) 100 | print("category:", categoryMatch, "in", j['categories']) 101 | # This json record is a match to titlePattern 102 | #print >> outfile, j['title'].encode('utf-8') #FOR DEBUGGING 103 | # Text element contains HTML 104 | soup = BeautifulSoup(j['text'].encode('utf-8'), "html.parser") 105 | result = "" # The result found (if any) 106 | inresult = False # Am I inside a result section of the article? 107 | previousHeaders = [] 108 | headerIndex = 0 109 | # Walk through each element in the html soup object 110 | for n in range(len(soup.contents)): 111 | current = soup.contents[n] # The current html element 112 | if len(headerList) == 0: 113 | # If only title information is given, we just get everything 114 | if current is not None and current.name is None: 115 | result = result + current.strip() + ' ' 116 | elif not inresult and current is not None and current.name is not None and current.name == 'h' + str(headerIndex + 2): # start with h2 117 | # Let's see if this header matches the current expected pattern 118 | #print >> outfile, "current(1):", current.name.encode('utf-8'), current.encode('utf-8') 119 | match = False 120 | if len(headerList) == 0: 121 | match = True 122 | elif len(headerList) > 0: 123 | match = matchPattern(headerList[headerIndex].lower(), current.get_text().lower()) 124 | if match: 125 | # this header matches 126 | previousHeaders.append(current.get_text()) 127 | #print >> outfile, "previousheaders(a):", map(lambda x: x.encode('utf-8'), previousHeaders) 128 | headerIndex = headerIndex + 1 129 | if headerIndex >= len(headerList): 130 | inresult = True 131 | elif headerList[headerIndex].lower() in contentTypes: 132 | inresult = True 133 | contentType = headerList[headerIndex].lower() 134 | else: 135 | previousHeaders = [] 136 | elif inresult and current is not None and current.name is not None and current.name[0] == 'h' and int(current.name[1]) >= (headerIndex + 2): 137 | # I'm probably seeing a sub-heading inside of what I want 138 | previousHeaders.append(current.get_text()) 139 | #print >> outfile, "previousheaders(b):", map(lambda x: x.encode('utf-8'), previousHeaders) 140 | elif inresult and current is not None and current.name is not None and current.name in listTags: 141 | # found a list inside what I am looking for 142 | result = result + '\n' + fixList(current) + '\n ' 143 | elif inresult and current is not None and (current.name is None or current.name.lower() in okTags): 144 | # I'm probably looking at text inside of what I want 145 | #print >> outfile, "current(3):", current.encode('utf-8') 146 | if contentType != 'list': 147 | current = current.strip() 148 | # Sometimes we see the header name duplicated inside the text block that succeeds the sub-section header. Crop it off 149 | if len(current) > 0: 150 | if len(previousHeaders) > 0: 151 | #print >> outfile, "previousheaders(c):", map(lambda x: x.encode('utf-8'), previousHeaders) 152 | headerLength = functools.reduce(lambda x,y: x+y, map(lambda z: len(z)+2, previousHeaders)) # add 2 for period and space. 153 | result = result + current[headerLength:].strip() + ' ' 154 | else: 155 | result = result + current.strip() + ' ' 156 | # Forget the previous header. It was either consumed or wasn't duplicated in the first place. 157 | previousHeaders = [] 158 | elif inresult and current is not None and current.name is not None and current.name[0] == 'h' and int(current.name[1]) < (headerIndex + 2): 159 | # Probably left the block. All done with this json! 160 | break 161 | elif not inresult and current is not None and current.name is not None and current.name[0] == 'h' and int(current.name[1]) > 1 and int(current.name[1]) < (headerIndex + 2): 162 | # not in the result, but we went up one level 163 | headerIndex = headerIndex - 1 164 | if len(previousHeaders) > 0: 165 | previousHeaders.pop() 166 | # Let's see if this header matches the current expected pattern 167 | #print >> outfile, "current(2):", current.name.encode('utf-8'), current.encode('utf-8') 168 | match = matchPattern(headerList[headerIndex].lower(), current.get_text().lower()) 169 | if len(match) > 0: 170 | # this header matches 171 | previousHeaders.append(current.get_text()) 172 | #print >> outfile, "previousheaders(d):", map(lambda x: x.encode('utf-8'), previousHeaders) 173 | headerIndex = headerIndex + 1 174 | if headerIndex >= len(headerList): 175 | inresult = True 176 | elif headerList[headerIndex].lower() in contentTypes: 177 | inresult = True 178 | contentType = headerList[headerIndex].lower() 179 | elif not inresult and current is not None and current.name is not None and current.name[0] == 'h' and int(current.name[1]) > 1 and int(current.name[1]) >= (headerIndex + 2): 180 | # I'm not in the result block and I saw something that wasn't a header. 181 | previousHeaders = [] 182 | #print >> outfile, "previous header cleared (e)", current.encode('utf-8') 183 | elif not inresult and current is not None and current.name is None and len(current.strip()) > 0: 184 | previousHeaders = [] 185 | #print >> outfile, "previous header cleared (f)", current.encode('utf-8') 186 | # Did we find what we were looking for? 187 | if len(result) > 0: 188 | # ASSERT: I have a result 189 | # Record the name of the article with the result 190 | #titlefile.write(j['title'].encode('utf-8') + '\n') 191 | titlefile.write(j['title'] + '\n') 192 | ''' 193 | # remove newlines 194 | #result = result.replace('\n', ' ').replace('\r', '').strip() 195 | # remove html tags (probably mainly hyperlinks) 196 | result = re.sub('<[^<]+?>', '', result) 197 | # remove character name initials and take periods off mr/mrs/ms/dr/etc. 198 | result = re.sub(' [M|m]r\.', ' mr', result) 199 | result = re.sub(' [M|m]rs\.', ' mrs', result) 200 | result = re.sub(' [M|m]s\.', ' ms', result) 201 | result = re.sub(' [D|d]r\.', ' dr', result) 202 | #result = re.sub(' [M|m]d\.', ' md', result) 203 | #result = re.sub(' [P|p][H|h][D|d]\.', ' phd', result) 204 | #result = re.sub(' [E|e][S|s][Q|q]\.', ' esq', result) 205 | result = re.sub(' [L|l][T|t]\.', ' lt', result) 206 | result = re.sub(' [G|g][O|o][V|v]\.', ' lt', result) 207 | result = re.sub(' [C|c][P|p][T|t]\.', ' cpt', result) 208 | result = re.sub(' [S|s][T|t]\.', ' st', result) 209 | # handle i.e. and cf. 210 | result = re.sub('i\.e\. ', 'ie ', result) 211 | result = re.sub('cf\. ', 'cf', result) 212 | # deal with periods in quotes 213 | result = re.sub('\.\"', '\".', result) 214 | # remove single letter initials 215 | p4 = re.compile(r'([ \()])([A-Z|a-z])\.') 216 | result = p4.sub(r'\1\2', result) 217 | # Acroymns with periods are not fun. Need two steps to get rid of those periods. 218 | # I don't think this is working quite right 219 | p1 = re.compile('([A-Z|a-z])\.([)|\"|\,])') 220 | result = p1.sub(r'\1\2', result) 221 | p2 = re.compile('\.([A-Z|a-z])') 222 | result = p2.sub(r'\1', result) 223 | # periods in numbers 224 | p3 = re.compile('([0-9]+)\.([0-9]+)') 225 | result = p3.sub(r'\1\2', result) 226 | ''' 227 | # Print result 228 | if contentType == 'text': 229 | #print >> outfile, result.strip().encode('utf-8') 230 | outfile.write(result.strip() + '\n') 231 | elif contentType == 'list': 232 | lines = [s.strip() for s in result.splitlines()] 233 | for line in lines: 234 | if len(line) > 0 and line[0] == '-': 235 | #rint >> outfile, line.strip().encode('utf-8') 236 | outfile.write(line.strip()) 237 | outfile.write('\n') 238 | 239 | 240 | ### TEST RUN 241 | #readWikipedia(wikiDir, pattern, outFilename, titleFilename, breakSentences) 242 | 243 | 244 | 245 | 246 | ''' 247 | TODO: future modules: 248 | - Remove Empty lines 249 | - Remove newlines 250 | - Segment sentences into separate lines 251 | - Remove characters/words/substrings 252 | - strip all lines 253 | ''' -------------------------------------------------------------------------------- /stylegan_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import sys 4 | import pickle 5 | import random 6 | import math 7 | import argparse 8 | import numpy as np 9 | from PIL import Image 10 | from tqdm import tqdm_notebook as tqdm 11 | 12 | def easygen_train(model_path, images_path, dataset_path, start_kimg=7000, max_kimg=25000, schedule='', seed=1000): 13 | #import stylegan 14 | #from stylegan import config 15 | ##from stylegan import dnnlib 16 | #from stylegan.dnnlib import EasyDict 17 | 18 | #images_dir = '/content/raw' 19 | #max_kimg = 25000 20 | #start_kimg = 7000 21 | #schedule = '' 22 | #model_in = '/content/karras2019stylegan-cats-256x256.pkl' 23 | 24 | #dataset_dir = '/content/stylegan_dataset' #os.path.join(cwd, 'cache', 'stylegan_dataset') 25 | 26 | import config 27 | config.data_dir = '/content/datasets' 28 | config.results_dir = '/content/results' 29 | config.cache_dir = '/contents/cache' 30 | run_dir_ignore = ['/contents/results', '/contents/datasets', 'contents/cache'] 31 | import copy 32 | import dnnlib 33 | from dnnlib import EasyDict 34 | from metrics import metric_base 35 | # Prep dataset 36 | import dataset_tool 37 | print("prepping dataset...") 38 | dataset_tool.create_from_images(tfrecord_dir=dataset_path, image_dir=images_path, shuffle=False) 39 | # Set up training parameters 40 | desc = 'sgan' # Description string included in result subdir name. 41 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. 42 | G = EasyDict(func_name='training.networks_stylegan.G_style') # Options for generator network. 43 | D = EasyDict(func_name='training.networks_stylegan.D_basic') # Options for discriminator network. 44 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. 45 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. 46 | G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating') # Options for generator loss. 47 | D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp', r1_gamma=10.0) # Options for discriminator loss. 48 | dataset = EasyDict() # Options for load_dataset(). 49 | sched = EasyDict() # Options for TrainingSchedule. 50 | grid = EasyDict(size='1080p', layout='random') # Options for setup_snapshot_image_grid(). 51 | #metrics = [metric_base.fid50k] # Options for MetricGroup. 52 | submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). 53 | tf_config = {'rnd.np_random_seed': seed} # Options for tflib.init_tf(). 54 | # Dataset 55 | desc += '-custom' 56 | dataset = EasyDict(tfrecord_dir=dataset_path) 57 | train.mirror_augment = False 58 | # Number of GPUs. 59 | desc += '-1gpu' 60 | submit_config.num_gpus = 1 61 | sched.minibatch_base = 4 62 | sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4} #{4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 16} 63 | # Default options. 64 | train.total_kimg = max_kimg 65 | sched.lod_initial_resolution = 8 66 | sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} 67 | sched.D_lrate_dict = EasyDict(sched.G_lrate_dict) 68 | # schedule 69 | schedule_dict = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:30, 1024:20} #{4: 2, 8:2, 16:2, 32:2, 64:2, 128:2, 256:2, 512:2, 1024:2} # Runs faster for small datasets 70 | if len(schedule) >=5 and schedule[0] == '{' and schedule[-1] == '}' and ':' in schedule: 71 | # is schedule a string of a dict? 72 | try: 73 | temp = eval(schedule) 74 | schedule_dict = dict(temp) 75 | # assert: it is a dict 76 | except: 77 | pass 78 | elif len(schedule) > 0: 79 | # is schedule an int? 80 | try: 81 | schedule_int = int(schedule) 82 | #assert: schedule is an int 83 | schedule_dict = {} 84 | for i in range(1, 10): 85 | schedule_dict[int(math.pow(2, i+1))] = schedule_int 86 | except: 87 | pass 88 | print('schedule:', str(schedule_dict)) 89 | sched.tick_kimg_dict = schedule_dict 90 | # resume kimg 91 | resume_kimg = start_kimg 92 | # path to model 93 | resume_run_id = model_path 94 | # tick snapshots 95 | image_snapshot_ticks = 1 96 | network_snapshot_ticks = 1 97 | # Submit run 98 | kwargs = EasyDict(train) 99 | kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) 100 | kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, tf_config=tf_config) 101 | kwargs.update(resume_kimg=resume_kimg, resume_run_id=resume_run_id) 102 | kwargs.update(image_snapshot_ticks=image_snapshot_ticks, network_snapshot_ticks=network_snapshot_ticks) 103 | kwargs.submit_config = copy.deepcopy(submit_config) 104 | kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) 105 | kwargs.submit_config.run_dir_ignore += config.run_dir_ignore 106 | kwargs.submit_config.run_desc = desc 107 | dnnlib.submit_run(**kwargs) 108 | 109 | def easygen_run(model_path, images_path, num=1): 110 | # from https://github.com/ak9250/stylegan-art/blob/master/styleganportraits.ipynb 111 | truncation = 0.7 # hard coding because everyone uses this value 112 | import dnnlib 113 | import dnnlib.tflib as tflib 114 | import config 115 | tflib.init_tf() 116 | #num = 10 117 | #model = '/content/karras2019stylegan-cats-256x256.pkl' 118 | #images_dir = '/content/cache/run_out' 119 | #truncation = 0.7 120 | _G = None 121 | _D = None 122 | Gs = None 123 | with open(model_path, 'rb') as f: 124 | _G, _D, Gs = pickle.load(f) 125 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 126 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) 127 | latents = np.random.RandomState(int(1000*random.random())).randn(num, *Gs.input_shapes[0][1:]) 128 | labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:]) 129 | images = Gs.run(latents, None, truncation_psi=truncation, randomize_noise=False, output_transform=fmt) 130 | for n, image in enumerate(images): 131 | # img = Image.fromarray(images[0]) 132 | img = Image.fromarray(image) 133 | img.save(os.path.join(images_path, str(n) + '.jpg'), "JPEG") 134 | 135 | def get_latent_interpolation(endpoints, num_frames_per, mode = 'linear', shuffle = False): 136 | if shuffle: 137 | random.shuffle(endpoints) 138 | num_endpoints, dim = len(endpoints), len(endpoints[0]) 139 | num_frames = num_frames_per * num_endpoints 140 | endpoints = np.array(endpoints) 141 | latents = np.zeros((num_frames, dim)) 142 | for e in range(num_endpoints): 143 | e1, e2 = e, (e+1)%num_endpoints 144 | for t in range(num_frames_per): 145 | frame = e * num_frames_per + t 146 | r = 0.5 - 0.5 * np.cos(np.pi*t/(num_frames_per-1)) if mode == 'ease' else float(t) / num_frames_per 147 | latents[frame, :] = (1.0-r) * endpoints[e1,:] + r * endpoints[e2,:] 148 | return latents 149 | 150 | 151 | 152 | def easygen_movie(model_path, movie_path, num=10, interp=10, duration=10): 153 | # from https://github.com/ak9250/stylegan-art/blob/master/styleganportraits.ipynb 154 | import dnnlib 155 | import dnnlib.tflib as tflib 156 | import config 157 | tflib.init_tf() 158 | truncation = 0.7 # what everyone uses 159 | # Get model 160 | _G = None 161 | _D = None 162 | Gs = None 163 | with open(model_path, 'rb') as f: 164 | _G, _D, Gs = pickle.load(f) 165 | # Make waypoints 166 | #fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 167 | #synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) 168 | waypoint_latents = np.random.RandomState(int(1000*random.random())).randn(num, *Gs.input_shapes[0][1:]) 169 | #waypoint_labels = np.zeros([waypoint_latents.shape[0]] + Gs.input_shapes[1][1:]) 170 | #waypoint_images = Gs.run(latents, None, truncation_psi=truncation, randomize_noise=False, output_transform=fmt) 171 | # interpolate 172 | interp_latents = get_latent_interpolation(waypoint_latents, interp) 173 | interp_labels = np.zeros([interp_latents.shape[0]] + Gs.input_shapes[1][1:]) 174 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 175 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) 176 | batch_size = 8 177 | num_frames = interp_latents.shape[0] 178 | num_batches = int(np.ceil(num_frames/batch_size)) 179 | images = [] 180 | for b in tqdm(range(num_batches)): 181 | new_images = Gs.run(interp_latents[b*batch_size:min((b+1)*batch_size, num_frames-1), :], None, truncation_psi=truncation, randomize_noise=False, output_transform=fmt) 182 | for img in new_images: 183 | images.append(Image.fromarray(img)) # convert to PIL.Image 184 | images[0].save(movie_path, "GIF", 185 | save_all=True, 186 | append_images=images[1:], 187 | duration=duration, 188 | loop=0) 189 | 190 | 191 | if __name__ == '__main__': 192 | parser = argparse.ArgumentParser(description='Process runner commands.') 193 | parser.add_argument('--train', action="store_true", default=False) 194 | parser.add_argument('--run', action="store_true", default=False) 195 | parser.add_argument('--movie', action="store_true", default=False) 196 | parser.add_argument("--model", help="model to load", default="") 197 | parser.add_argument("--images_in", help="directory containing training images", default="") 198 | parser.add_argument("--images_out", help="diretory to store generated images", default="") 199 | parser.add_argument("--movie_out", help="directory to save movie", default="") 200 | parser.add_argument("--dataset_temp", help="where to store prepared image data", default="") 201 | parser.add_argument("--schedule", help="training schedule", default="") 202 | parser.add_argument("--max_kimg", help="iteration to stop training at", type=int, default=25000) 203 | parser.add_argument("--start_kimg", help="iteration to start training at", type=int, default=7000) 204 | parser.add_argument("--num", help="number of images to generate", type=int, default=1) 205 | parser.add_argument("--interp", help="number of images to interpolate", type=int, default=10) 206 | parser.add_argument("--duration", help="how long for each image in movie", type=int, default=10) 207 | parser.add_argument("--seed", help="seed number", type=int, default=1000) 208 | args = parser.parse_args() 209 | if args.train: 210 | easygen_train(model_path=args.model, 211 | images_path=args.images_in, 212 | dataset_path=args.dataset_temp, 213 | start_kimg=args.start_kimg, 214 | max_kimg=args.max_kimg, 215 | schedule=args.schedule, 216 | seed=args.seed) 217 | elif args.run: 218 | easygen_run(model_path=args.model, 219 | images_path=args.images_out, 220 | num=args.num) 221 | elif args.movie: 222 | easygen_movie(model_path=args.model, 223 | movie_path=args.movie_out, 224 | num=args.num, 225 | interp=args.interp, 226 | duration=args.duration) 227 | -------------------------------------------------------------------------------- /Easygen.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Easygen.ipynb","version":"0.3.2","provenance":[{"file_id":"1PpvfuRxrR93GV6z_Asm7QSG3C9wiNX4l","timestamp":1561078408259}],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"XzgHTDSyzMRP","colab_type":"text"},"source":["# EasyGen\n","\n","EasyGen is a visual user interface to help set up simple neural network generation tasks.\n","\n","There are a number of neural network frameworks (Tensorflow, TFLearn, Keras, PyTorch, etc.) that implement standard algorithms for generating text (e.g., recurrent neural networks, sequence2sequence) and images (e.g., generative adversarial networks). They require some familairity with coding. Beyond that, just running a simple experiment may require a complex set of steps to clean and prepare the data.\n","\n","EasyGen allows one to quickly set up the data cleaning and neural network training pipeline using a graphical user interface and a number of self-contained \"modules\" that implement stardard data preparation routines. EasyGen differs from other neural network user interfaces in that it doesn't focus on the graphical instantiation of the neural network itself. Instead, it provides an easy to use way to instantiate some of the most common neural network algorithms used for generation. EasyGen focuses on the data preparation.\n","\n","For documentation see the [EasyGen Github repo](https://https://github.com/markriedl/easygen)\n","\n","To get started:\n","\n","1. Clone the [EasyGen notebook](https://drive.google.com/open?id=1XNiOuNtMnItl5CPGvRjEvj9C78nDuvXj) by following the link and selecting File -> Save a copy in Drive.\n","\n","2. Turn on GPU support under Edit -> Notebook setting.\n","\n","3. Run the cells in Sections 1. Some are optional if you know you aren't going to be using particular features.\n","\n","4. Run the cell in Section 2. If you know there are any models or datasets that you won't be using you can skip them.\n","\n","5. Run the cell in Section 3. This creates a blank area below the cell in which you can use the buttons to create your visual program. An example program is loaded automatically. You can clear it with the \"clear\" button below it. Afterwards you can create your own programs. Selecting \"Make New Module\" will cause the new module appears graphically above and can be dragged around. The inputs and outputs of different modules can be connected together by clicking on an output (red) and dragging to an input (green). Gray boxes are parameters that can be edited. Clicking on a gray box causes a text input field to appear at the bottom of the editing area, just above the \"Make New Module\" controls.\n","\n","6. Save your program by entering a program name and pressing the \"save\" button.\n","\n","7. Run your program by editing the program name in the cell in Section 4 and then running the cell."]},{"cell_type":"markdown","metadata":{"id":"1PJwtOgUXpAG","colab_type":"text"},"source":["# 1. Setup\n","\n","Download EasyGen"]},{"cell_type":"code","metadata":{"id":"cXGp8MWKPumP","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"outputId":"0b48a281-edce-4a16-b1d2-ddbabaed86d6","executionInfo":{"status":"ok","timestamp":1564620470959,"user_tz":240,"elapsed":3309,"user":{"displayName":"Mark Riedl","photoUrl":"https://lh4.googleusercontent.com/-VkUvOUfU44c/AAAAAAAAAAI/AAAAAAAAAG4/lS4Rpm-pe_0/s64/photo.jpg","userId":"17417011882129852150"}}},"source":["!git clone https://github.com/markriedl/easygen.git\n","!cp easygen/*.js /usr/local/share/jupyter/nbextensions/google.colab/\n","!cp easygen/images/*.png /usr/local/share/jupyter/nbextensions/google.colab/"],"execution_count":1,"outputs":[{"output_type":"stream","text":["fatal: destination path 'easygen' already exists and is not an empty directory.\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"9XikLU2ez0Yz","colab_type":"text"},"source":["Install requirements"]},{"cell_type":"code","metadata":{"id":"jlgTqMtFX3NU","colab_type":"code","colab":{}},"source":["!apt-get update\n","!apt-get install chromium-chromedriver\n","!pip install -r easygen/requirements.txt"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gqeftQQ6z3we","colab_type":"text"},"source":["Download StyleGAN"]},{"cell_type":"code","metadata":{"id":"l9jSPMsgj0hI","colab_type":"code","colab":{}},"source":["!git clone https://github.com/NVlabs/stylegan.git\n","!cp easygen/stylegan_runner.py stylegan"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BJNElaVez7JW","colab_type":"text"},"source":["Download GPT-2"]},{"cell_type":"code","metadata":{"id":"3yHkEkVNj2Ih","colab_type":"code","colab":{}},"source":["!git clone https://github.com/nshepperd/gpt-2"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4LW3Il9nzg5M","colab_type":"text"},"source":["Create backend hooks for saving and loading programs"]},{"cell_type":"code","metadata":{"id":"bV_TJbUjzfza","colab_type":"code","colab":{}},"source":["import IPython\n","from google.colab import output\n","\n","def python_save_hook(file_text, filename):\n"," import easygen\n"," import hooks\n"," status = hooks.python_save_hook_aux(file_text, filename)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_load_hook(filename):\n"," import easygen\n"," import hooks\n"," result = hooks.python_load_hook_aux(filename)\n"," return IPython.display.JSON({'result': result})\n","\n","def python_cwd_hook(dir):\n"," import easygen\n"," import hooks\n"," result = hooks.python_cwd_hook_aux(dir)\n"," return IPython.display.JSON({'result': result})\n","\n","def python_copy_hook(path1, path2):\n"," import easygen\n"," import hooks\n"," status = hooks.python_copy_hook_aux(path1, path2)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_move_hook(path1, path2):\n"," import easygen\n"," import hooks\n"," status = hooks.python_move_hook_aux(path1, path2)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_open_text_hook(path):\n"," import easygen\n"," import hooks\n"," status = hooks.python_open_text_hook_aux(path)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n"," \n","def python_open_image_hook(path):\n"," import easygen\n"," import hooks\n"," status = hooks.python_open_image_hook_aux(path)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_mkdir_hook(path, dir_name):\n"," import easygen\n"," import hooks\n"," status = hooks.python_mkdir_hook_aux(path, dir_name)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_trash_hook(path):\n"," import easygen\n"," import hooks\n"," status = hooks.python_trash_hook_aux(path)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_run_hook(path):\n"," import easygen\n"," program_file_name = path\n"," easygen.main(program_file_name)\n"," return IPython.display.JSON({'result': 'true'})\n","\n","output.register_callback('notebook.python_cwd_hook', python_cwd_hook)\n","output.register_callback('notebook.python_copy_hook', python_copy_hook)\n","output.register_callback('notebook.python_move_hook', python_move_hook)\n","output.register_callback('notebook.python_open_text_hook', python_open_text_hook)\n","output.register_callback('notebook.python_open_image_hook', python_open_image_hook)\n","output.register_callback('notebook.python_save_hook', python_save_hook)\n","output.register_callback('notebook.python_load_hook', python_load_hook)\n","output.register_callback('notebook.python_mkdir_hook', python_mkdir_hook)\n","output.register_callback('notebook.python_trash_hook', python_trash_hook)\n","output.register_callback('notebook.python_run_hook', python_run_hook)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"R8WwVL1g3MSh","colab_type":"text"},"source":["Import EasyGen"]},{"cell_type":"code","metadata":{"id":"cZ9sZmeR3KsR","colab_type":"code","colab":{}},"source":["import sys\n","sys.path.insert(0, 'easygen')\n","import easygen"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nvMUzJGmXy44","colab_type":"text"},"source":["# 2. Download pre-trained neural network models\n","\n","## 2.1 Download GPT-2 and StyleGan models\n","\n","Download the GPT-2 small 117M model. Will save to ```models/117M``` directory.\n"]},{"cell_type":"code","metadata":{"id":"vqaLwY9GXpuO","colab_type":"code","colab":{}},"source":["!python gpt-2/download_model.py 117M"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wwUI0RL79R3m","colab_type":"text"},"source":["Download the GPT-2 medium 345M model. Will save to ```models/345M``` directory."]},{"cell_type":"code","metadata":{"id":"QZe-0_gY9FWx","colab_type":"code","colab":{}},"source":["!python gpt-2/download_model.py 345M"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YHmVcSNbkVvo","colab_type":"text"},"source":["Download the StyleGAN cats model (256x256). Will save as \"cats256x256.pkl\" in the home directory."]},{"cell_type":"code","metadata":{"id":"njbbbOBRkVXF","colab_type":"code","colab":{}},"source":["!wget -O cats256x256.pkl https://www.dropbox.com/s/1w97383h0nrj4ea/karras2019stylegan-cats-256x256.pkl?dl=0"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zEqyG5DMX7CI","colab_type":"text"},"source":["## 2.2 Download Wikipedia\n","\n","You only need to do this if you are using the ```ReadWikipedia``` functionality. This takes a long time. You may want to skip it if you know you wont be scraping data from Wikipedia."]},{"cell_type":"code","metadata":{"id":"A7LWxcOzWjaV","colab_type":"code","colab":{}},"source":["!wget -O wiki.zip https://www.dropbox.com/s/39w6mj1akwy2a0r/wiki.zip?dl=0\n","!unzip wiki.zip"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"oX_AG49uYV23","colab_type":"text"},"source":["# 3. Run the GUI\n","\n","Run the cell below. This will load a default example program that generate new, fictional paint names. Use the \"clear\" button to clear it and make your own.\n","\n","When done, name the program and press the \"save\" button. You should see your file appear in the file listing in the left panel."]},{"cell_type":"code","metadata":{"id":"NFpPTx0oN8N9","colab_type":"code","colab":{}},"source":["%%html\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","
    \n","

    text

    \n"," text\n"," \n"," \n","
    \n","
    \n","

    Make New Module

    \n"," \n"," \n","
    \n","
    \n","

    Save Program

    \n"," \n"," \n","
    \n","
    \n","

    Load Program

    \n"," \n"," \n","
    \n","
    \n","

    Clear Program

    \n"," \n","
    \n","\n",""],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xUmugXljEFw8","colab_type":"text"},"source":["# 4. Run Your Program"]},{"cell_type":"markdown","metadata":{"id":"GA0sfrdnaDAV","colab_type":"text"},"source":["This will run a default example program that will generate new, fictional paint names. If you don't want to run that program, skip to the next cell."]},{"cell_type":"code","metadata":{"id":"4klnhCIosXQU","colab_type":"code","colab":{}},"source":["program_file_name = 'easygen/examples/make_new_colors'\n","easygen.main(program_file_name)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vnc0wSWEp2m2","colab_type":"text"},"source":["Once you've made your own program, run the cell below, enter the program below, and the press the run button.\n"]},{"cell_type":"code","metadata":{"id":"2BNineAvps7j","colab_type":"code","colab":{}},"source":["%%html\n","\n","\n","\n","\n","Run Program: \n","\n",""],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3umTq8CjEWnl","colab_type":"text"},"source":["# 5. View Your Output Files"]},{"cell_type":"markdown","metadata":{"id":"L0fH_pqlaJTP","colab_type":"text"},"source":["Run the cell below to launch a file manager so you can view your text and image files. \n","\n","You can use the panel to the left to download any files written to disk."]},{"cell_type":"code","metadata":{"id":"Sf7C3Y2J3mUw","colab_type":"code","colab":{}},"source":["%%html\n","\n","\n","\n","

    Manage Files

    \n","\n"," \n"," \n"," \n","
    /content/content



    \n","

    Make Directories

    \n","\n","\n",""],"execution_count":0,"outputs":[]}]} -------------------------------------------------------------------------------- /gpt2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.core.protobuf import rewriter_config_pb2 6 | import sys 7 | import time 8 | import pdb 9 | import shutil 10 | 11 | CUR_PATH = os.getcwd() 12 | GPT_PATH = os.path.join(CUR_PATH, 'gpt-2') 13 | sys.path.insert(0, os.path.join(GPT_PATH, 'src')) 14 | 15 | 16 | import model 17 | import sample 18 | import encoder 19 | import load_dataset 20 | from load_dataset import load_dataset, Sampler 21 | from accumulate import AccumulatingOptimizer 22 | import memory_saving_gradients 23 | 24 | 25 | 26 | #################################################### 27 | 28 | def run_gpt( 29 | model_in_path, 30 | model_name='117M', 31 | raw_text = ' ', 32 | seed=None, 33 | nsamples=1, 34 | batch_size=1, 35 | length=None, 36 | temperature=1, 37 | top_k=0, 38 | top_p=0.0 39 | ): 40 | """ 41 | Interactively run the model 42 | :model_name=117M : String, which model to use 43 | :seed=None : Integer seed for random number generators, fix seed to reproduce 44 | results 45 | :nsamples=1 : Number of samples to return total 46 | :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. 47 | :length=None : Number of tokens in generated text, if None (default), is 48 | determined by model hyperparameters 49 | :temperature=1 : Float value controlling randomness in boltzmann 50 | distribution. Lower temperature results in less random completions. As the 51 | temperature approaches zero, the model will become deterministic and 52 | repetitive. Higher temperature results in more random completions. 53 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 54 | considered for each step (token), resulting in deterministic completions, 55 | while 40 means 40 words are considered at each step. 0 (default) is a 56 | special setting meaning no restrictions. 40 generally is a good value. 57 | :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling, 58 | overriding top_k if set to a value > 0. A good setting is 0.9. 59 | """ 60 | output_text = '' 61 | 62 | if batch_size is None: 63 | batch_size = 1 64 | assert nsamples % batch_size == 0 65 | 66 | enc = get_encoder(model_in_path) 67 | hparams = model.default_hparams() 68 | with open(os.path.join(model_in_path, 'hparams.json')) as f: 69 | hparams.override_from_dict(json.load(f)) 70 | 71 | if length is None: 72 | length = hparams.n_ctx // 2 73 | elif length > hparams.n_ctx: 74 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 75 | 76 | with tf.Session(graph=tf.Graph()) as sess: 77 | context = tf.placeholder(tf.int32, [batch_size, None]) 78 | np.random.seed(seed) 79 | tf.set_random_seed(seed) 80 | output = sample.sample_sequence( 81 | hparams=hparams, length=length, 82 | context=context, 83 | batch_size=batch_size, 84 | temperature=temperature, top_k=top_k, top_p=top_p 85 | ) 86 | 87 | saver = tf.train.Saver() 88 | ckpt = tf.train.latest_checkpoint(model_in_path) #os.path.join('models', model_name)) 89 | saver.restore(sess, ckpt) 90 | 91 | if len(raw_text) == 0: 92 | raw_text = ' ' 93 | context_tokens = enc.encode(raw_text) 94 | generated = 0 95 | for n in range(nsamples // batch_size): 96 | out = sess.run(output, feed_dict={ 97 | context: [context_tokens for _ in range(batch_size)] 98 | })[:, len(context_tokens):] 99 | for i in range(batch_size): 100 | text = enc.decode(out[i]) 101 | generated = generated + 1 102 | #print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 103 | #print(text) 104 | if n == 0: 105 | output_text = text 106 | else: 107 | output_text = output_text + '\n' + text 108 | return output_text 109 | 110 | ############################################################################## 111 | 112 | def maketree(path): 113 | try: 114 | os.makedirs(path) 115 | except: 116 | pass 117 | 118 | 119 | def randomize(context, hparams, p): 120 | if p > 0: 121 | mask = tf.random.uniform(shape=tf.shape(context)) < p 122 | noise = tf.random.uniform(shape=tf.shape(context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32) 123 | return tf.where(mask, noise, context) 124 | else: 125 | return context 126 | 127 | 128 | def get_encoder(path): 129 | with open(os.path.join(path, 'encoder.json'), 'r') as f: 130 | enc = json.load(f) 131 | with open(os.path.join(path, 'vocab.bpe'), 'r', encoding="utf-8") as f: 132 | bpe_data = f.read() 133 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 134 | return encoder.Encoder( 135 | encoder=enc, 136 | bpe_merges=bpe_merges, 137 | ) 138 | 139 | def train(dataset, model_in_path, model_out_path, 140 | model_name = '117M', 141 | steps = 1000, 142 | combine = 50000, 143 | batch_size = 1, 144 | learning_rate = 0.00002, 145 | accumulate_gradients = 1, 146 | memory_saving_gradients = False, 147 | only_train_transformer_layers = False, 148 | optimizer = 'adam', 149 | noise = 0.0, 150 | top_k = 40, 151 | top_p = 0.0, 152 | restore_from = 'latest', 153 | sample_every = 100, 154 | sample_length = 1023, 155 | sample_num = 1, 156 | save_every = 1000, 157 | val_dataset = None): 158 | # Reset the TF computation graph 159 | tf.reset_default_graph() 160 | # Get the checkpoint and sample directories 161 | #checkpoint_dir = os.path.dirname(model_path) 162 | #sample_dir = checkpoint_dir 163 | #run_name = os.path.basename(model_path) 164 | # Load the encoder 165 | enc = get_encoder(model_in_path) 166 | hparams = model.default_hparams() 167 | with open(os.path.join(model_in_path, 'hparams.json')) as f: 168 | hparams.override_from_dict(json.load(f)) 169 | 170 | if sample_length > hparams.n_ctx: 171 | raise ValueError( 172 | "Can't get samples longer than window size: %s" % hparams.n_ctx) 173 | 174 | # Size matters 175 | if model_name == '345M': 176 | memory_saving_gradients = True 177 | if optimizer == 'adam': 178 | only_train_transformer_layers = True 179 | 180 | # Configure TF 181 | config = tf.ConfigProto() 182 | config.gpu_options.allow_growth = True 183 | config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF 184 | # Start the session 185 | with tf.Session(config=config) as sess: 186 | context = tf.placeholder(tf.int32, [batch_size, None]) 187 | context_in = randomize(context, hparams, noise) 188 | output = model.model(hparams=hparams, X=context_in) 189 | loss = tf.reduce_mean( 190 | tf.nn.sparse_softmax_cross_entropy_with_logits( 191 | labels=context[:, 1:], logits=output['logits'][:, :-1])) 192 | 193 | tf_sample = sample.sample_sequence( 194 | hparams=hparams, 195 | length=sample_length, 196 | context=context, 197 | batch_size=batch_size, 198 | temperature=1.0, 199 | top_k=top_k, 200 | top_p=top_p) 201 | 202 | all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] 203 | train_vars = [v for v in all_vars if '/h' in v.name] if only_train_transformer_layers else all_vars 204 | 205 | if optimizer == 'adam': 206 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate) 207 | elif optimizer == 'sgd': 208 | opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) 209 | else: 210 | exit('Bad optimizer:', optimizer) 211 | 212 | if accumulate_gradients > 1: 213 | if memory_saving_gradients: 214 | exit("Memory saving gradients are not implemented for gradient accumulation yet.") 215 | opt = AccumulatingOptimizer( 216 | opt=opt, 217 | var_list=train_vars) 218 | opt_reset = opt.reset() 219 | opt_compute = opt.compute_gradients(loss) 220 | opt_apply = opt.apply_gradients() 221 | summary_loss = tf.summary.scalar('loss', opt_apply) 222 | else: 223 | if memory_saving_gradients: 224 | opt_grads = memory_saving_gradients.gradients(loss, train_vars) 225 | else: 226 | opt_grads = tf.gradients(loss, train_vars) 227 | opt_grads = list(zip(opt_grads, train_vars)) 228 | opt_apply = opt.apply_gradients(opt_grads) 229 | summary_loss = tf.summary.scalar('loss', loss) 230 | 231 | summary_lr = tf.summary.scalar('learning_rate', learning_rate) 232 | summaries = tf.summary.merge([summary_lr, summary_loss]) 233 | 234 | summary_log = tf.summary.FileWriter( 235 | #os.path.join(checkpoint_dir, run_name) 236 | model_out_path 237 | ) 238 | 239 | saver = tf.train.Saver( 240 | var_list=all_vars, 241 | max_to_keep=1) 242 | sess.run(tf.global_variables_initializer()) 243 | 244 | if restore_from == 'latest': 245 | ckpt = tf.train.latest_checkpoint( 246 | #os.path.join(checkpoint_dir, run_name) 247 | model_in_path 248 | ) 249 | if ckpt is None: 250 | # Get fresh GPT weights if new run. 251 | ckpt = tf.train.latest_checkpoint( 252 | model_in_path)#os.path.join('models', model_name)) 253 | elif restore_from == 'fresh': 254 | ckpt = tf.train.latest_checkpoint( 255 | model_in_path)#os.path.join('models', model_name)) 256 | else: 257 | ckpt = tf.train.latest_checkpoint(restore_from) 258 | print('Loading checkpoint', ckpt) 259 | saver.restore(sess, ckpt) 260 | 261 | print('Loading dataset...') 262 | chunks = load_dataset(enc, dataset, combine) 263 | data_sampler = Sampler(chunks) 264 | print('dataset has', data_sampler.total_size, 'tokens') 265 | print('Training...') 266 | counter = 1 267 | counter_path = os.path.join(model_in_path, 'counter') #os.path.join(checkpoint_dir, run_name, 'counter') 268 | if restore_from == 'latest' and os.path.exists(counter_path): 269 | # Load the step number if we're resuming a run 270 | # Add 1 so we don't immediately try to save again 271 | with open(counter_path, 'r') as fp: 272 | counter = int(fp.read()) + 1 273 | 274 | def save(): 275 | #maketree(os.path.join(checkpoint_dir, run_name)) 276 | maketree(model_out_path) 277 | print( 278 | 'Saving', 279 | #os.path.join(checkpoint_dir, run_name, 'model-{}').format(counter) 280 | os.path.join(model_out_path, 'model-{}').format(counter) 281 | ) 282 | saver.save( 283 | sess, 284 | #os.path.join(checkpoint_dir, run_name, 'model'), 285 | os.path.join(model_out_path, 'model'), 286 | global_step=counter) 287 | with open(os.path.join(model_out_path, 'counter'), 'w') as fp: 288 | fp.write(str(counter) + '\n') 289 | 290 | def generate_samples(): 291 | print('Generating samples...') 292 | context_tokens = data_sampler.sample(1) 293 | all_text = [] 294 | index = 0 295 | while index < sample_num: 296 | out = sess.run( 297 | tf_sample, 298 | feed_dict={context: batch_size * [context_tokens]}) 299 | for i in range(min(sample_num - index, batch_size)): 300 | text = enc.decode(out[i]) 301 | text = '======== SAMPLE {} ========\n{}\n'.format( 302 | index + 1, text) 303 | all_text.append(text) 304 | index += 1 305 | print(text) 306 | #maketree(os.path.join(sample_dir, run_name)) 307 | maketree(model_out_path) 308 | with open(os.path.join(model_out_path, 'samples-{}').format(counter), 'w') as fp: 309 | fp.write('\n'.join(all_text)) 310 | 311 | def sample_batch(): 312 | return [data_sampler.sample(1024) for _ in range(batch_size)] 313 | 314 | 315 | avg_loss = (0.0, 0.0) 316 | start_time = time.time() 317 | 318 | stop = steps + counter 319 | 320 | try: 321 | while counter < stop + 1: 322 | if counter % save_every == 0: 323 | save() 324 | ''' 325 | if counter % sample_every == 0: 326 | generate_samples() 327 | ''' 328 | 329 | if accumulate_gradients > 1: 330 | sess.run(opt_reset) 331 | for _ in range(accumulate_gradients): 332 | sess.run( 333 | opt_compute, feed_dict={context: sample_batch()}) 334 | (v_loss, v_summary) = sess.run((opt_apply, summaries)) 335 | else: 336 | (_, v_loss, v_summary) = sess.run( 337 | (opt_apply, loss, summaries), 338 | feed_dict={context: sample_batch()}) 339 | 340 | summary_log.add_summary(v_summary, counter) 341 | 342 | avg_loss = (avg_loss[0] * 0.99 + v_loss, 343 | avg_loss[1] * 0.99 + 1.0) 344 | 345 | print( 346 | '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' 347 | .format( 348 | counter=counter, 349 | time=time.time() - start_time, 350 | loss=v_loss, 351 | avg=avg_loss[0] / avg_loss[1])) 352 | 353 | counter += 1 354 | print('done!') 355 | save() 356 | except KeyboardInterrupt: 357 | print('interrupted') 358 | save() 359 | 360 | ##################################################################### 361 | 362 | if __name__ == '__main__': 363 | cache_path = os.path.join(os.getcwd(), 'cache') 364 | os.chdir('gpt-2') 365 | train(os.path.join(cache_path, 'text0'), cache_path, cache_path, steps=1) 366 | #text = run_gpt(top_k=40) 367 | #print(text) 368 | os.chdir('..') -------------------------------------------------------------------------------- /lstm.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import os 4 | import torch 5 | import codecs 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import pdb 9 | 10 | 11 | ######################################################## 12 | ### MODEL 13 | ######################################################## 14 | 15 | class RNNModel(nn.Module): 16 | """Container module with an encoder, a recurrent module, and a decoder.""" 17 | 18 | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): 19 | super(RNNModel, self).__init__() 20 | self.drop = nn.Dropout(dropout) 21 | self.encoder = nn.Embedding(ntoken, ninp) 22 | if rnn_type in ['LSTM', 'GRU']: 23 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 24 | else: 25 | try: 26 | nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] 27 | except KeyError: 28 | raise ValueError( """An invalid option for `--model` was supplied, 29 | options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""") 30 | self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) 31 | self.decoder = nn.Linear(nhid, ntoken) 32 | 33 | # Optionally tie weights as in: 34 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 35 | # https://arxiv.org/abs/1608.05859 36 | # and 37 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 38 | # https://arxiv.org/abs/1611.01462 39 | if tie_weights: 40 | if nhid != ninp: 41 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 42 | self.decoder.weight = self.encoder.weight 43 | 44 | self.init_weights() 45 | 46 | self.rnn_type = rnn_type 47 | self.nhid = nhid 48 | self.nlayers = nlayers 49 | 50 | def init_weights(self): 51 | initrange = 0.1 52 | self.encoder.weight.data.uniform_(-initrange, initrange) 53 | self.decoder.bias.data.zero_() 54 | self.decoder.weight.data.uniform_(-initrange, initrange) 55 | 56 | def forward(self, input, hidden): 57 | emb = self.drop(self.encoder(input)) 58 | output, hidden = self.rnn(emb, hidden) 59 | output = self.drop(output) 60 | decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2))) 61 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 62 | 63 | def init_hidden(self, bsz): 64 | weight = next(self.parameters()) 65 | if self.rnn_type == 'LSTM': 66 | return (weight.new_zeros(self.nlayers, bsz, self.nhid), 67 | weight.new_zeros(self.nlayers, bsz, self.nhid)) 68 | else: 69 | return weight.new_zeros(self.nlayers, bsz, self.nhid) 70 | 71 | 72 | ####################################################### 73 | ### DICTIONARY 74 | ####################################################### 75 | 76 | class Dictionary(object): 77 | def __init__(self): 78 | self.token2idx = {} 79 | self.idx2token = [] 80 | 81 | def add_token(self, word): 82 | if word not in self.token2idx: 83 | self.idx2token.append(word) 84 | self.token2idx[word] = len(self.idx2token) - 1 85 | return self.token2idx[word] 86 | 87 | def __len__(self): 88 | return len(self.idx2token) 89 | 90 | ############################################################################### 91 | ### CORPUS 92 | ################################################################################ 93 | 94 | class Corpus(object): 95 | 96 | def __init__(self, train_path): 97 | self.dictionary = Dictionary() 98 | self.train_data = self.tokenize(train_path) if train_path is not None and os.path.exists(train_path) else None 99 | 100 | def tokenize(self, path): 101 | """Tokenizes a text file.""" 102 | pass 103 | 104 | class WordCorpus(Corpus): 105 | 106 | def tokenize(self, path): 107 | super(WordCorpus, self).tokenize(path) 108 | # Add words to the dictionary 109 | tokens = 0 110 | with codecs.open(path, 'r', encoding="utf8") as f: 111 | for line in f: 112 | words = line.split() + [''] 113 | tokens += len(words) 114 | for word in words: 115 | self.dictionary.add_token(word) 116 | 117 | # Tokenize file content 118 | with codecs.open(path, 'r', encoding="utf8") as f: 119 | ids = torch.LongTensor(tokens) 120 | token = 0 121 | for line in f: 122 | words = line.split() + [''] 123 | for word in words: 124 | ids[token] = self.dictionary.token2idx[word] 125 | token += 1 126 | return ids 127 | 128 | class CharCorpus(Corpus): 129 | 130 | def tokenize(self, path): 131 | super(CharCorpus, self).tokenize(path) 132 | tokens = 0 133 | with codecs.open(path, 'r', encoding="utf8") as f: 134 | for line in f: 135 | tokens += len(line) 136 | for c in line: 137 | self.dictionary.add_token(c) 138 | with codecs.open(path, 'r', encoding="utf8") as f: 139 | ids = torch.LongTensor(tokens) 140 | token = 0 141 | for line in f: 142 | for c in line: 143 | ids[token] = self.dictionary.token2idx[c] 144 | token += 1 145 | return ids 146 | 147 | ############################################################################### 148 | # Globals 149 | ############################################################################### 150 | 151 | RUN_STEPS = 600 152 | RUN_TEMPERATURE = 0.5 153 | SEED = 1 154 | HISTORY = 35 155 | LAYERS = 2 156 | EPOCHS = 50 157 | HIDDEN_NODES = 512 158 | BATCH_SIZE = 10 159 | MODEL_TYPE = 'GRU' 160 | DROPOUT = 0.2 161 | TIED = False 162 | EMBED_SIZE = HIDDEN_NODES 163 | CLIP = 0.25 164 | LR = 0.0001 165 | LR_DECAY = 0.1 166 | LOG_INTERVAL = 10 167 | 168 | ############################################################################# 169 | # MAIN entrypoints 170 | ############################################################################ 171 | 172 | 173 | def wordLSTM_Run(model_path, dictionary_path, output_path, seed = SEED, 174 | steps = RUN_STEPS, temperature = RUN_TEMPERATURE, k = 0): 175 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 176 | 177 | # Load the model 178 | model = None 179 | with open(model_path, 'rb') as f: 180 | model = torch.load(f) 181 | model = model.to(device) 182 | model.eval() 183 | 184 | # Load the dictionary 185 | dictionary = None 186 | with open(dictionary_path, 'rb') as f: 187 | dictionary = torch.load(f) 188 | ntokens = len(dictionary) 189 | 190 | hidden = model.init_hidden(1) 191 | 192 | input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) 193 | 194 | if seed is not None: 195 | seed_words = seed.strip().split() 196 | if len(seed_words) > 1: 197 | for i in range(len(seed_words)-1): 198 | word = seed_words[i] 199 | if word in dictionary.idx2token: 200 | input = torch.tensor([[dictionary.token2idx[word]]], dtype=torch.long).to(device) 201 | output, hidden = model(input, hidden) 202 | if len(seed_words) > 0: 203 | input = torch.tensor([[dictionary.token2idx[seed_words[-1]]]], dtype=torch.long).to(device) 204 | 205 | with codecs.open(output_path, 'w', encoding="utf8") as outf: 206 | if seed is not None and len(seed) > 0: 207 | outf.write(seed.strip() + ' ') 208 | with torch.no_grad(): # no tracking history 209 | for i in range(steps): 210 | output, hidden = model(input, hidden) 211 | word_weights = output.squeeze().div(temperature).exp().cpu() 212 | word_idx = None 213 | if k > 0: 214 | # top-k sampling 215 | word_idx = top_k_sample(word_weights, k) 216 | else: 217 | word_idx = torch.multinomial(word_weights, 1)[0] 218 | input.fill_(word_idx) 219 | word = dictionary.idx2token[word_idx] 220 | 221 | outf.write(word + ' ' if i < steps-1 else '') 222 | 223 | ### Top K sampling 224 | def top_k_sample(logits, k): 225 | values, _ = torch.topk(logits, k) 226 | min_value = values.min() 227 | mask = logits >= min_value 228 | new_logits = logits * mask.float() 229 | return torch.multinomial(new_logits, 1)[0] 230 | 231 | 232 | 233 | 234 | def charLSTM_Run(model_path, dictionary_path, output_path, seed = SEED, 235 | steps = RUN_STEPS, temperature = RUN_TEMPERATURE): 236 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 237 | 238 | model = None 239 | with open(model_path, 'rb') as f: 240 | model = torch.load(f) 241 | model = model.to(device) 242 | model.eval() 243 | 244 | dictionary = None 245 | with open(dictionary_path, 'rb') as f: 246 | dictionary = torch.load(f) 247 | ntokens = len(dictionary) 248 | 249 | hidden = model.init_hidden(1) 250 | 251 | input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) 252 | 253 | if seed is not None: 254 | if len(seed) > 1: 255 | for i in range(len(seed)-1): 256 | char = seed[i] 257 | if char in dictionary.idx2token: 258 | input = torch.tensor([[dictionary.token2idx[char]]], dtype=torch.long).to(device) 259 | output, hidden = model(input, hidden) 260 | if len(seed) > 0: 261 | input = torch.tensor([[dictionary.token2idx[seed[-1]]]], dtype=torch.long).to(device) 262 | 263 | 264 | text = seed 265 | with torch.no_grad(): # no tracking history 266 | for i in range(steps): 267 | output, hidden = model(input, hidden) 268 | char_weights = output.squeeze().div(temperature).exp().cpu() 269 | char_idx = torch.multinomial(char_weights, 1)[0] 270 | input.fill_(char_idx) 271 | char = dictionary.idx2token[char_idx] 272 | text = text + char 273 | 274 | with codecs.open(output_path, 'w', encoding="utf8") as outf: 275 | outf.write(text) 276 | 277 | def wordLSTM_Train(train_data_path, 278 | dictionary_path, model_out_path, 279 | history = HISTORY, layers = LAYERS, epochs = EPOCHS, hidden_nodes = HIDDEN_NODES, 280 | batch_size = BATCH_SIZE, model_type=MODEL_TYPE, dropout = DROPOUT, tied = TIED, embed_size = EMBED_SIZE, 281 | clip = CLIP, lr = LR, 282 | lr_decay = LR_DECAY, 283 | log_interval = LOG_INTERVAL): 284 | train(train_data_path, 285 | dictionary_path, model_out_path, 286 | history = history, layers = layers, epochs = epochs, hidden_nodes = hidden_nodes, 287 | batch_size = batch_size, model_type=model_type, dropout = dropout, tied = tied, embed_size = embed_size, 288 | clip = clip, lr = lr, 289 | lr_decay = lr_decay, 290 | log_interval = log_interval, 291 | corpus_type = WordCorpus) 292 | 293 | def wordLSTM_Train_More(train_data_path, model_in_path, dictionary_path, model_out_path, 294 | history = HISTORY, epochs = EPOCHS, batch_size = BATCH_SIZE, 295 | clip = CLIP, lr = LR, lr_decay = LR_DECAY, log_interval = LOG_INTERVAL): 296 | train_more(train_data_path, model_in_path, dictionary_path, model_out_path, 297 | history = history, epochs = epochs, batch_size = batch_size, 298 | clip = clip, lr = lr, lr_decay = lr_decay, log_interval = log_interval, 299 | corpus_type = WordCorpus) 300 | 301 | def charLSTM_Train(train_data_path, 302 | dictionary_path, model_out_path, 303 | history = HISTORY, layers = LAYERS, epochs = EPOCHS, hidden_nodes = HIDDEN_NODES, 304 | batch_size = BATCH_SIZE, model_type=MODEL_TYPE, dropout = DROPOUT, tied = TIED, embed_size = EMBED_SIZE, 305 | clip = CLIP, lr = LR, 306 | lr_decay = LR_DECAY, 307 | log_interval = LOG_INTERVAL): 308 | train(train_data_path, 309 | dictionary_path, model_out_path, 310 | history = history, layers = layers, epochs = epochs, hidden_nodes = hidden_nodes, 311 | batch_size = batch_size, model_type=model_type, dropout = dropout, tied = tied, embed_size = embed_size, 312 | clip = clip, lr = lr, 313 | lr_decay = lr_decay, 314 | log_interval = log_interval, 315 | corpus_type = CharCorpus) 316 | 317 | def charLSTM_Train_More(train_data_path, model_in_path, dictionary_path, model_out_path, 318 | history = HISTORY, epochs = EPOCHS, batch_size = BATCH_SIZE, 319 | clip = CLIP, lr = LR, lr_decay = LR_DECAY, log_interval = LOG_INTERVAL): 320 | train_more(train_data_path, model_in_path, dictionary_path, model_out_path, 321 | history = history, epochs = epochs, batch_size = batch_size, 322 | clip = clip, lr = lr, lr_decay = lr_decay, log_interval = log_interval, 323 | corpus_type = CharCorpus) 324 | 325 | 326 | 327 | ################################################# 328 | ### TRAIN 329 | ################################################## 330 | 331 | def train_more(train_data_path, model_in_path, dictionary_path, model_out_path, 332 | history = HISTORY, epochs = EPOCHS, batch_size = BATCH_SIZE, 333 | clip = CLIP, lr = LR, lr_decay = LR_DECAY, log_interval = LOG_INTERVAL, 334 | corpus_type = WordCorpus): 335 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 336 | corpus = corpus_type(train_data_path) 337 | dictionary = corpus.dictionary 338 | with open(dictionary_path, 'wb') as f: 339 | dictionary = torch.load(f) 340 | with open(model_in_path, 'wb') as f: 341 | model = torch.load(f) 342 | model = model.to(device) 343 | train_loop(model, corpus, model_out_path, 344 | history = history, epochs = epochs, batch_size = batch_size, 345 | clip = clip, lr = lr, lr_decay = lr_decay, log_interval = log_interval) 346 | 347 | 348 | def train(train_data_path, dictionary_path, model_out_path, 349 | history = HISTORY, layers = LAYERS, epochs = EPOCHS, hidden_nodes = HIDDEN_NODES, 350 | batch_size = BATCH_SIZE, model_type=MODEL_TYPE, dropout = DROPOUT, tied = TIED, embed_size = EMBED_SIZE, 351 | clip = CLIP, lr = LR, 352 | lr_decay = LR_DECAY, 353 | log_interval = LOG_INTERVAL, 354 | corpus_type = WordCorpus): 355 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 356 | 357 | corpus = corpus_type(train_data_path) 358 | dictionary = corpus.dictionary 359 | with open(dictionary_path, 'wb') as f: 360 | torch.save(dictionary, f) 361 | 362 | ### BUILD THE MODEL 363 | ntokens = len(dictionary) 364 | model = RNNModel(model_type, ntokens, embed_size, hidden_nodes, layers, dropout, tied) 365 | model = model.to(device) 366 | train_loop(model, corpus, model_out_path, 367 | history = history, epochs = epochs, batch_size = batch_size, 368 | clip = clip, lr = lr, lr_decay = lr_decay, log_interval = log_interval) 369 | 370 | def train_loop(model, corpus, model_path, 371 | history = HISTORY, epochs = EPOCHS, batch_size = BATCH_SIZE, 372 | clip = CLIP, lr = LR, lr_decay = LR_DECAY, log_interval = LOG_INTERVAL): 373 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 374 | # Starting from sequential data, batchify arranges the dataset into columns. 375 | # For instance, with the alphabet as the sequence and batch size 4, we'd get 376 | # | a g m s | 377 | # | b h n t | 378 | # | c i o u | 379 | # | d j p v | 380 | # | e k q w | 381 | # | f l r x |. 382 | # These columns are treated as independent by the model, which means that the 383 | # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient 384 | # batch processing. 385 | train_data = batchify(corpus.train_data, batch_size, device) 386 | val_data = batchify(corpus.train_data[0:corpus.train_data.size()[0]//10], batch_size, device) 387 | dictionary = corpus.dictionary 388 | ntokens = len(dictionary) 389 | 390 | criterion = nn.CrossEntropyLoss() 391 | 392 | best_val_loss = None 393 | log_interval = max(1, (len(train_data) // history) // log_interval) 394 | 395 | ### TRAIN 396 | for epoch in range(1, epochs+1): 397 | epoch_start_time = time.time() 398 | 399 | model.train() 400 | optimizer = optim.Adam(model.parameters(), lr=lr) 401 | 402 | total_loss = 0.0 403 | start_time = time.time() 404 | hidden = model.init_hidden(batch_size) 405 | 406 | for batch, i in enumerate(range(0, train_data.size(0) - 1, history)): 407 | data, targets = get_batch(train_data, i, batch_size) 408 | # Starting each batch, we detach the hidden state from how it was previously produced. 409 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 410 | hidden = repackage_hidden(hidden) 411 | #model.zero_grad() 412 | optimizer.zero_grad() 413 | output, hidden = model(data, hidden) 414 | loss = criterion(output.view(-1, ntokens), targets) 415 | loss.backward() 416 | optimizer.step() 417 | 418 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 419 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 420 | ''' 421 | for p in model.parameters(): 422 | p.data.add_(-lr, p.grad.data) 423 | ''' 424 | 425 | total_loss += loss.item() 426 | 427 | if batch % log_interval == 0 and batch > 0: 428 | cur_loss = total_loss / log_interval 429 | elapsed = time.time() - start_time 430 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | ' 431 | 'loss {:5.2f} | ppl {:8.2f}'.format( 432 | epoch, batch, len(train_data) // history, lr, 433 | elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss) if cur_loss < 1000 else float('inf'))) 434 | total_loss = 0 435 | start_time = time.time() 436 | 437 | ### EVALUATE 438 | val_loss = evaluate(model, val_data, criterion, dictionary, batch_size, history) 439 | print('-' * 89) 440 | print('| end of epoch {:3d} | time: {:5.2f}s | train loss {:5.2f} | ' 441 | 'train ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 442 | val_loss, math.exp(val_loss) if val_loss < 1000 else float('inf'))) 443 | print('-' * 89) 444 | # Save the model if the validation loss is the best we've seen so far. 445 | if best_val_loss is None: 446 | best_val_loss = val_loss 447 | if val_loss <= best_val_loss: 448 | with open(model_path, 'wb') as f: 449 | torch.save(model, f) 450 | best_val_loss = val_loss 451 | else: 452 | # Anneal the learning rate if no improvement has been seen in the validation dataset. 453 | lr = lr * lr_decay 454 | 455 | 456 | 457 | ####################################################### 458 | ### HELPERS 459 | ###################################################### 460 | 461 | 462 | def batchify(data, bsz, device): 463 | # Work out how cleanly we can divide the dataset into bsz parts. 464 | nbatch = data.size(0) // bsz 465 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 466 | data = data.narrow(0, 0, nbatch * bsz) 467 | # Evenly divide the data across the bsz batches. 468 | data = data.view(bsz, -1).t().contiguous() 469 | return data.to(device) 470 | 471 | def repackage_hidden(h): 472 | """Wraps hidden states in new Tensors, to detach them from their history.""" 473 | if isinstance(h, torch.Tensor): 474 | return h.detach() 475 | else: 476 | return tuple(repackage_hidden(v) for v in h) 477 | 478 | # get_batch subdivides the source data into chunks of length args.bptt. 479 | # If source is equal to the example output of the batchify function, with 480 | # a bptt-limit of 2, we'd get the following two Variables for i = 0: 481 | # | a g m s | | b h n t | 482 | # | b h n t | | c i o u | 483 | # Note that despite the name of the function, the subdivison of data is not 484 | # done along the batch dimension (i.e. dimension 1), since that was handled 485 | # by the batchify function. The chunks are along dimension 0, corresponding 486 | # to the seq_len dimension in the LSTM. 487 | def get_batch(source, i, history): 488 | seq_len = min(history, len(source) - 1 - i) 489 | data = source[i:i+seq_len] 490 | target = source[i+1:i+1+seq_len].view(-1) 491 | return data, target 492 | 493 | 494 | def evaluate(model, data_source, criterion, dictionary, batch_size, history): 495 | # Turn on evaluation mode which disables dropout. 496 | model.eval() 497 | total_loss = 0.0 498 | ntokens = len(dictionary) 499 | hidden = model.init_hidden(batch_size) 500 | with torch.no_grad(): 501 | for i in range(0, data_source.size(0) - 1, history): 502 | data, targets = get_batch(data_source, i, batch_size) 503 | output, hidden = model(data, hidden) 504 | output_flat = output.view(-1, ntokens) 505 | total_loss += len(data) * criterion(output_flat, targets).item() 506 | hidden = repackage_hidden(hidden) 507 | return total_loss / (len(data_source) - 1) 508 | 509 | ###################################### 510 | ### TESTING 511 | 512 | 513 | 514 | if __name__ == "__main__": 515 | print("running") 516 | train_data_path = 'datasets/origin_train' 517 | val_data_path = 'datasets/origin_valid' 518 | dictionary_path = 'origin_dictionary' 519 | model_out_path = 'origin.model' 520 | output_path = 'foo_out.txt' 521 | seed = 'the' 522 | print("training") 523 | wordLSTM_Train(train_data_path, 524 | dictionary_path, 525 | model_out_path, 526 | epochs = 1) 527 | print("running") 528 | wordLSTM_Run(model_out_path, dictionary_path, output_path, seed = seed, k = 20) 529 | -------------------------------------------------------------------------------- /module_dicts.js: -------------------------------------------------------------------------------- 1 | var module_dicts = [{"name" : "ReadWikipedia", 2 | "params" : [{"name" : "wiki_directory", "type" : "directory", "default" : "wiki"}, 3 | {"name": "pattern", "type" : "string", "default": "*"}, 4 | {"name" : "categories", "type" : "string", "default" : "*"}, 5 | {"name" : "out_file", "type" : "text", "out" : true}, 6 | {"name" : "titles_file", "type" : "text", "out" : true}], 7 | "category" : "Wikipedia"}, 8 | {"name" : "WordRNN_Train", 9 | "params" : [{"name" : "data", "in" : true, "type" : "text"}, 10 | {"name" : "history", "type" : "int", "default" : 35}, 11 | {"name" : "layers", "type" : "int", "default": 2}, 12 | {"name" : "hidden_nodes", "type" : "int", "default" : 512}, 13 | {"name" : "epochs", "type" : "int", "default" : 50}, 14 | {"name" : "learning_rate", "type" : "float", "default": 0.0001}, 15 | {"name" : "model", "type" : "model", "out" : true}, 16 | {"name" : "dictionary", "type" : "dictionary", "out" : true}], 17 | "category" : "RNN"}, 18 | {"name" : "WordRNN_Run", 19 | "params" : [{"name" : "model", "in" : true, "type" : "model"}, 20 | {"name" : "dictionary", "in" : true, "type" : "dictionary"}, 21 | {"name" : "seed", "in" : true, "type" : "text"}, 22 | {"name" : "steps", "type" : "int", "default" : "600"}, 23 | {"name" : "temperature", "type" : "float", "default" : "0.5"}, 24 | {"name" : "k", "type" : "int", "default" : "40"}, 25 | {"name" : "output", "out" : true, "type" : "text"}], 26 | "category" : "RNN"}, 27 | {"name" : "MakeString", 28 | "params" : [{"name" : "string", "type" : "string"}, 29 | {"name" : "output", "out" : true, "type" : "text"}], 30 | "category" : "Input"}, 31 | {"name" : "RemoveEmptyLines", 32 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 33 | {"name" : "output", "type" : "text", "out" : true}], 34 | "category" : "Text"}, 35 | {"name" : "SplitSentences", 36 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 37 | {"name" : "output", "type" : "text", "out" : true}], 38 | "category" : "Text"}, 39 | {"name" : "ReplaceCharacters", 40 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 41 | {"name" : "find", "type" : "string"}, 42 | {"name" : "replace", "type" : "string"}, 43 | {"name" : "output", "type" : "text", "out" : true}], 44 | "category" : "Text"}, 45 | {"name" : "ReadTextFile", 46 | "params" : [{"name" : "file", "type" : "directory"}, 47 | {"name" : "output", "type" : "text", "out" : true}], 48 | "category" : "File"}, 49 | {"name" : "WriteTextFile", 50 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 51 | {"name" : "file", "type" : "directory"}], 52 | "category" : "File"}, 53 | {"name" : "MakeLowercase", 54 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 55 | {"name" : "output", "type" : "text", "out" : true}], 56 | "category" : "Text"}, 57 | {"name" : "Wordify", 58 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 59 | {"name" : "output", "type" : "text", "out" : true}], 60 | "category" : "Text"}, 61 | {"name" : "RemoveTags", 62 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 63 | {"name" : "output", "type" : "text", "out" : true}], 64 | "category" : "HTML"}, 65 | {"name" : "CleanText", 66 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 67 | {"name" : "output", "type" : "text", "out" : true}], 68 | "category" : "Text"}, 69 | {"name" : "SaveModel", 70 | "params" : [{"name" : "model", "in" : true, "type" : "model"}, 71 | {"name" : "file", "type" : "directory"}], 72 | "category" : "File"}, 73 | {"name" : "LoadModel", 74 | "params" : [{"name" : "file", "type" : "directory"}, 75 | {"name" : "model", "out" : true, "type" : "model"}], 76 | "category" : "File"}, 77 | {"name" : "SaveDictionary", 78 | "params" : [{"name" : "dictionary", "in" : true, "type" : "dictionary"}, 79 | {"name" : "file", "type" : "directory"}], 80 | "category" : "File"}, 81 | {"name" : "LoadDictionary", 82 | "params" : [{"name" : "file", "type" : "directory"}, 83 | {"name" : "dictionary", "out" : true, "type" : "dictionary"}], 84 | "category" : "File"}, 85 | {"name" : "SplitHTML", 86 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 87 | {"name" : "output", "type" : "text", "out" : true}], 88 | "category" : "HTML"}, 89 | {"name" : "RandomSequence", 90 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 91 | {"name" : "length", "type" : "int", "default" : 100}, 92 | {"name" : "output", "type" : "text", "out" : true}], 93 | "category" : "Text"}, 94 | {"name" : "ConcatenateTextFiles", 95 | "params" : [{"name" : "input_1", "type" : "text", "in" : true}, 96 | {"name" : "input_2", "type" : "text", "in" : true}, 97 | {"name" : "output", "type" : "text", "out" : true}], 98 | "category" : "Text"}, 99 | {"name" : "RandomizeLines", 100 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 101 | {"name" : "output", "type" : "text", "out" : true}], 102 | "category" : "Text"}, 103 | {"name" : "KeepFirstLine", 104 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 105 | {"name" : "output", "type" : "text", "out" : true}], 106 | "category" : "Text"}, 107 | {"name" : "DeleteFirstLine", 108 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 109 | {"name" : "output", "type" : "text", "out" : true}], 110 | "category" : "Text"}, 111 | {"name" : "DeleteLastLine", 112 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 113 | {"name" : "output", "type" : "text", "out" : true}], 114 | "category" : "Text"}, 115 | {"name" : "KeepLineWhen", 116 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 117 | {"name" : "match", "type" : "string"}, 118 | {"name" : "output", "type" : "text", "out" : true}], 119 | "category" : "Text"}, 120 | {"name" : "KeepLineUnless", 121 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 122 | {"name" : "match", "type" : "string"}, 123 | {"name" : "output", "type" : "text", "out" : true}], 124 | "category" : "Text"}, 125 | {"name" : "Sort", 126 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 127 | {"name" : "output", "type" : "text", "out" : true}], 128 | "category" : "Text"}, 129 | {"name" : "Reverse", 130 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 131 | {"name" : "output", "type" : "text", "out" : true}], 132 | "category" : "Text"}, 133 | {"name" : "GPT2_FineTune", 134 | "params" : [{"name" : "model_in", "type" : "model", "in" : true}, 135 | {"name" : "data", "type" : "text", "in" : true}, 136 | {"name" : "model_size", "type" : "string", "default" : "117M"}, 137 | {"name" : "steps", "type" : "int", "default" : 1000}, 138 | {"name" : "model_out", "type" : "model", "out" : true}], 139 | "category" : "GPT2"}, 140 | {"name" : "GPT2_Run", 141 | "params" : [{"name" : "model_in", "type" : "model", "in" : true}, 142 | {"name" : "prompt", "type" : "text", "in" : true}, 143 | {"name" : "model_size", "type" : "string", "default" : "117M"}, 144 | {"name" : "top_k", "type" : "int", "default" : 40}, 145 | {"name" : "temperature", "type" : "float", "default" : 1.0}, 146 | {"name" : "num_samples", "type" : "int", "default" : 1}, 147 | {"name" : "output", "type" : "text", "out" : true}], 148 | "category" : "GPT2"}, 149 | {"name" : "CharRNN_Train", 150 | "params" : [{"name" : "data", "in" : true, "type" : "text"}, 151 | {"name" : "history", "type" : "int", "default" : 35}, 152 | {"name" : "layers", "type" : "int", "default": 2}, 153 | {"name" : "hidden_nodes", "type" : "int", "default" : 512}, 154 | {"name" : "epochs", "type" : "int", "default" : 50}, 155 | {"name" : "learning_rate", "type" : "float", "default": 0.0001}, 156 | {"name" : "model", "type" : "model", "out" : true}, 157 | {"name" : "dictionary", "type" : "dictionary", "out" : true}], 158 | "category" : "RNN"}, 159 | {"name" : "CharRNN_Run", 160 | "params" : [{"name" : "model", "in" : true, "type" : "model"}, 161 | {"name" : "dictionary", "in" : true, "type" : "dictionary"}, 162 | {"name" : "seed", "in" : true, "type" : "text"}, 163 | {"name" : "steps", "type" : "int", "default" : "600"}, 164 | {"name" : "temperature", "type" : "float", "default" : "0.5"}, 165 | {"name" : "output", "out" : true, "type" : "text"}], 166 | "category" : "RNN"}, 167 | {"name" : "UserInput", 168 | "params" : [{"name" : "prompt", "type" : "string", "default" : "prompt"}, 169 | {"name" : "output", "type" : "text", "out" : true}], 170 | "category" : "Input"}, 171 | {"name" : "Regex_Search", 172 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 173 | {"name" : "expression", "type" : "string", "default" : "*"}, 174 | {"name" : "output", "type" : "text", "out" : true}, 175 | {"name" : "group_1", "type" : "text", "out" : true}, 176 | {"name" : "group_2", "type" : "text", "out" : true}], 177 | "category" : "Regex"}, 178 | {"name" : "Regex_Sub", 179 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 180 | {"name" : "expression", "type" : "string", "default" : ""}, 181 | {"name" : "replacement", "type" : "string", "default" : ""}, 182 | {"name" : "output", "type" : "text", "out" : true}], 183 | "category" : "Regex"}, 184 | {"name" : "PrintText", 185 | "params" : [{"name" : "input", "type" : "text", "in" : true}], 186 | "category" : "Utils"}, 187 | {"name": "ReadFromWeb", 188 | "params" : [{"name" : "url", "type" : "string", "default" : ""}, 189 | {"name" : "data", "type" : "text", "out" : true}], 190 | "category" : "Web"}, 191 | {"name" : "MakeCountFile", 192 | "params" : [{"name" : "num", "type" : "int", "default" : "10"}, 193 | {"name" : "prefix", "type" : "string", "default" : ""}, 194 | {"name" : "postfix", "type" : "string", "default" : ""}, 195 | {"name" : "output", "type" : "text", "out" : true}], 196 | "category" : "Utils"}, 197 | {"name" : "ReadAllFromWeb", 198 | "params" : [{"name" : "urls", "type" : "text", "in" : true}, 199 | {"name" : "data", "type" : "text", "out" : true}], 200 | "category" : "Web"}, 201 | {"name" : "RemoveDuplicates", 202 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 203 | {"name" : "output", "type" : "text", "out" : true}], 204 | "category" : "Text"}, 205 | {"name" : "StripLines", 206 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 207 | {"name" : "output", "type" : "text", "out" : true}], 208 | "category" : "Text"}, 209 | {"name" : "TextSubtract", 210 | "params" : [{"name" : "main", "type" : "text", "in" : true}, 211 | {"name" : "subtract", "type" : "text", "in" : true}, 212 | {"name" : "diff", "type" : "text", "out" : true}], 213 | "category" : "Text"}, 214 | {"name" : "DuplicateText", 215 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 216 | {"name" : "count", "type" : "int", "default" : 1}, 217 | {"name" : "output", "type" : "text", "out" : true}], 218 | "category" : "Text"}, 219 | {"name" : "Spellcheck", 220 | "params" : [{"name" : "input", "type" : "text", "in" : true}, 221 | {"name" : "output", "type" : "text", "out" : true}], 222 | "category" : "Text"}, 223 | {"name" : "WebCrawl", 224 | "params" : [{"name" : "url", "type" : "string", "default" : ""}, 225 | {"name" : "link_id", "type" : "string", "default" : ""}, 226 | {"name" : "link_text", "type" : "string", "default" : ""}, 227 | {"name" : "max_hops", "type" : "int", "default" : "10"}, 228 | {"name" : "output", "type" : "text", "out" : true}], 229 | "category" : "Web"}, 230 | {"name" : "ScrapePinterest", 231 | "params" : [{"name" : "url", "type" : "string", "default" : ""}, 232 | {"name" : "username", "type" : "string", "default" : ""}, 233 | {"name" : "password", "type" : "string", "default" : ""}, 234 | {"name" : "target", "type" : "int", "default" : "100"}, 235 | {"name" : "output", "type" : "images", "out" : true}], 236 | "category" : "Images"}, 237 | {"name" : "LoadImages", 238 | "params" : [{"name" : "directory", "type" : "directory", "default" : ""}, 239 | {"name" : "images", "type" : "images", "out" : true}], 240 | "category" : "File"}, 241 | {"name" : "SaveImages", 242 | "params" : [{"name" : "images", "type" : "images", "in" : true}, 243 | {"name" : "directory", "type" : "directory", "default" : ""}], 244 | "category" : "File"}, 245 | {"name" : "ResizeImages", 246 | "params" : [{"name" : "input", "type" : "images", "in" : true}, 247 | {"name" : "size", "type" : "int", "default" : "256"}, 248 | {"name" : "output", "type" : "images", "out" : true}], 249 | "category" : "Images"}, 250 | {"name" : "RemoveGrayscale", 251 | "params" : [{"name" : "input", "type" : "images", "in" : true}, 252 | {"name" : "output", "type" : "images", "out" : true}, 253 | {"name" : "rejects", "type" : "images", "out" : true}], 254 | "category" : "Images"}, 255 | {"name" : "CropFaces", 256 | "params" : [{"name" : "input", "type" : "images", "in" : true}, 257 | {"name" : "size", "type" : "int", "default" : "256"}, 258 | {"name" : "output", "type" : "images", "out" : true}, 259 | {"name" : "rejects", "type" : "images", "out" : true}], 260 | "category" : "Images"}, 261 | {"name" : "StyleGAN_FineTune", 262 | "params" : [{"name" : "model_in", "type" : "model", "in" : true}, 263 | {"name" : "images", "type" : "images", "in" : true}, 264 | {"name" : "start_kimg", "type" : "int", "default" : "7000"}, 265 | {"name" : "max_kimg", "type" : "int", "default" : "25000"}, 266 | {"name" : "seed", "type" : "int", "default" : "1000"}, 267 | {"name" : "schedule", "type" : "string", "default" : ""}, 268 | {"name" : "model_out", "type" : "model", "out" : true}, 269 | {"name" : "animation", "type" : "images", "out" : true}], 270 | "category" : "Images"}, 271 | {"name" : "StyleGAN_Run", 272 | "params" : [{"name" : "model", "type" : "model", "in" : true}, 273 | {"name" : "num", "type" : "int", "default" : "1"}, 274 | {"name" : "images", "type" : "images", "out" : true}], 275 | "category" : "Images"}, 276 | {"name" : "StyleGAN_Movie", 277 | "params" : [{"name" : "model", "type" : "model", "in" : true}, 278 | {"name" : "length", "type" : "int", "default" : "10"}, 279 | {"name" : "interp", "type" : "int", "default" : "10"}, 280 | {"name" : "duration", "type" : "int", "default" : "10"}, 281 | {"name" : "movie", "type" : "images", "out" : true}], 282 | "category" : "Images"}, 283 | {"name" : "MakeMovie", 284 | "params" : [{"name" : "images", "type" : "images", "in" : true}, 285 | {"name" : "duration", "type" : "int", "default" : "10"}, 286 | {"name" : "movie", "type" : "images", "out" : true}], 287 | "category" : "Images"}, 288 | {"name" : "Gridify", 289 | "params" : [{"name" : "input", "type" : "images", "in" : true}, 290 | {"name" : "size", "type" : "int", "default" : "256"}, 291 | {"name" : "columns", "type" : "int", "default" : "4"}, 292 | {"name" : "output", "type" : "images", "out" : true}], 293 | "category" : "Images"}, 294 | {"name" : "Degridify", 295 | "params" : [{"name" : "input", "type" : "images", "in" : true}, 296 | {"name" : "columns", "type" : "int", "default" : "4"}, 297 | {"name" : "rows", "type" : "int", "default" : "4"}, 298 | {"name" : "output", "type" : "images", "out" : true}], 299 | "category" : "Images"}, 300 | {"name" : "StyleTransfer", 301 | "params" : [{"name" : "content_image", "type" : "images", "in" : true}, 302 | {"name" : "style_image", "type" : "images", "in" : true}, 303 | {"name" : "steps", "type" : "int", "default" : "1000"}, 304 | {"name" : "size", "type" : "int", "default" : "512"}, 305 | {"name" : "style_weight", "type" : "int", "default" : "1000000"}, 306 | {"name" : "content_weight", "type" : "int", "default" : "1"}, 307 | {"name" : "content_layers", "type" : "string", "default" : "4"}, 308 | {"name" : "style_layers", "type" : "string", "default" : "1, 2, 3, 4, 5"}, 309 | {"name" : "output", "type" : "images", "out" : true}], 310 | "category" : "Images"}, 311 | {"name" : "JoinImageDirectories", 312 | "params" : [{"name" : "dir1", "type" : "images", "in" : true}, 313 | {"name" : "dir2", "type" : "images", "in" : true}, 314 | {"name" : "output", "type" : "images", "out" : true}], 315 | "category" : "File"}, 316 | {"name" : "SquareCrop", 317 | "params" : [{"name" : "input", "type" : "images", "in" : true}, 318 | {"name" : "output", "type" : "images", "out" : true}], 319 | "category" : "Images"}, 320 | {"name" : "UnmakeMovie", 321 | "params" : [{"name" : "movie", "type" : "images", "in" : true}, 322 | {"name" : "output", "type" : "images", "out" : true}], 323 | "category" : "Images"} 324 | ]; 325 | /* 326 | 327 | {"name" : "MakePredictionData", "params" : "data(in,text);x(out,text);y(out,text)", "tip" : "Prepare data for prediction--each line will try to predict the next line", "category" : "no"}, \ 328 | {"name" : "DCGAN_Train", "params" : "input_images(images,in);epochs(int=10);input_height(int=108);output_height(int=108);filetype(string=jpg);model(out,model);animation(out,image)", "tip" : "Train a generateive adversarial network to make images", "category" : "no"}, \ 329 | {"name" : "DCGAN_Run", "params" : "input_images(images,in);model(in,model);input_height(int=108);output_height(int=108);filetype(string=jpg);output_image(out,image)", "tip" : "Generate an image from a generative adversarial network", "category" : "no"}, \ 330 | {"name" : "ReadImages", "params" : "data_directory(directory);output_images(out,images)", "tip" : "Read in a directory of image files", "category" : "File"}, \ 331 | {"name" : "WriteImages", "params" : "input_images(in,images);output_directory(directory)", "tip" : "Save a group of images to a directory", "category" : "File"}, \ 332 | {"name" : "PickFromWikipedia", "params" : "wiki_directory(directory,tip=Directory where wikipedia files are stored);input(in,text);catgories(string=*,tip=What categories if any?);section_name(string,tip=What section to pull text from if any);output(out,text);break_sentences(bool=false,tip=Should text be broken into one sentence per line?)", "tip" : "Pull text from wikipedia for the articles specified (file with one title per line)", "category" : "Wikipedia"}, \ 333 | {"name" : "Repeat", "params" : "input(in,text);output(out,text);times(int)", "category" : "Do not use"}, \ 334 | {"name" : "StyleNet_Train", "params" : "style_image(in,image);test_image(in,image);epochs(int=2,tip=how long to run);model(out,model);animation(out,image)", "tip" : "Draw the target image in the style of the style image", "category" : "no"}, \ 335 | {"name" : "StyleNet_Run", "params" : "model(in,model);target_image(in,image,tip=Image to stylize);output_image(out,image)", "tip" : "Apply a style learned by a neural net to an image", "category" : "no"}, \ 336 | {"name" : "ReadImageFile", "params" : "file(string,tip=Name of image file to read in);output(out,image)", "tip" : "Read an image file in", "category" : "File"}, \ 337 | {"name" : "WriteImageFile", "params" : "input(in,image);file(string,tip=Name of image file to write to)", "tip" : "Write an image to file", "category" : "File"}, \ 338 | {"name" : "MakeEmptyText", "params" : "output(out,text)", "tip" : "Create an empty text file", "category" : "Utils"}, \ 339 | ]'; 340 | */ -------------------------------------------------------------------------------- /image_modules.py: -------------------------------------------------------------------------------- 1 | from module import * 2 | import requests 3 | import os 4 | import re 5 | import time 6 | import random 7 | import pickle 8 | import numpy as np 9 | import math 10 | import copy 11 | from torchvision import transforms 12 | from torchvision import utils 13 | from PIL import Image 14 | import shutil 15 | import pdb 16 | import subprocess 17 | import sys 18 | 19 | #aaah 20 | 21 | 22 | 23 | ############################## 24 | 25 | class ScrapePinterest(Module): 26 | 27 | def __init__(self, url, username, password, target, output): 28 | self.url = url # Initial url (must be at Pinterest) 29 | self.username = username # Pinterest username 30 | self.password = password # Pinterest password 31 | self.target = target # target number of images to download 32 | self.output = output # path to output directory to store image files 33 | self.ready = True 34 | self.output_files = [output] 35 | 36 | def run(self): 37 | # Do some input checking 38 | if 'pinterest.com' not in self.url.lower(): 39 | print("url (" + self.url + ") is not a pinterest.com url.") 40 | return 41 | if len(self.password) == 0: 42 | print("password is empty") 43 | return 44 | if len(self.username) == 0: 45 | print("username is empty") 46 | return 47 | from selenium import webdriver 48 | from selenium.webdriver.common.keys import Keys 49 | # Set up Chrome Driver 50 | options = webdriver.ChromeOptions() 51 | options.add_argument('--headless') 52 | options.add_argument('--no-sandbox') 53 | options.add_argument('--disable-dev-shm-usage') 54 | wd = webdriver.Chrome('chromedriver',options=options) 55 | # open Pinterest and log in 56 | print("logging in to Pinterest...") 57 | wd.get("https://www.pinterest.com") 58 | emailElem = wd.find_element_by_name('id') 59 | emailElem.send_keys(self.username) 60 | passwordElem = wd.find_element_by_name('password') 61 | passwordElem.send_keys(self.password) 62 | passwordElem.send_keys(Keys.RETURN) 63 | time.sleep(5 + random.randint(1, 5)) 64 | # Get the first page 65 | print("Going to first page...") 66 | wd.get(self.url) 67 | # sleep 68 | time.sleep(5 + random.randint(1, 5)) 69 | # get urls for images 70 | results = set() # the urls 71 | url_count = 0 # how many image urls have we found? 72 | miss_count = 0 # how many times have we failed to get more images? 73 | max_miss_count = 10 # how many times are we willing to try again? 74 | print("getting image urls...") 75 | while len(results) < self.target and miss_count < 5: 76 | # Find image elements in the web page 77 | images = wd.find_elements_by_tag_name("img") 78 | # Iterate through the elements and try to get the urls, which is the source of the element 79 | for i in images: 80 | try: 81 | src = i.get_attribute("src") 82 | results.add(src) 83 | except: 84 | pass 85 | # Did we fail to get new image urls? 86 | print(len(results)) 87 | if len(results) == url_count: 88 | # If so, increment miss count 89 | miss_count = miss_count + 1 90 | else: 91 | miss_count = 0 92 | # Remember how many image urls we had last time around 93 | url_count = len(results) 94 | # sleep 95 | time.sleep(5 + random.randint(1, 5)) 96 | # Send page down signal 97 | try: 98 | dummy = wd.find_element_by_tag_name('a') 99 | dummy.send_keys(Keys.PAGE_DOWN) 100 | except: 101 | print("can't page down") 102 | miss_count = miss_count + 1 103 | # convert results to list 104 | results = list(results) 105 | # Prep download directory 106 | prep_output_dir(self.output) 107 | # download images 108 | print("downloading images...") 109 | for result in results: 110 | res = requests.get(result) 111 | filename = os.path.join(self.output, os.path.basename(result)) 112 | file = open(filename, 'wb') 113 | for chunk in res.iter_content(100000): 114 | file.write(chunk) 115 | file.close() 116 | 117 | 118 | #################################### 119 | 120 | class ResizeImages(Module): 121 | 122 | def __init__(self, input, size, output): 123 | self.input = input # path to input files (directory) 124 | self.output = output # path to output files (directory) 125 | self.size = size # width and height (int) 126 | self.ready = checkFiles(input) 127 | self.output_files = [output] 128 | 129 | def run(self): 130 | # Prep output file directory 131 | prep_output_dir(self.output) 132 | # The transformation 133 | p = transforms.Compose([transforms.Resize((self.size,self.size))]) 134 | # Apply to all files in input directory 135 | if os.path.exists(self.input) and os.path.isdir(self.input): 136 | for file in os.listdir(self.input): 137 | try: 138 | img = Image.open(os.path.join(self.input, file)) 139 | img2 = p(img) 140 | img2.save(os.path.join(self.output, file), 'JPEG') 141 | except: 142 | print(file, "did not load") 143 | else: 144 | print(self.input, "is not a directory") 145 | 146 | ################################### 147 | 148 | ### Module for face detection/close cropping faces 149 | 150 | class CropFaces(Module): 151 | 152 | def __init__(self, input, size, output, rejects): 153 | self.input = input # path to directory containing images 154 | self.output = output # path to directory to save new images 155 | self.rejects = rejects # path to directory to save rejected images 156 | self.size = size # (int) size of height and width of output files 157 | self.ready = checkFiles(input) 158 | self.output_files = [output, rejects] 159 | 160 | def run(self): 161 | # Prep output file directory 162 | prep_output_dir(self.output) 163 | prep_output_dir(self.rejects) 164 | # Run autocrop program 165 | cmd = 'autocrop -i ' + self.input + ' -o ' + self.output + ' -w ' + str(self.size) + ' -H ' + str(self.size) + ' --facePercent 50 -r ' + self.rejects 166 | status = os.system(cmd) 167 | 168 | ### Module for removing grayscale images 169 | class RemoveGrayscale(Module): 170 | 171 | def __init__(self, input, output, rejects): 172 | self.input = input # path to directory containing images 173 | self.output = output # path to directory to put non-grayscale images 174 | self.rejects = rejects # path to directory to put grayscale images 175 | self.ready = checkFiles(input) 176 | self.output_files = [output, rejects] 177 | 178 | def run(self): 179 | # prep output file directory 180 | prep_output_dir(self.output, makedir=True) 181 | prep_output_dir(self.rejects, makedir=True) 182 | # Check each file 183 | for file in os.listdir(self.input): 184 | try: 185 | img = Image.open(os.path.join(self.input, file)) 186 | except: 187 | print(file, "could not be loaded") 188 | if len(img.getbands()) == 3: 189 | # this images has 3 color channels (not grayscale) 190 | shutil.copyfile(os.path.join(self.input, file), os.path.join(self.output, file)) 191 | else: 192 | # this image is grayscale 193 | shutil.copyfile(os.path.join(self.input, file), os.path.join(self.rejects, file)) 194 | 195 | ##################### 196 | 197 | class MakeGrayscale(Module): 198 | 199 | def __init__(self, input, output): 200 | self.input = input # path to directory containing images 201 | self.output = output # path to directory 202 | self.ready = checkFiles(input) 203 | self.output_files = [output] 204 | 205 | def run(self): 206 | # prep output file directory 207 | prep_output_dir(self.output) 208 | # load each file 209 | for file in os.listdir(self.input): 210 | # load and convert 211 | img = Image.open(os.path.join(self.input, file)).convert('LA') 212 | # save to target destination 213 | img.save(os.path.join(self.output, file)) 214 | 215 | ############################# 216 | 217 | ### Module for saving images (directory) 218 | class SaveImages(Module): 219 | 220 | def __init__(self, images, directory): 221 | self.images = images # path to directory containing images 222 | self.directory = directory # path name to save to 223 | self.ready = checkFiles(images) 224 | self.output_files = [directory] 225 | 226 | def run(self): 227 | prep_output_dir(self.directory, makedir=False) 228 | shutil.copytree(self.images, self.directory) 229 | 230 | 231 | ############################### 232 | 233 | class LoadImages(Module): 234 | 235 | def __init__(self, directory, images): 236 | self.images = images # path to save images into 237 | self.directory = directory # path to load images from 238 | self.ready = True 239 | self.output_files = [images] 240 | 241 | def run(self): 242 | # Check for path existence 243 | if os.path.exists(self.directory): 244 | # If it's a directory, copy all files in the directory 245 | if os.path.isdir(self.directory): 246 | prep_output_dir(self.images, makedir=False) 247 | shutil.copytree(self.directory, self.images) 248 | else: 249 | # Not a directory, make a directory and copy the single file into it 250 | prep_output_dir(self.images) 251 | shutil.copy(self.directory, self.images) 252 | 253 | ################################ 254 | 255 | ### Module for fine-tuning StyleGAN 256 | ### (include output for an animation) 257 | 258 | class StyleGAN_FineTune(Module): 259 | 260 | def __init__(self, model_in, images, start_kimg, max_kimg, seed, schedule, model_out, animation): 261 | self.model_in = model_in # path to model (pkl file) 262 | self.images = images # path to image directory 263 | self.start_kimg = start_kimg # (int) iteration to begin with (closer to zero means retrain more) default: 7000 264 | self.max_kimg = max_kimg # (int) max kimg 265 | self.schedule = schedule # (int or string or dict) the number of kimgs per resolution level 266 | self.model_out = model_out # path to fine tuned model (pkl file) 267 | self.animation = animation # path to an animation file 268 | self.seed = seed # (int) default = 1000 269 | self.ready = checkFiles(model_in, images) 270 | self.output_files = [model_out, animation] 271 | 272 | def run(self): 273 | print("Keyboard interrupt will stop training but program will try to continue.") 274 | # Run the stylegan program in a separate sub-process 275 | cwd = os.getcwd() 276 | schedule = str(self.schedule) 277 | params = {'model': os.path.join(cwd, self.model_in), 278 | 'images_in' : os.path.join(cwd, self.images), 279 | 'dataset_temp' : os.path.join(cwd, 'stylegan_dataset'), 280 | 'start_kimg' : self.start_kimg, 281 | 'max_kimg' : self.max_kimg, 282 | 'schedule' : schedule if len(schedule) > 0 else ' ', 283 | 'seed' : self.seed 284 | } 285 | command = 'python stylegan/stylegan_runner.py --train' 286 | for key in params.keys(): 287 | val = params[key] 288 | command = command + ' --' + str(key) + ' ' + str(val) 289 | print("launching", command) 290 | try: 291 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True ) 292 | for line in iter(process.stdout.readline, b''): # replace '' with b'' for Python 3 293 | sys.stdout.write(line) 294 | except KeyboardInterrupt: 295 | print("Keyboard interrupt") 296 | ### ASSERT: we are done training 297 | # Get final model... it's the pkl with the biggest number 298 | # First, get the latest run directory 299 | run_dir = '' 300 | latest = -1 301 | for file in os.listdir(os.path.join(cwd, 'results')): 302 | match = re.match(r'([0-9]+)', file) 303 | if match is not None and match.group(1) is not None: 304 | current = int(match.group(1)) 305 | if current > latest: 306 | latest = current 307 | run_dir = file 308 | run_dir = os.path.join(cwd, 'results', run_dir) 309 | # Now get the newest pkl file and image files 310 | model_filename = '' 311 | image_files = [] 312 | latest = 0.0 313 | for file in os.listdir(run_dir): 314 | match_png = re.search(r'[\w\W]*?[0-9]+?.png', file) 315 | match_pkl = re.search(r'[\w\W]*?[0-9]+?.pkl', file) 316 | if match_pkl is not None: 317 | current = os.stat(os.path.join(cwd, 'results', run_dir, file)).st_ctime 318 | if current > latest: 319 | latest = current 320 | model_filename = file 321 | elif match_png is not None: 322 | image_files.append(os.path.join(cwd, 'results', run_dir, file)) 323 | model_filename = os.path.join(cwd, 'results', run_dir, model_filename) 324 | # save animation images 325 | prep_output_dir(self.animation) 326 | for file in image_files: 327 | shutil.copy(file, self.animation) 328 | # Save fine tuned model 329 | shutil.copyfile(model_filename, self.model_out) 330 | 331 | 332 | 333 | ############################ 334 | 335 | 336 | 337 | class StyleGAN_Run(Module): 338 | 339 | def __init__(self, model, num, images): 340 | self.model = model # path to model 341 | self.num = num # number of images to generate (int) 342 | self.images = images # path to output image 343 | self.ready = checkFiles(model) 344 | self.output_files = [images] 345 | 346 | def run(self): 347 | prep_output_dir(self.images) 348 | # Run the stylegan program in a separate sub-process 349 | cwd = os.getcwd() 350 | params = {'model': os.path.join(cwd, self.model), 351 | 'images_out' : os.path.join(cwd, self.images), 352 | 'num' : self.num, 353 | } 354 | command = 'python stylegan/stylegan_runner.py --run' 355 | for key in params.keys(): 356 | val = params[key] 357 | command = command + ' --' + str(key) + ' ' + str(val) 358 | print("launching", command) 359 | try: 360 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True ) 361 | for line in iter(process.stdout.readline, b''): # replace '' with b'' for Python 3 362 | sys.stdout.write(line) 363 | except Exception as e: 364 | print("something broke") 365 | print(e) 366 | 367 | 368 | ############# 369 | 370 | # Not tested. Needs to launch stylegan_runner.py 371 | class StyleGAN_Movie(Module): 372 | 373 | def __init__(self, model, length, interp, duration, movie): 374 | self.model = model # path to input model 375 | self.length = length # (int) number of way points interpolated between (default = 10) 376 | self.interp = interp # (int) number of interpolations between waypoints (default = 30) 377 | self.duration = duration # (int) duration of animation (default=1) 378 | self.movie = movie 379 | self.ready = checkFiles(model) 380 | self.output_files = [movie] 381 | 382 | def run(self): 383 | prep_output_dir(self.movie) 384 | # Run the stylegan program in a separate sub-process 385 | cwd = os.getcwd() 386 | params = {'model': os.path.join(cwd, self.model), 387 | 'movie_out' : os.path.join(cwd, self.movie, 'movie.gif'), 388 | 'num' : self.length, 389 | 'interp' : self.interp, 390 | 'duration' : self.duration 391 | } 392 | command = 'python stylegan/stylegan_runner.py --movie' 393 | for key in params.keys(): 394 | val = params[key] 395 | command = command + ' --' + str(key) + ' ' + str(val) 396 | print("launching", command) 397 | try: 398 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True ) 399 | for line in iter(process.stdout.readline, b''): # replace '' with b'' for Python 3 400 | sys.stdout.write(line) 401 | except Exception as e: 402 | print("something broke") 403 | print(e) 404 | 405 | 406 | 407 | 408 | ###################### 409 | 410 | class MakeMovie(Module): 411 | 412 | def __init__(self, images, duration, movie): 413 | self.images = images # path to directory with images 414 | self.duration = duration # (int) duration of animation (default=10) 415 | self.movie = movie # path to output movie 416 | self.ready = checkFiles(images) 417 | self.output_files = [movie] 418 | 419 | def run(self): 420 | images = [] 421 | files = [] 422 | # Get list of filenames 423 | for file in os.listdir(self.images): 424 | files.append(file) 425 | # Sort the file names by creation date 426 | sorted_files = sorted(files) 427 | # Load each image in sorted order 428 | for file in sorted_files: 429 | try: 430 | img = Image.open(os.path.join(self.images, file)) 431 | images.append(img) 432 | except: 433 | print(file, "did not load.") 434 | # Prep the destination, a directory to save a single file 435 | prep_output_dir(self.movie) 436 | # Make and save the animation 437 | images[0].save(os.path.join(self.movie, 'movie.gif'), "GIF", 438 | save_all=True, 439 | append_images=images[1:], 440 | duration=self.duration, 441 | loop=0) 442 | 443 | ########################## 444 | 445 | class Degridify(Module): 446 | 447 | def __init__(self, input, columns, rows, output): 448 | self.input = input # path to directory containing images 449 | self.output = output # path to output directory 450 | self.columns = columns # (int) number of columns 451 | self.rows = rows # (int) number of rows 452 | self.ready = checkFiles(input) 453 | self.output_files = [output] 454 | 455 | def run(self): 456 | columns = self.columns 457 | rows = self.rows 458 | prep_output_dir(self.output) 459 | # iterate through images in the directory 460 | for file in os.listdir(self.input): 461 | # Load image 462 | grid = Image.open(os.path.join(self.input, file)) 463 | # compute the size of each grid cell 464 | grid_width, grid_height = grid.size 465 | image_width = grid_width // columns 466 | image_height = grid_height // rows 467 | # Start cropping 468 | if grid.n_frames > 1: 469 | self.cropAnim(grid, columns, rows, image_width, image_height) 470 | else: 471 | self.cropImage(grid, columns, rows) 472 | 473 | def cropImage(self, grid, columns, rows, image_width, image_height): 474 | for i in range(columns): 475 | for j in range(rows): 476 | img = grid.crop((i*image_width, j*image_height, (i+1)*image_width, (j+1)*image_height)) 477 | img.save(os.path.join(self.output, str(i).zfill(5) + '-' + str(j).zfill(5) + '.gif'), "GIF") 478 | 479 | def cropAnim(self, anim, columns, rows, image_width, image_height): 480 | cwd = os.getcwd() 481 | temp_dir = os.path.join(cwd, '.degridify_temp') 482 | if os.path.exists(temp_dir): 483 | shutil.rmtree(temp_dir) 484 | os.mkdir(temp_dir) 485 | frame_paths = [] 486 | anim_dict = {} 487 | for n in range(anim.n_frames): 488 | anim.seek(n) 489 | file = str(n).zfill(5) + '.gif' 490 | path = os.path.join(temp_dir, file) 491 | anim.save(path, "GIF") 492 | frame_paths.append(path) 493 | for file in frame_paths: 494 | frame = Image.open(file) 495 | for i in range(columns): 496 | for j in range(rows): 497 | cropped = frame.crop((i*image_width, j*image_height, (i+1)*image_width, (j+1)*image_height)) 498 | if (i, j) not in anim_dict: 499 | anim_dict[(i, j)] = [] 500 | anim_dict[(i, j)].append(cropped) 501 | for key in anim_dict: 502 | i, j = key 503 | filename = str(i).zfill(5) + '-' + str(j).zfill(5) + '.gif' 504 | frames = anim_dict[key] 505 | frames[0].save(os.path.join(self.output, filename), "GIF", 506 | save_all=True, 507 | append_images=frames[1:], 508 | duration=100, 509 | loop=0) 510 | shutil.rmtree(temp_dir) 511 | 512 | 513 | class Gridify(Module): 514 | 515 | def __init__(self, input, size, columns, output): 516 | self.input = input # path to directory containing input files 517 | self.output = ouput # path containing a grid file 518 | self.size = size # (int) size of a cell (default=256) 519 | self.columns = columns # (int) number of columns required (rows computed automatically) (default=4) 520 | self.ready = checkFiles(input) 521 | self.output_files = [output] 522 | 523 | def run(self): 524 | columns = self.columns # Just making it easier to use this variable 525 | # load and resize images 526 | images = [] # all the images 527 | files = [] # all the filenames 528 | p = transforms.Compose([transforms.Resize((self.size,self.size))]) # the transform 529 | # Sort the files by creation date 530 | for file in os.listdir(self.input): 531 | files.append(file) 532 | sorted_files = sorted(files) 533 | # Iterate through all images in the directory 534 | for file in sorted_files: 535 | # Load the image 536 | img = Image.open(os.path.join(self.input, file)) 537 | # Resize it 538 | img2 = p(img) 539 | # Save it in order 540 | images.append(img2) 541 | # compute number of rows 542 | rows = len(images) // columns 543 | leftovers = len(images) % columns 544 | # round up 545 | if leftovers > 0: 546 | rows = rows + 1 547 | for _ in range(leftovers): 548 | empty = Image.new('RGB', (self.size, self.size)) 549 | images.append(empty) 550 | # make new image 551 | grid = Image.new('RGB', (columns*self.size, rows*self.size)) 552 | # paste images into new image 553 | counter = 0 554 | for i in range(columns): 555 | for j in range(rows): 556 | cur_img = images[counter] 557 | grid.paste(cur_img, (i*self.size, j*self.size)) 558 | counter = counter + 1 559 | # save new image 560 | prep_output_dir(self.output) 561 | grid.save(os.path.join(self.output, 'grid.jpg'), "JPEG") 562 | 563 | 564 | ########################### 565 | 566 | class StyleTransfer(Module): 567 | 568 | def __init__(self, content_image, style_image, steps, size, style_weight, content_weight, 569 | content_layers, style_layers, output): 570 | self.content_image = content_image # path to content image directory 571 | self.style_image = style_image # path to style image directory 572 | self.steps = steps # (int) number of steps to run (Default=1000) 573 | self.size = size # (int) image size (default=512) 574 | self.style_weight = style_weight # (int) style weight 575 | self.content_weight = content_weight # (int) content weight 576 | self.content_layers = content_layers # (str) consisting of numbers [1-5] (default="4") 577 | self.style_layers = style_layers # (str) consisting of numbers [1-5] (default="1, 2, 3, 4, 5") 578 | self.output = output # path to output directory 579 | self.ready = checkFiles(content_image, style_image) 580 | self.output_files = [output] 581 | 582 | def run(self): 583 | import style_transfer 584 | prep_output_dir(self.output) 585 | count = 0 586 | # sort the content and style images 587 | content_images = [] 588 | style_images = [] 589 | sorted_content_images = [] 590 | sorted_style_images = [] 591 | for content_file in os.listdir(self.content_image): 592 | content_images.append(content_file) 593 | for style_file in os.listdir(self.style_image): 594 | style_images.append(style_file) 595 | sorted_content_images = sorted(content_images) 596 | sorted_style_images = sorted(style_images) 597 | # If there are multiple input content and style images, run all combinations 598 | for style_file in sorted_style_images: 599 | for content_file in sorted_content_images: 600 | count = count + 1 601 | print("Running with content=" + content_file + " style=" + style_file) 602 | style_transfer.run(os.path.join(self.content_image, content_file), 603 | os.path.join(self.style_image, style_file), 604 | os.path.join(self.output, str(count).zfill(5) + '.jpg'), 605 | image_size = self.size, 606 | num_steps = self.steps, 607 | style_weight = self.style_weight, 608 | content_weight = self.content_weight, 609 | content_layers_spec = self.content_layers, 610 | style_layers_spec = self.style_layers) 611 | 612 | ################################# 613 | 614 | class JoinImageDirectories(Module): 615 | 616 | def __init__(self, dir1, dir2, output): 617 | self.dir1 = dir1 618 | self.dir2 = dir2 619 | self.output = output 620 | self.ready = checkFiles(dir1, dir2) 621 | self.output_files = [output] 622 | 623 | def run(self): 624 | prep_output_dir(self.output) 625 | for file in os.listdir(self.dir1): 626 | shutil.copy(os.path.join(self.dir1, file), self.output) 627 | for file in os.listdir(self.dir2): 628 | shutil.copy(os.path.join(self.dir2, file), self.output) 629 | 630 | ################################# 631 | 632 | class SquareCrop(Module): 633 | 634 | def __init__(self, input, output): 635 | self.input = input 636 | self.output = output 637 | self.ready = checkFiles(input) 638 | self.output_files = [output] 639 | 640 | def run(self): 641 | prep_output_dir(self.output) 642 | for file in os.listdir(self.input): 643 | img = Image.open(os.path.join(self.input, file)) 644 | square_img = img 645 | width, height = img.size 646 | if width > height: 647 | diff = width - height 648 | box = (diff//2, 0, width - diff//2, height) 649 | square_img = img.crop(box) 650 | elif height > width: 651 | diff = height - width 652 | box = (0, diff//2, width, height - diff//2) 653 | square_img = img.crop(box) 654 | square_img.save(os.path.join(self.output, file), "JPEG") 655 | 656 | #################################### 657 | 658 | class UnmakeMovie(Module): 659 | 660 | def __init__(self, movie, output): 661 | self.movie = movie # path to directory containing movies 662 | self.output = output # path to directory to save images 663 | self.ready = checkFiles(movie) 664 | self.output_files = [output] 665 | 666 | def run(self): 667 | prep_output_dir(self.output) 668 | for file in os.listdir(self.movie): 669 | anim = Image.open(os.path.join(self.movie, file)) 670 | for n in range(anim.n_frames): 671 | anim.seek(n) 672 | new_filename = os.path.splitext(file)[0] + '-' + str(n).zfill(5) +'.gif' 673 | anim.save(os.path.join(self.output, new_filename), "GIF") 674 | 675 | -------------------------------------------------------------------------------- /gui.js: -------------------------------------------------------------------------------- 1 | // GLOBALS //////////////////////////////////// 2 | const screen_width = 2048; // default screen width 3 | const screen_height = 600; // default screen height 4 | const module_width = 200; // default module width 5 | const module_height = 40; // default module height 6 | const parameter_offset = 10; // how much to indent parameters 7 | const parameter_width = module_width - (parameter_offset * 2); // default parameter width 8 | const parameter_height = module_height; // default parameter height 9 | const module_spacing = 20; // how much space between modules 10 | const parameter_spacing = 5; // how much space between parameters 11 | const cache_path = "cache/" // cache directory 12 | const image_path = '/nbextensions/google.colab/'; // where to find images 13 | const text_size = 16; // default text size 14 | const font_size = "" + text_size + "px"; // default font size for html5 15 | 16 | // What images to use for different parameter types 17 | var type_images = {"model" : "model.png", 18 | "dictionary" : "dictionary.png", 19 | "image" : "image.png", 20 | "images" : "images.png", 21 | "text" : "text.png", 22 | "string" : "string.png", 23 | "int" : "number.png", 24 | "float" : "number.png"}; 25 | 26 | var drag = null; // Thing being dragged 27 | var clicked = null; // Thing that was clicked 28 | var open_param = null; // Which parameter is being edited 29 | var mouseX; // Mouse X 30 | var mouseY; // Mouse Y 31 | var modules = []; // List of modules 32 | var parameters = []; // list of parameters 33 | var connections = []; // List of connections 34 | var clickables = []; // List of clickable things 35 | var dragables = []; // List of dragable things 36 | var module_counters = {}; // Dictionary of counts of each module used 37 | var drag_offset_x = 0; // Keep track off the offset due to scrolling 38 | var drag_offset_y = 0; // Keep track of the offset due to scrolling 39 | var module_counter = 0; // keep count of number of modules 40 | var parameter_counter = 0; // keep count of number of parameters 41 | 42 | 43 | // START GAME ///////////////////////////////// 44 | function startGame() { 45 | // HIDE HTML OBJECTS 46 | hideInput(); 47 | // POPULATE MODULE MAKING HTML 48 | populateModuleMaker(); 49 | // START THE GUI RUNNING 50 | myGameArea.start(); 51 | // EVENT HANDLERS 52 | // Mouse down 53 | myGameArea.canvas.onmousedown = function(e) { 54 | var offset = recursive_offset(myGameArea.canvas) 55 | var i; 56 | // Iterate through Dragables 57 | for (i = 0; i < dragables.length; i++) { 58 | var obj = dragables[i]; 59 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width && 60 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) { 61 | // found the dragable we clicked on 62 | drag = obj; // we are dragging this object now 63 | drag_offset_x = mouseX - obj.x; 64 | drag_offset_y = mouseY - obj.y; 65 | // Stop showing the text entry 66 | hideInput(); 67 | return; 68 | } 69 | } 70 | // Iterate through clickables 71 | for (i = 0; i < clickables.length; i++) { 72 | var obj = clickables[i]; 73 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width && 74 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) { 75 | // found the clickable we clicked on 76 | if (obj.is_out) { 77 | // Clicked on an output element that is not connected 78 | clicked = obj; // we are drawing a line from this object now 79 | break; 80 | } 81 | else if (!obj.is_in && !obj.is_out) { 82 | // clicked on a non-input connection 83 | // show input 84 | var inp = document.getElementById("inp"); 85 | inp.style.display = "block"; 86 | var inp_module = document.getElementById("inp_module"); 87 | inp_module.innerHTML = obj.parent.name; 88 | var inp_param = document.getElementById("inp_param"); 89 | inp_param.innerHTML = obj.name; 90 | var val = document.getElementById("inp_val"); 91 | val.value = obj.value; 92 | open_param = obj; 93 | return; 94 | } 95 | } 96 | } 97 | // Make text input invisible 98 | hideInput(); 99 | } 100 | // Mouse Move 101 | myGameArea.canvas.onmousemove = function(e) { 102 | mouseX = e.clientX; 103 | mouseY = e.clientY; 104 | } 105 | // Mouse up 106 | myGameArea.canvas.onmouseup = function(e) { 107 | if (drag) { 108 | // Done dragging 109 | drag = null; 110 | } 111 | if (clicked) { 112 | // Done clicking 113 | var offset = recursive_offset(myGameArea.canvas) 114 | var i; 115 | // we are making a line. Where did we drag the line to? 116 | for (i = 0; i < parameters.length; i++) { 117 | var obj = parameters[i]; 118 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width && 119 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) { 120 | // Found the parameter we mouse-uped on 121 | if (obj.is_in && !obj.connected && obj.parent.name != clicked.parent.name && obj.type === clicked.type) { 122 | // clicked on an input that isn't already connected 123 | // Make a connection 124 | c = new connection(clicked, obj); 125 | c.id = connections.length; 126 | connections.push(c); 127 | obj.connected = true; 128 | clicked.connected = true; 129 | var fname = cache_path + obj.type + clicked.id; 130 | obj.value = fname; 131 | clicked.value = fname; 132 | } 133 | } 134 | } 135 | // Not making a line any more 136 | clicked = null; 137 | } 138 | } 139 | // double click 140 | myGameArea.canvas.ondblclick = function(e){ 141 | var offset = recursive_offset(myGameArea.canvas) 142 | var i; 143 | // Dragables 144 | for (i = 0; i < dragables.length; i++) { 145 | var obj = dragables[i]; 146 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width && 147 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) { 148 | // found the dragable we double-clicked on. 149 | // Collapse or un-collapse it 150 | obj.collapsed = !obj.collapsed; 151 | break; 152 | } 153 | } 154 | }; 155 | // Key Press 156 | document.addEventListener('keydown', function(e){ 157 | var offset = recursive_offset(myGameArea.canvas) 158 | if (e.keyCode == 8 || e.keyCode == 46) { 159 | // DELETE 160 | var i; 161 | // Did we delete a dragable? 162 | for (i = 0; i < dragables.length; i++) { 163 | var obj = dragables[i]; 164 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width && 165 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) { 166 | // Found the dragable we were mousing over when the delete button was pressed 167 | delete_module(obj); 168 | break; 169 | } 170 | } 171 | } 172 | } ); 173 | } 174 | 175 | // MYGAMEAREA ////////////////////////////////////////// 176 | var myGameArea = { 177 | canvas : document.createElement("canvas"), 178 | start : function() { 179 | this.canvas.width = screen_width; 180 | this.canvas.height = screen_height; 181 | this.context = this.canvas.getContext("2d"); 182 | document.body.insertBefore(this.canvas, document.body.childNodes[0]); 183 | this.frameNo = 0; 184 | this.interval = setInterval(updateGameArea, 20); 185 | }, 186 | clear : function() { 187 | this.context.clearRect(0, 0, this.canvas.width, this.canvas.height); 188 | } 189 | } 190 | 191 | 192 | // MODULE /////////////////////////////////////////////// 193 | function module(x, y, type, name, category, id) { 194 | this.width = module_width; // module width 195 | this.height = module_height; // module height 196 | this.color = "gray"; // module color 197 | this.category = category; // what category did this belong in (not sure this is needed) 198 | this.x = x; // x location 199 | this.y = y; // y location 200 | this.name = name; // label to display to user 201 | this.type = type; // module type 202 | this.params = []; // list of parameters 203 | this.font = "Arial" // font to use 204 | this.font_size = font_size; // font size 205 | this.id = id; // unique id number 206 | this.collapsed = false; // am I collapsed? 207 | this.up = new Image(); // up arrow 208 | this.up.src = image_path + "up.png"; // load the up arrow image 209 | this.down = new Image(); // down arrow 210 | this.down.src = image_path + "down.png"; // load the down arrow image 211 | // UPDATE FUNCTION 212 | this.update = function() { 213 | ctx = myGameArea.context; 214 | if (drag == this) { 215 | // I AM BEING DRAGGED 216 | this.x = mouseX - drag_offset_x; 217 | this.y = mouseY - drag_offset_y; 218 | } 219 | // DRAW ME 220 | // draw my parameters 221 | var i; 222 | for (i = 0; i < this.params.length; i++) { 223 | param = this.params[i]; 224 | var index = i + 1; 225 | if (this.collapsed) { 226 | index = 0; 227 | } 228 | param.x = this.x + parameter_offset; 229 | param.y = this.y + (this.height + parameter_spacing)*(index); 230 | param.update(); 231 | } 232 | // Draw the module header 233 | ctx.fillStyle = this.color; 234 | ctx.fillRect(this.x, this.y, this.width, this.height); 235 | ctx.font = this.font_size + " " + this.font; 236 | ctx.fillStyle = "black"; 237 | var text_x = this.x + 5; 238 | var text_y = this.y + (this.height / 2.0) + (text_size / 3.0); 239 | var img_height = (this.height - 10) / 4; 240 | var img_width = (this.height - 10) / 2; 241 | ctx.fillText(this.name, text_x, text_y); 242 | // Show up arrow or down arrow? 243 | if (this.collapsed) { 244 | ctx.drawImage(this.down, this.x + this.width - img_width - 5, this.y + 5, img_width, img_height); 245 | } 246 | else { 247 | ctx.drawImage(this.up, this.x + this.width - img_width - 5, this.y + 5, img_width, img_height); 248 | } 249 | 250 | } 251 | } 252 | 253 | // PARAMETER ////////////////////////// 254 | function parameter(name, is_in, is_out, type, default_value, parent, id) { 255 | this.color = "lightgray"; // parameter color (red is output, green is input, lightgray otherwise) 256 | this.width = parameter_width; // parameter width 257 | this.height = parameter_height; // parameter height 258 | this.is_in = is_in; // am I an input? 259 | this.is_out = is_out; // am I an output? 260 | this.type = type; // what type am I? (string, text, int, float, dictionary, model, etc.) 261 | this.value = default_value; // my value 262 | this.x = 0; // x location (parent module will set me) 263 | this.y = 0; // y location (parent module will set me) 264 | this.name = name; // My label to show the user 265 | this.connected = false; // Am I connected to another parameter? 266 | this.parent = parent; // Who is my parent module? 267 | this.font = "Arial" // My font 268 | this.font_size = font_size // my font size 269 | this.id = id; // unique identifier 270 | // Set my color based on whether I am an input or output. Also set the default filename I create when linked 271 | if (this.is_in) { 272 | this.color = "green"; 273 | this.value = cache_path + name; 274 | } 275 | else if (this.is_out) { 276 | this.color = "red"; 277 | this.value = cache_path + name; 278 | } 279 | // Do I have an icon to show? 280 | this.img = null; // icon image 281 | if (this.type in type_images) { 282 | this.img = new Image(); 283 | this.img.src = image_path + type_images[this.type]; 284 | } 285 | // UPDATE FUNCTION 286 | this.update = function() { 287 | ctx = myGameArea.context; 288 | // DRAW ME 289 | ctx.fillStyle = this.color; 290 | ctx.fillRect(this.x, this.y, this.width, this.height); 291 | ctx.font = this.font_size + " " + this.font; 292 | ctx.fillStyle = "black"; 293 | var text_x = this.x + 5; 294 | var text_y = this.y + (this.height / 2.0) + (text_size / 3.0); 295 | var img_height = this.height - 10; 296 | var img_width = img_height; 297 | // Show my icon 298 | if (this.is_out) { 299 | ctx.drawImage(this.img, this.x + this.width - img_width - 5, this.y + 5, img_width, img_height); 300 | } 301 | else if (this.is_in || this.img) { 302 | ctx.drawImage(this.img, text_x, this.y + 5, img_width, img_height); 303 | text_x = text_x + img_width + 5; 304 | } 305 | ctx.fillText(this.name, text_x, text_y); 306 | 307 | } 308 | } 309 | 310 | // CONNECTION //////////////////////////////// 311 | function connection (origin, target) { 312 | this.origin = origin; // my origin parameter 313 | this.target = target; // my target parameter 314 | this.id = null; // unique identifier 315 | this.update = function() { 316 | var ctx = myGameArea.context; 317 | // DRAW ME 318 | var origin_x = this.origin.x + this.origin.width; 319 | var origin_y = this.origin.y + this.origin.height/2; 320 | var target_x = this.target.x; 321 | var target_y = this.target.y + this.target.height/2; 322 | if (this.origin.parent.collapsed) { 323 | origin_x = this.origin.x + this.origin.width/2; 324 | origin_y = this.origin.y + this.target.height; 325 | } 326 | if (this.target.parent.collapsed) { 327 | target_x = this.target.x + this.target.width/2; 328 | target_y = this.target.y; 329 | } 330 | ctx.lineWidth = 2; 331 | ctx.beginPath(); 332 | ctx.moveTo(origin_x, origin_y); 333 | ctx.lineTo(target_x, target_y); 334 | ctx.stroke(); 335 | ctx.beginPath(); 336 | ctx.arc(origin_x, origin_y, 5, 0, 2 * Math.PI); 337 | ctx.fill(); 338 | ctx.beginPath(); 339 | ctx.arc(target_x, target_y, 5, 0, 2 * Math.PI); 340 | ctx.fill(); 341 | } 342 | } 343 | 344 | // UPDATE CANVAS /////////////////////////////////// 345 | function updateGameArea() { 346 | var x, height, gap, minHeight, maxHeight, minGap, maxGap; 347 | myGameArea.clear(); 348 | myGameArea.frameNo += 1; 349 | // Draw my modules 350 | var i; 351 | for (i = 0; i < modules.length; i++) { 352 | modules[i].update(); 353 | } 354 | // Draw connections 355 | var j; 356 | for (j = 0; j < connections.length; j++) { 357 | connections[j].update(); 358 | } 359 | // Is there a line being drawn that hasn't been connected yet? 360 | if (clicked) { 361 | var offset = recursive_offset(myGameArea.canvas) 362 | var ctx = myGameArea.context; 363 | var dot_x = clicked.x + clicked.width; 364 | var dot_y = clicked.y + clicked.height/2; 365 | ctx.lineWidth = 2; 366 | ctx.beginPath(); 367 | ctx.moveTo(dot_x, dot_y); 368 | ctx.lineTo(mouseX+offset.x, mouseY+offset.y); 369 | ctx.stroke(); 370 | ctx.beginPath(); 371 | ctx.arc(clicked.x + clicked.width, clicked.y + clicked.height/2, 5, 0, 2 * Math.PI); 372 | ctx.fill(); 373 | } 374 | 375 | } 376 | 377 | // BUTTON HANDLERS ///////////////////////////// 378 | 379 | function do_input_button_up () { 380 | if (open_param != null) { 381 | var val = document.getElementById("inp_val"); 382 | open_param.value = val.value; 383 | } 384 | hideInput(); 385 | } 386 | 387 | function do_make_module_button_up() { 388 | var sel = document.getElementById("module_select"); 389 | var val = sel.options[sel.selectedIndex].value; // Name of the module type 390 | var i; 391 | // Figure out what type of module was selected 392 | for (i = 0; i < module_dicts.length; i++) { 393 | var module_dict = module_dicts[i]; 394 | if ("name" in module_dict) { 395 | var name = module_dict["name"]; 396 | if (name == val) { 397 | // Found it 398 | make_module(module_dict, module_counter); 399 | } 400 | } 401 | } 402 | } 403 | 404 | 405 | // HELPERS /////////////////////////////////////// 406 | 407 | 408 | function make_module(module_json, id) { 409 | var category = ""; 410 | var name = ""; 411 | var type = ""; 412 | var module_count = 0; 413 | // type 414 | if ("name" in module_json) { 415 | type = module_json["name"]; 416 | } 417 | // name 418 | if (!(type in module_counters)) { 419 | module_counters[type] = 0; 420 | } 421 | module_count = module_counters[type] + 1; 422 | module_counters[type] = module_count; 423 | name = type + " (" + module_count + ")"; 424 | // category 425 | if ("category" in module_json) { 426 | category = module_json["category"]; 427 | } 428 | // make new module 429 | var new_module = new module(((module_width+module_spacing) * modules.length) % (screen_width-(module_width+module_spacing)), 430 | parseInt((module_width+module_spacing) * modules.length / (screen_width-(module_width+module_spacing))) * (module_width+module_spacing), 431 | type, name, category, id); 432 | module_counter = module_counter + 1; 433 | //new_module.id = module_counter; 434 | modules.push(new_module); 435 | dragables.push(new_module); 436 | // parameters 437 | if ("params" in module_json) { 438 | var params = module_json["params"]; 439 | var j; 440 | for (j = 0; j < params.length; j ++ ) { 441 | param_json = params[j]; 442 | var p_name = "" 443 | var is_in = false; 444 | var is_out = false; 445 | var type = ""; 446 | var default_value = ""; 447 | // parameter name/type 448 | if ("name" in param_json) { 449 | p_name = param_json["name"]; 450 | } 451 | // is it an input? 452 | if ("in" in param_json && param_json["in"] == true) { 453 | is_in = true; 454 | } 455 | // is it an output? 456 | if ("out" in param_json && param_json["out"] == true) { 457 | is_out = true; 458 | } 459 | // What type of value does it store? 460 | if ("type" in param_json) { 461 | type = param_json["type"]; 462 | } 463 | // What is the default value? 464 | if ("default" in param_json) { 465 | default_value = param_json["default"]; 466 | } 467 | // Make the parameter 468 | new_param = new parameter(p_name, is_in, is_out, type, default_value, new_module, new_module.id + "-" + j); 469 | //new_param.id = new_module.id + "-" + j; 470 | parameter_counter = parameter_counter + 1; 471 | new_module.params.push(new_param); 472 | parameters.push(new_param); 473 | if (!is_in) { 474 | clickables.push(new_param); 475 | } 476 | } 477 | } 478 | return new_module; 479 | } 480 | 481 | function hideInput() { 482 | var inp = document.getElementById("inp"); 483 | inp.style.display = "none"; 484 | } 485 | 486 | function populateModuleMaker() { 487 | // Grab the select object 488 | var sel = document.getElementById("module_select"); 489 | var categories_dict = {}; // keys are category names, val is list of module names 490 | // seed with misc category 491 | categories_dict["misc"] = []; 492 | // Collect up category names and module names 493 | var i; 494 | for (i = 0; i < module_dicts.length; i++) { 495 | var module_dict = module_dicts[i]; 496 | if ("name" in module_dict && "category" in module_dict) { 497 | var category = module_dict["category"]; 498 | var name = module_dict["name"]; 499 | if (!(category in categories_dict)) { 500 | categories_dict[category] = []; 501 | } 502 | categories_dict[category].push(name); 503 | } 504 | else if ("name" in module_dict) { 505 | var name = module_dict["name"]; 506 | categories_dict["misc"].push(name); 507 | } 508 | } 509 | // Iterate through categories 510 | for (var key in categories_dict) { 511 | var module_names = categories_dict[key]; 512 | // If this category has modules, then add them to the select object 513 | if (module_names.length > 0) { 514 | // This category has modules 515 | // Make an optgroup 516 | var group = document.createElement('OPTGROUP'); 517 | group.label = key; 518 | // Iterate through modules 519 | for (i=0; i < module_names.length; i++) { 520 | // Make an option 521 | var module_name = module_names[i]; 522 | var opt = document.createElement('OPTION'); 523 | opt.textContent = module_name; 524 | opt.value = module_name; 525 | group.appendChild(opt); 526 | } 527 | sel.appendChild(group); 528 | } 529 | } 530 | } 531 | 532 | function everyinterval(n) { 533 | if ((myGameArea.frameNo / n) % 1 == 0) {return true;} 534 | return false; 535 | } 536 | 537 | 538 | function recursive_offset (aobj) { 539 | var currOffset = { 540 | x: 0, 541 | y: 0 542 | } 543 | var newOffset = { 544 | x: 0, 545 | y: 0 546 | } 547 | 548 | if (aobj !== null) { 549 | 550 | if (aobj.scrollLeft) { 551 | currOffset.x = aobj.scrollLeft; 552 | } 553 | 554 | if (aobj.scrollTop) { 555 | currOffset.y = aobj.scrollTop; 556 | } 557 | 558 | if (aobj.offsetLeft) { 559 | currOffset.x -= aobj.offsetLeft; 560 | } 561 | 562 | if (aobj.offsetTop) { 563 | currOffset.y -= aobj.offsetTop; 564 | } 565 | 566 | if (aobj.parentNode !== undefined) { 567 | newOffset = recursive_offset(aobj.parentNode); 568 | } 569 | 570 | currOffset.x = currOffset.x + newOffset.x; 571 | currOffset.y = currOffset.y + newOffset.y; 572 | } 573 | return currOffset; 574 | } 575 | 576 | 577 | function save_program() { 578 | // Don't save if there is a blank program. 579 | if (modules.length == 0) { 580 | return; 581 | } 582 | var filename = "myprogram" 583 | var input_save = document.getElementById("inp_save"); 584 | if (input_save.value.length > 0) { 585 | filename = input_save.value; 586 | } 587 | var program = []; 588 | // Deep copy connections 589 | var module_links = []; 590 | var module_ids = []; 591 | var i; 592 | for (i = 0; i < modules.length; i++) { 593 | module_ids.push(parseInt(modules[i].id)); 594 | } 595 | for (i = 0; i < connections.length; i++) { 596 | module_links.push([parseInt(connections[i].origin.parent.id), parseInt(connections[i].target.parent.id)]); 597 | } 598 | // Put all the modules in order 599 | while (module_ids.length > 0) { 600 | var ready = get_ready_modules(module_ids, module_links); 601 | program = program.concat(ready); 602 | var filtered_module_ids = module_ids.filter(function(value, index, arr) { 603 | var n; 604 | var is_in = false; 605 | for (n=0; n < ready.length; n++) { 606 | if (value == ready[n]) { 607 | is_in = true; 608 | break; 609 | } 610 | } 611 | return !is_in; 612 | }); 613 | var filtered_module_links = module_links.filter(function(value, index, arr) { 614 | var n; 615 | var is_in = false; 616 | for (n=0; n < ready.length; n++) { 617 | if (value[0] == ready[n] || value[1] == ready[n]) { 618 | is_in = true; 619 | break; 620 | } 621 | } 622 | return !is_in; 623 | }); 624 | module_ids = filtered_module_ids; 625 | module_links = filtered_module_links; 626 | } 627 | // assert: program is module ids in executable order 628 | var prog_json = ""; 629 | var is_first = true; 630 | console.log(program); 631 | for (i = 0; i < program.length; i++) { 632 | var current = program[i]; 633 | var current_module = get_module_by_id(current) 634 | var module_json = module_to_json(current_module); 635 | if (is_first) { 636 | prog_json = module_json; 637 | } 638 | else { 639 | prog_json = prog_json + "," + module_json; 640 | } 641 | console.log(prog_json); 642 | is_first = false; 643 | } 644 | // Put brackets around the json 645 | prog_json = "[" + prog_json + "]"; 646 | // Call python 647 | (async function() { 648 | const result = await google.colab.kernel.invokeFunction( 649 | 'notebook.python_save_hook', // The callback name. 650 | [prog_json, filename], // The arguments. 651 | {}); // kwargs 652 | const res = result.data['application/json']; 653 | //document.querySelector("#output-area").appendChild(document.createTextNode(text.result)); 654 | })(); 655 | } 656 | 657 | function get_module_by_id(id) { 658 | var i = 0; 659 | for (i=0; i < modules.length; i++) { 660 | if (modules[i].id == id) { 661 | return modules[i]; 662 | } 663 | } 664 | return null; 665 | } 666 | 667 | 668 | function get_ready_modules(mods, cons) { 669 | var ready = []; 670 | var i; 671 | for (i=0; i < mods.length; i++) { 672 | var current = mods[i]; 673 | var is_ready = true; 674 | var j; 675 | for (j=0; j < cons.length; j++) { 676 | if (cons[j][1] == current) { 677 | is_ready = false; 678 | break; 679 | } 680 | } 681 | if (is_ready) { 682 | ready.push(current); 683 | } 684 | } 685 | return ready; 686 | } 687 | 688 | function module_to_json(module) { 689 | var json = {}; 690 | json["module"] = module.type; 691 | json["name"] = module.name; 692 | json["x"] = module.x; 693 | json["y"] = module.y; 694 | json["id"] = module.id; 695 | json["collapsed"] = module.collapsed; 696 | // Parameters 697 | var i=0; 698 | for (i=0; i < module.params.length; i++) { 699 | param = module.params[i]; 700 | json[param.name] = ""+param.value; // Make all values strings 701 | } 702 | return JSON.stringify(json); 703 | } 704 | 705 | function clear_program() { 706 | modules = []; 707 | parameters = []; 708 | connections = []; 709 | clickables = []; 710 | dragables = []; 711 | module_counters = {}; 712 | module_counter = 0; 713 | parameter_counter = 0; 714 | drag_offset_x = 0; 715 | drag_offset_y = 0; 716 | drag = null; 717 | clicked = null; 718 | open_param = null; 719 | } 720 | 721 | function load_program(loadfile="") { 722 | // Need to clear the existing program 723 | clear_program(); 724 | var filename = loadfile; 725 | if (filename.length == 0) { 726 | // Filename wasn't passed in so we need to get the filename from the html 727 | var input_load = document.getElementById("inp_load"); 728 | if (input_load.value.length > 0) { 729 | filename = input_load.value; 730 | } 731 | else { 732 | return; 733 | } 734 | } 735 | // ASSERT: filename is non-empty 736 | // Call python 737 | async function foo() { 738 | const result = await google.colab.kernel.invokeFunction( 739 | 'notebook.python_load_hook', // The callback name. 740 | [filename], // The arguments. 741 | {}); // kwargs 742 | //program = result.data['application/json']; 743 | //res = result.data['application/json']; 744 | //console.log(result); 745 | //document.querySelector("#output-area").appendChild(document.createTextNode(text.result)); 746 | return result; 747 | }; 748 | foo().then(function(value) { 749 | var program = eval(value.data['application/json'].result); 750 | console.log(program); 751 | // what to do with the program 752 | var i; 753 | var m; 754 | // Iterate through each module in the program 755 | for (m = 0; m < program.length; m++) { 756 | var module_json = program[m]; // The json for this part of the program 757 | var mod = null; 758 | var id = module_counter; 759 | if ("id" in module_json) { 760 | id = module_json["id"]; 761 | } 762 | // Find the corresponding module definition 763 | for (i = 0; i < module_dicts.length; i++) { 764 | var module_dict = module_dicts[i]; 765 | 766 | if ("name" in module_dict) { 767 | var name = module_dict["name"]; 768 | if (name == module_json.module) { 769 | // Make the module 770 | mod = make_module(module_dict, id); 771 | // Except it only has default values 772 | break; 773 | } 774 | } 775 | } // end i 776 | // Move the module to the saved location 777 | mod.x = module_json.x; 778 | mod.y = module_json.y; 779 | if ("collapsed" in module_json) { 780 | mod.collapsed = module_json.collapsed; 781 | } 782 | // Update default parameters 783 | // Iterate through each of the specs from the file 784 | var module_keys = Object.keys(module_json); 785 | for (i = 0; i < module_keys.length; i++) { 786 | var module_key = module_keys[i]; 787 | var module_val = module_json[module_key]; 788 | var p; 789 | // Find the corresponding parameter 790 | for (p = 0; p < mod.params.length; p++) { 791 | var param = mod.params[p]; 792 | if (module_key === param.name) { 793 | // Found the right parameter 794 | param.value = module_val; 795 | break; 796 | } 797 | } 798 | } 799 | } //end m 800 | // Make connections 801 | console.log(parameters); 802 | var i; 803 | var j; 804 | // Iterate through all parameters, looking for matching file names 805 | for (i = 0; i < parameters.length; i++) { 806 | for (j = 0; j < parameters.length; j++) { 807 | var param1 = parameters[i]; 808 | var param2 = parameters[j]; 809 | if (param1.is_out && param2.is_in) { 810 | // param1 has an outlink and param2 has an inlink 811 | var fname1 = param1.value; 812 | var fname2 = param2.value; 813 | if (fname1 === fname2) { 814 | // Make a connection 815 | c = new connection(param1, param2); 816 | c.id = connections.length; 817 | connections.push(c); 818 | param1.connected = true; 819 | param2.connected = true; 820 | } 821 | } 822 | } 823 | } 824 | // update module count to be max id 825 | var mod; 826 | var max_id = 0; 827 | for (mod = 0; mod < modules.length; mod++) { 828 | cur_id = modules[mod].id; 829 | if (cur_id > max_id) { 830 | max_id = cur_id; 831 | } 832 | } 833 | module_counter = max_id+1; 834 | }); // end then function 835 | // Anything after this is not guaranteed to execute after the file is loaded. 836 | } 837 | 838 | function delete_module(module) { 839 | var filtered_modules = modules.filter(function(value, index, arr) { 840 | return value.id != module.id; 841 | }); 842 | var filtered_connections = connections.filter(function(value, index, arr) { 843 | return value.origin.parent.id != module.id && value.target.parent.id != module.id; 844 | }); 845 | var deleted_connections = connections.filter(function(value, index, err) { 846 | return value.origin.parent.id == module.id || value.target.parent.id == module.id; 847 | }); 848 | var i; 849 | for (i = 0; i < deleted_connections.length; i++) { 850 | var c = deleted_connections[i]; 851 | c.origin.connected = false; 852 | c.target.connected = false; 853 | } 854 | modules = filtered_modules; 855 | connections = filtered_connections; 856 | } 857 | 858 | // START THE GUI /////////////////////////////////////////// 859 | startGame() --------------------------------------------------------------------------------