","output":"cache/output","group_1":"cache/text4-3","group_2":"cache/group_2"},{"module":"ConcatenateTextFiles","name":"ConcatenateTextFiles (1)","x":451,"y":176,"id":6,"input_1":"cache/text1-3","input_2":"cache/text4-3","output":"cache/text6-2"},{"module":"RemoveDuplicates","name":"RemoveDuplicates (1)","x":672,"y":51,"id":7,"input":"cache/text6-2","output":"cache/text7-1"},{"module":"MakeLowercase","name":"MakeLowercase (1)","x":676,"y":206,"id":8,"input":"cache/text7-1","output":"cache/text8-1"},{"module":"WriteTextFile","name":"WriteTextFile (1)","x":678,"y":370,"id":9,"input":"cache/text8-1","file":"training_data"},{"module":"CharRNN_Train","name":"CharRNN_Train (1)","x":910,"y":6,"id":10,"data":"cache/text8-1","history":"10","layers":"2","hidden_nodes":"64","epochs":"100","learning_rate":"0.0001","model":"cache/model10-6","dictionary":"cache/dictionary10-7"},{"module":"RandomSequence","name":"RandomSequence (1)","x":914,"y":423,"id":12,"input":"cache/text8-1","length":"10","output":"cache/text12-2"},{"module":"CharRNN_Run","name":"CharRNN_Run (1)","x":1152,"y":297,"id":11,"model":"cache/model10-6","dictionary":"cache/dictionary10-7","seed":"cache/text12-2","steps":"6000","temperature":"1.0","output":"cache/text11-5"},{"module":"TextSubtract","name":"TextSubtract (1)","x":1374,"y":131,"id":13,"main":"cache/text11-5","subtract":"cache/text8-1","diff":"cache/text13-2"},{"module":"WriteTextFile","name":"WriteTextFile (2)","x":1379,"y":343,"id":14,"input":"cache/text13-2","file":"my_output_file"}]
--------------------------------------------------------------------------------
/examples/star_trek_novels:
--------------------------------------------------------------------------------
1 | [{"module":"MakeCountFile","name":"MakeCountFile (1)","x":0,"y":0,"id":0,"num":"50","prefix":"https://www.barnesandnoble.com/b/books/romance/historical-romance/_/N-29Z8q8Z17yg?Nrpp=40&page=","postfix":"","output":"cache/text0-3"},{"module":"MakeCountFile","name":"MakeCountFile (2)","x":2,"y":234,"id":1,"num":"50","prefix":"https://www.barnesandnoble.com/b/books/science-fiction-fantasy/star-trek-fiction/_/N-29Z8q8Z182c?Nrpp=40&page=","postfix":"","output":"cache/text1-3"},{"module":"ReadAllFromWeb","name":"ReadAllFromWeb (1)","x":240,"y":24,"id":2,"urls":"cache/text0-3","data":"cache/text2-1"},{"module":"ReadAllFromWeb","name":"ReadAllFromWeb (2)","x":216,"y":287,"id":3,"urls":"cache/text1-3","data":"cache/text3-1"},{"module":"ConcatenateTextFiles","name":"ConcatenateTextFiles (1)","x":443,"y":142,"id":4,"input_1":"cache/text2-1","input_2":"cache/text3-1","output":"cache/text4-2"},{"module":"Regex_Search","name":"Regex_Search (1)","x":687,"y":13,"id":5,"input":"cache/text4-2","expression":"Title: ([\\w\\W ]+?), Author:","output":"cache/output","group_1":"cache/text5-3","group_2":"cache/group_2"},{"module":"MakeLowercase","name":"MakeLowercase (1)","x":680,"y":301,"id":6,"input":"cache/text5-3","output":"cache/text6-1"},{"module":"Regex_Sub","name":"Regex_Sub (1)","x":935,"y":9,"id":7,"input":"cache/text6-1","expression":"star trek[: \\#0-9]*","replacement":"","output":"cache/text7-3"},{"module":"Regex_Sub","name":"Regex_Sub (2)","x":942,"y":246,"id":17,"input":"cache/text7-3","expression":"\\([\\w\\W]+?\\)","replacement":"","output":"cache/text17-3"},{"module":"RemoveEmptyLines","name":"RemoveEmptyLines (1)","x":943,"y":475,"id":18,"input":"cache/text17-3","output":"cache/text18-1"},{"module":"RemoveDuplicates","name":"RemoveDuplicates (1)","x":1184,"y":16,"id":9,"input":"cache/text18-1","output":"cache/text9-1"},{"module":"RandomizeLines","name":"RandomizeLines (1)","x":1193,"y":180,"id":10,"input":"cache/text9-1","output":"cache/text10-1"},{"module":"WriteTextFile","name":"WriteTextFile (1)","x":1191,"y":407,"id":11,"input":"cache/text10-1","file":"training_data"},{"module":"CharRNN_Train","name":"CharRNN_Train (1)","x":1432,"y":4,"id":12,"data":"cache/text10-1","history":"10","layers":"2","hidden_nodes":"64","epochs":"150","learning_rate":"0.0001","model":"cache/model12-6","dictionary":"cache/dictionary12-7"},{"module":"RandomSequence","name":"RandomSequence (1)","x":1429,"y":415,"id":14,"input":"cache/text10-1","length":"10","output":"cache/text14-2"},{"module":"CharRNN_Run","name":"CharRNN_Run (1)","x":1640,"y":269,"id":13,"model":"cache/model12-6","dictionary":"cache/dictionary12-7","seed":"cache/text14-2","steps":"6000","temperature":"1.0","output":"cache/text13-5"},{"module":"TextSubtract","name":"TextSubtract (1)","x":1848,"y":48,"id":15,"main":"cache/text13-5","subtract":"cache/text10-1","diff":"cache/text15-2"},{"module":"Spellcheck","name":"Spellcheck (1)","x":1845,"y":248,"id":19,"input":"cache/text15-2","output":"cache/text19-1"},{"module":"WriteTextFile","name":"WriteTextFile (2)","x":1843,"y":413,"id":20,"input":"cache/text19-1","file":"my_output_file"}]
--------------------------------------------------------------------------------
/hooks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from pathlib import Path
4 | from PIL import Image
5 | from IPython.display import display
6 |
7 | TRASH_PATH = '/content/.trash'
8 |
9 | def python_cwd_hook_aux(dir):
10 | result = {}
11 | for file in os.listdir(dir):
12 | path = os.path.join(dir, file)
13 | if os.path.isdir(path):
14 | file = file + '/'
15 | result[file] = path
16 | result['./'] = os.getcwd()
17 | if dir != '/':
18 | parent_dir = str(Path(dir).parent)
19 | result['../'] = parent_dir
20 | return result
21 |
22 | def python_move_hook_aux(path1, path2):
23 | status = False
24 | if os.path.exists(path1):
25 | shutil.move(path1, path2)
26 | status = True
27 | return status
28 |
29 | def python_copy_hook_aux(path1, path2):
30 | status = False
31 | if os.path.exists(path1):
32 | # path 1 exists
33 | if os.path.isdir(path1):
34 | # path1 is a directory
35 | if os.path.exists(path2):
36 | # path2 exists
37 | if os.path.isdir(path2):
38 | # copying directory to selected directory
39 | # make new directory inside with same name as path1
40 | basename1 = os.path.basename(os.path.normpath(path1))
41 | path2 = os.path.join(path2, basename1)
42 | shutil.copytree(path1, path2)
43 | status = True
44 | else:
45 | # copy directory to a file
46 | # can't do that
47 | print("cannot copy directory inside a file")
48 | else:
49 | # path1 is a file
50 | # doesn't matter if path2 is a file or directory
51 | shutil.copy(path1, path2)
52 | status = True
53 | return status
54 |
55 | def python_open_text_hook_aux(path):
56 | status = False
57 | if os.path.exists(path) and not os.path.isdir(path):
58 | with open(path, 'r') as file:
59 | try:
60 | text = file.read()
61 | print(text)
62 | status = True
63 | except:
64 | print("Cannot read text file", path)
65 | return status
66 |
67 | def python_open_image_hook_aux(path):
68 | status = False
69 | if os.path.exists(path) and not os.path.isdir(path):
70 | try:
71 | pil_im = Image.open(path, 'r')
72 | display(pil_im)
73 | status = True
74 | except:
75 | print("Cannot open image file", path)
76 | return status
77 |
78 | def python_save_hook_aux(file_text, filename):
79 | status = False
80 | with open(filename, 'w') as f:
81 | try:
82 | f.write(file_text)
83 | status = True
84 | except:
85 | print("Could not write to", filename)
86 | return status
87 |
88 | def python_load_hook_aux(filename):
89 | file_text = ''
90 | with open(filename, 'r') as f:
91 | try:
92 | file_text = f.read()
93 | status = True
94 | except:
95 | print("could not write to", filename)
96 | return file_text
97 |
98 | def python_mkdir_hook_aux(path, dir_name):
99 | status = False
100 | try:
101 | os.mkdir(os.path.join(path, dir_name))
102 | status = True
103 | except:
104 | print("Could not create directory " + dir_name + " in " + path)
105 | return status
106 |
107 | def python_trash_hook_aux(path):
108 | status = False
109 | try:
110 | if not os.path.exists(TRASH_PATH):
111 | os.mkdir(TRASH_PATH)
112 | if os.path.exists(path):
113 | shutil.move(path, TRASH_PATH)
114 | status = True
115 | except:
116 | print("Could not move " + path + " to " + TRASH_PATH)
117 | return status
118 |
--------------------------------------------------------------------------------
/easygen.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import copy
4 | from modules import *
5 | from image_modules import *
6 | import pdb
7 | import unicodedata
8 | import argparse
9 |
10 | TEMP_DIRECTORY = './cache'
11 | HISTORY_PATH = os.path.join(TEMP_DIRECTORY, '.history')
12 |
13 |
14 | parser = argparse.ArgumentParser(description='Run an easygen program.')
15 | parser.add_argument('program', help="the file containing the easgen program.")
16 | parser.add_argument('--clear_cache', help="Clear the cache", action='store_true')
17 |
18 | # Inputs:
19 | # - json description of the module
20 | # - dictionary of history for caching
21 | # Return 2 values:
22 | # - Whether the module was run (could have not run because of caching)
23 | # - The paths to files that were produced as outputs of the module (or none if the module wasn't ready)
24 | def runModule(module_json, history = {}):
25 | module_json_copy = copy.deepcopy(module_json)
26 | if 'module' in module_json_copy:
27 | module = module_json_copy['module']
28 | ## Convert the json to a set of parameters to pass into a class of the same name as the module name
29 | ## Take the module name out
30 | del module_json_copy['module']
31 | ## Take out the x and y
32 | if 'x' in module_json_copy:
33 | del module_json_copy['x']
34 | if 'y' in module_json_copy:
35 | del module_json_copy['y']
36 | if 'collapsed' in module_json_copy:
37 | del module_json_copy['collapsed']
38 | ## Take out the module name and id
39 | if 'name' in module_json_copy:
40 | del module_json_copy['name']
41 | if 'id' in module_json_copy:
42 | del module_json_copy['id']
43 |
44 | rest = str(module_json_copy)
45 |
46 | params = rest.replace('{', '').replace('}', '')
47 | #params = re.sub(r'u\'', '', params)
48 | #params = re.sub(r'\'', '', params)
49 | #params = re.sub(r': ', '=', params)
50 | p1 = re.compile(r'\'([0-9a-zA-Z\_]+)\':[\s]*(\'[\<\>\(\)0-9a-zA-Z\_\.\,\/\:\*\-\?\+\=\#\&\@\%\\\[\]\|\"\^ ]*\')')
51 | params = p1.sub(r"\1=\2", params)
52 | params = re.sub(r'\'True\'', 'True', params)
53 | params = re.sub(r'\'true\'', 'True', params)
54 | params = re.sub(r'\'False\'', 'False', params)
55 | params = re.sub(r'\'false\'', 'False', params)
56 | p2 = re.compile(r'\'([0-9]*[\.]*[0-9]+)\'')
57 | params = p2.sub(r'\1', params)
58 | ## Put the module name back on as class name
59 | evalString = module + '(' + convertHexToASCII(params) + ')'
60 | print("Module:", evalString)
61 | ## If everything went well, we can now evaluate the string and create a new class.
62 | module = eval(evalString)
63 | ## Do we need to run it?
64 | done_files = [] # The files that are in temp and generated by an identical process
65 | for key in history.keys():
66 | # key is a filename
67 | value = history[key] # string descriptor of the process that generated the file
68 | # Is the descriptor the same and does the file exist?
69 | if str(value) == str(module_json) and os.path.exists(key):
70 | # and if it's a directory is it non-empty?
71 | if os.path.isfile(key) or (os.path.isdir(key) and len(os.listdir(key)) > 0):
72 | # It is!
73 | done_files.append(key)
74 | output_files = module.output_files # These are the files we would expect to see
75 | # Check if all the output files are already done
76 | if len(set(output_files).difference(set(done_files))) == 0:
77 | # All the files we were going to output have been generated by an identical process
78 | print("Using cached data.")
79 | return False, output_files
80 | else:
81 | ## Run the class.
82 | if module.ready:
83 | print("Running...")
84 | module.run()
85 | print("Done.")
86 | return True, output_files
87 | else:
88 | print("Module not ready.")
89 | return False, None
90 |
91 | def main(program_path):
92 | ### Make sure required directories exist
93 | if not os.path.exists(TEMP_DIRECTORY):
94 | os.makedirs(TEMP_DIRECTORY)
95 |
96 | ### Read the history file
97 | history = {}
98 | if os.path.exists(HISTORY_PATH):
99 | with open(HISTORY_PATH, 'r') as historyfile:
100 | history_text = historyfile.read().strip()
101 | if len(history_text) > 0:
102 | history = eval(history_text)
103 |
104 | ### Read in the program file
105 | data_text = ''
106 | data = None
107 | for line in open(program_path, "r"):
108 | data_text = data_text + line.strip()
109 |
110 | ### Convert to json dictionary
111 | data = json.loads(data_text)
112 | #data = byteify(data)
113 |
114 | use_caching = True # Should we use the caching system?
115 |
116 | ### Run each module
117 | if data is not None:
118 | print("Running", program_path)
119 | for d in data:
120 | # Run the module!
121 | executed, output_files = runModule(d, history if use_caching else {})
122 | # If the module ran and produced output_files then we shouldn't use caching anymore
123 | # If the module didn't run and output_files is None then it wasn't ready
124 | # and we are probably dead in the water
125 | # If the module didn't run and produced output_files then it is drawing from the cache,
126 | # keep doing so
127 | if (executed and output_files is not None) or (not executed and output_files is None):
128 | use_caching = False
129 | # Update the history
130 | if output_files is not None:
131 | for file in output_files:
132 | history[file] = d
133 |
134 | # Write the history to file
135 | with open(HISTORY_PATH, 'w') as historyfile:
136 | # Clean up the history file
137 | history_copy = copy.deepcopy(history)
138 | for file in history_copy.keys():
139 | if not os.path.exists(file):
140 | del history[file]
141 | # Now write
142 | historyfile.write(str(history))
143 | else:
144 | # No data to run
145 | print("Program is empty")
146 |
147 | if __name__ == '__main__':
148 | args = parser.parse_args()
149 | if args.clear_cache:
150 | shutil.rmtree(TEMP_DIRECTORY)
151 | main(args.program)
--------------------------------------------------------------------------------
/file_manager.js:
--------------------------------------------------------------------------------
1 | ////////////////////////////
2 | // GLOBALS
3 |
4 | var path1 = '/content' // the cwd of the first file list box
5 | var path2 = '/content' // the cwd of the second file list box
6 | var selected1 = '/content' // the path to a file selected in the first file list box
7 | var selected2 = '/content' // the path to a file selected in the second file list box
8 |
9 | // Call python and get a dictionary containing file names and their paths
10 | function get_files(path, list_id) {
11 | async function foo() {
12 | console.log(path);
13 | const result = await google.colab.kernel.invokeFunction(
14 | 'notebook.python_cwd_hook', // The callback name.
15 | [path], // The arguments.
16 | {}); // kwargs
17 | return result;
18 | };
19 | foo().then(function(value) {
20 | // parse the return value
21 | var returned = value.data['application/json'];
22 | var dict = eval(returned.result); // dictionary of filenames and full paths
23 | var file_list = document.getElementById(list_id); // list box html element
24 | // Clear the list box
25 | removeOptions(file_list);
26 | var files = []; // filenames
27 | var key;
28 | // Move filesnames (keys) out of dictionary into files list
29 | for (key in dict) {
30 | files.push(key);
31 | }
32 | // Sort the files
33 | var sorted_files = files.sort(); // the sorted file list
34 | // But make sure . is at the top of the list
35 | var files_temp = []
36 | files_temp.push('./')
37 | var i;
38 | for (i = 0; i < sorted_files.length; i++) {
39 | if (sorted_files[i] !== './') {
40 | files_temp.push(sorted_files[i])
41 | }
42 | }
43 | sorted_files = files_temp;
44 | // ASSERT: files are sorted and . is at the top of the list
45 | // Populate the list box
46 | var i;
47 | for (i = 0; i < sorted_files.length; i++) {
48 | var file = sorted_files[i];
49 | var val_path = dict[file];
50 | var opt = document.createElement('option');
51 | opt.value = val_path;
52 | opt.innerHTML = file;
53 | file_list.appendChild(opt);
54 | }
55 | });
56 | // ASSERT: nothing after here guaranteed to be executed before foo returns
57 | }
58 |
59 | // Remove all options from a select box
60 | function removeOptions(selectbox) {
61 | var i;
62 | for(i = selectbox.options.length - 1 ; i >= 0 ; i--) {
63 | selectbox.remove(i);
64 | }
65 | }
66 |
67 |
68 |
69 | // Set up the select boxes with double_click callbacks
70 | var file_list1 = document.getElementById("file_list1");
71 | file_list1.ondblclick = function(){
72 | var filename = this.options[this.selectedIndex].innerHTML;
73 | var path = this.options[this.selectedIndex].value;
74 | if (filename[filename.length-1] === "/") {
75 | // this is a directory
76 | path1 = path;
77 | selected1 = path;
78 | update_gui(path, "path1", "file_list1");
79 | }
80 | };
81 | file_list1.onclick = function() {
82 | var filename = this.options[this.selectedIndex].innerHTML;
83 | var path = this.options[this.selectedIndex].value;
84 | selected1 = path;
85 | };
86 |
87 | var file_list2 = document.getElementById("file_list2");
88 | file_list2.ondblclick = function(){
89 | var filename = this.options[this.selectedIndex].innerHTML;
90 | var path = this.options[this.selectedIndex].value;
91 | if (filename[filename.length-1] === "/") {
92 | // this is a directory
93 | path2 = path;
94 | selected2 = path;
95 | update_gui(path, "path2", "file_list2");
96 | }
97 | };
98 | file_list2.onclick = function() {
99 | var filename = this.options[this.selectedIndex].innerHTML;
100 | var path = this.options[this.selectedIndex].value;
101 | selected2 = path;
102 | };
103 |
104 | // update the gui if a path has changed
105 | function update_gui(path, dir_id, list_id) {
106 | var path_text = document.getElementById(dir_id);
107 | path_text.innerHTML = path;
108 | get_files(path, list_id);
109 | }
110 |
111 | // Copy button
112 | function do_copy_mouse_up() {
113 | (async function() {
114 | const result = await google.colab.kernel.invokeFunction(
115 | 'notebook.python_copy_hook', // The callback name.
116 | [selected1, selected2], // The arguments.
117 | {}); // kwargs
118 | const res = result.data['application/json'];
119 | })();
120 | update_gui(path1, "path1", "file_list1")
121 | update_gui(path2, "path2", "file_list2")
122 | }
123 |
124 | // Move button
125 | function do_move_mouse_up() {
126 | (async function() {
127 | const result = await google.colab.kernel.invokeFunction(
128 | 'notebook.python_move_hook', // The callback name.
129 | [selected1, selected2], // The arguments.
130 | {}); // kwargs
131 | const res = result.data['application/json'];
132 | })();
133 | update_gui(path1, "path1", "file_list1")
134 | update_gui(path2, "path2", "file_list2")
135 | }
136 |
137 | // Open text button
138 | function do_open_text_mouse_up() {
139 | (async function() {
140 | const result = await google.colab.kernel.invokeFunction(
141 | 'notebook.python_open_text_hook', // The callback name.
142 | [selected1], // The arguments.
143 | {}); // kwargs
144 | const res = result.data['application/json'];
145 | })();
146 | }
147 |
148 | // Open image button
149 | function do_open_image_mouse_up() {
150 | (async function() {
151 | const result = await google.colab.kernel.invokeFunction(
152 | 'notebook.python_open_image_hook', // The callback name.
153 | [selected1], // The arguments.
154 | {}); // kwargs
155 | const res = result.data['application/json'];
156 | })();
157 | }
158 |
159 | function do_mkdir_mouse_up() {
160 | var input_box = document.getElementById('mkdir_input');
161 | dir_name = input_box.value;
162 | (async function() {
163 | const result = await google.colab.kernel.invokeFunction(
164 | 'notebook.python_mkdir_hook', // The callback name.
165 | [selected1, dir_name], // The arguments.
166 | {}); // kwargs
167 | const res = result.data['application/json'];
168 | })();
169 | update_gui(path1, "path1", "file_list1")
170 | update_gui(path2, "path2", "file_list2")
171 | }
172 |
173 | function do_trash_mouse_up() {
174 | (async function() {
175 | const result = await google.colab.kernel.invokeFunction(
176 | 'notebook.python_trash_hook', // The callback name.
177 | [selected1], // The arguments.
178 | {}); // kwargs
179 | const res = result.data['application/json'];
180 | })();
181 | update_gui(path1, "path1", "file_list1")
182 | update_gui(path2, "path2", "file_list2")
183 | }
184 |
185 | // GO
186 | update_gui(path1, "path1", "file_list1")
187 | update_gui(path2, "path2", "file_list2")
--------------------------------------------------------------------------------
/style_transfer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 |
6 | import PIL
7 | from PIL import Image
8 | import matplotlib.pyplot as plt
9 |
10 | import torchvision.transforms as transforms
11 | import torchvision.models as models
12 |
13 | import copy
14 | import re
15 |
16 | ######################################
17 | ## GLOBALS
18 |
19 | # desired depth layers to compute style/content losses :
20 | CONTENT_LAYERS_DEFAULT = ['conv_4']
21 | STYLE_LAYERS_DEFAULT = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
22 | CNN_NORMALIZATION_MEAN = torch.tensor([0.485, 0.456, 0.406])
23 | CNN_NORMALIZATION_STD = torch.tensor([0.229, 0.224, 0.225])
24 |
25 |
26 |
27 | #######################################
28 | ## CLASSES
29 |
30 | class ContentLoss(nn.Module):
31 |
32 | def __init__(self, target,):
33 | super(ContentLoss, self).__init__()
34 | # we 'detach' the target content from the tree used
35 | # to dynamically compute the gradient: this is a stated value,
36 | # not a variable. Otherwise the forward method of the criterion
37 | # will throw an error.
38 | self.target = target.detach()
39 |
40 | def forward(self, input):
41 | self.loss = F.mse_loss(input, self.target)
42 | return input
43 |
44 |
45 | class StyleLoss(nn.Module):
46 |
47 | def __init__(self, target_feature):
48 | super(StyleLoss, self).__init__()
49 | self.target = gram_matrix(target_feature).detach()
50 |
51 | def forward(self, input):
52 | G = gram_matrix(input)
53 | self.loss = F.mse_loss(G, self.target)
54 | return input
55 |
56 | # create a module to normalize input image so we can easily put it in a
57 | # nn.Sequential
58 | class Normalization(nn.Module):
59 | def __init__(self, mean, std):
60 | super(Normalization, self).__init__()
61 | # .view the mean and std to make them [C x 1 x 1] so that they can
62 | # directly work with image Tensor of shape [B x C x H x W].
63 | # B is batch size. C is number of channels. H is height and W is width.
64 | self.mean = torch.tensor(mean).view(-1, 1, 1)
65 | self.std = torch.tensor(std).view(-1, 1, 1)
66 |
67 | def forward(self, img):
68 | # normalize img
69 | return (img - self.mean) / self.std
70 |
71 | #######################################
72 | ## HELPERS
73 |
74 | def tensor_to_image(tensor):
75 | t = transforms.ToPILImage() # reconvert into PIL image
76 | image = tensor.cpu().clone() # we clone the tensor to not do changes on it
77 | image = image.squeeze(0) # remove the fake batch dimension
78 | image = t(image)
79 | return image
80 |
81 | def image_loader(image_name, size, device):
82 | # transform images to same size
83 | t = transforms.Compose([transforms.Resize(size), # scale imported image
84 | transforms.ToTensor()]) # transform it into a torch tensor
85 | image = Image.open(image_name)
86 | if image.mode != 'RGB':
87 | image = image.convert('RGB')
88 | # fake batch dimension required to fit network's input dimensions
89 | image = image.resize((size, size), PIL.Image.ANTIALIAS)
90 | image = t(image).unsqueeze(0)
91 | return image.to(device, torch.float)
92 |
93 |
94 | def gram_matrix(input):
95 | a, b, c, d = input.size() # a=batch size(=1)
96 | # b=number of feature maps
97 | # (c,d)=dimensions of a f. map (N=c*d)
98 |
99 | features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
100 |
101 | G = torch.mm(features, features.t()) # compute the gram product
102 |
103 | # we 'normalize' the values of the gram matrix
104 | # by dividing by the number of element in each feature maps.
105 | return G.div(a * b * c * d)
106 |
107 | def get_input_optimizer(input_img):
108 | # this line to show that input is a parameter that requires a gradient
109 | optimizer = optim.LBFGS([input_img.requires_grad_()])
110 | return optimizer
111 |
112 | def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
113 | style_img, content_img, device,
114 | content_layers = CONTENT_LAYERS_DEFAULT,
115 | style_layers = CONTENT_LAYERS_DEFAULT):
116 | cnn = copy.deepcopy(cnn)
117 |
118 | # normalization module
119 | normalization = Normalization(normalization_mean, normalization_std).to(device)
120 |
121 | # just in order to have an iterable access to or list of content/syle
122 | # losses
123 | content_losses = []
124 | style_losses = []
125 |
126 | # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
127 | # to put in modules that are supposed to be activated sequentially
128 | model = nn.Sequential(normalization)
129 |
130 | i = 0 # increment every time we see a conv
131 | for layer in cnn.children():
132 | if isinstance(layer, nn.Conv2d):
133 | i += 1
134 | name = 'conv_{}'.format(i)
135 | elif isinstance(layer, nn.ReLU):
136 | name = 'relu_{}'.format(i)
137 | # The in-place version doesn't play very nicely with the ContentLoss
138 | # and StyleLoss we insert below. So we replace with out-of-place
139 | # ones here.
140 | layer = nn.ReLU(inplace=False)
141 | elif isinstance(layer, nn.MaxPool2d):
142 | name = 'pool_{}'.format(i)
143 | elif isinstance(layer, nn.BatchNorm2d):
144 | name = 'bn_{}'.format(i)
145 | else:
146 | raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
147 |
148 | model.add_module(name, layer)
149 |
150 | if name in content_layers:
151 | # add content loss:
152 | target = model(content_img).detach()
153 | content_loss = ContentLoss(target)
154 | model.add_module("content_loss_{}".format(i), content_loss)
155 | content_losses.append(content_loss)
156 |
157 | if name in style_layers:
158 | # add style loss:
159 | target_feature = model(style_img).detach()
160 | style_loss = StyleLoss(target_feature)
161 | model.add_module("style_loss_{}".format(i), style_loss)
162 | style_losses.append(style_loss)
163 |
164 | # now we trim off the layers after the last content and style losses
165 | for i in range(len(model) - 1, -1, -1):
166 | if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
167 | break
168 |
169 | model = model[:(i + 1)]
170 |
171 | return model, style_losses, content_losses
172 |
173 | ###########################################
174 |
175 | def run_style_transfer(cnn, content_img, style_img, input_img, device,
176 | normalization_mean = CNN_NORMALIZATION_MEAN,
177 | normalization_std = CNN_NORMALIZATION_STD,
178 | content_layers = CONTENT_LAYERS_DEFAULT,
179 | style_layers = STYLE_LAYERS_DEFAULT,
180 | num_steps = 300,
181 | style_weight = 1000000,
182 | content_weight = 1
183 | ):
184 | """Run the style transfer."""
185 | print('Building the style transfer model..')
186 | model, style_losses, content_losses = get_style_model_and_losses(cnn,
187 | normalization_mean.to(device), normalization_std.to(device), style_img, content_img, device,
188 | content_layers, style_layers)
189 | optimizer = get_input_optimizer(input_img)
190 |
191 | best_img = [None]
192 | best_score = [None]
193 |
194 | print('Optimizing..')
195 | run = [0]
196 | while run[0] <= num_steps:
197 |
198 | def closure():
199 | # correct the values of updated input image
200 | input_img.data.clamp_(0, 1)
201 |
202 | optimizer.zero_grad()
203 | model(input_img)
204 | style_score = 0
205 | content_score = 0
206 |
207 | for sl in style_losses:
208 | style_score += sl.loss
209 | for cl in content_losses:
210 | content_score += cl.loss
211 |
212 | style_score *= style_weight
213 | content_score *= content_weight
214 |
215 | loss = style_score + content_score
216 | loss.backward()
217 |
218 | run[0] += 1
219 | if run[0] % 50 == 0:
220 | print("run {}:".format(run))
221 | print('Style Loss : {:4f} Content Loss: {:4f}'.format(
222 | style_score.item(), content_score.item()))
223 | print()
224 | current_score = style_score.item() + content_score.item()
225 | if best_img[0] is None or current_score <= best_score[0]:
226 | best_img[0] = input_img.clone()
227 | best_img[0].data.clamp_(0, 1)
228 | best_score[0] = current_score
229 |
230 | return style_score + content_score
231 |
232 | optimizer.step(closure)
233 | # a last correction... # not sure I need to do this
234 | input_img.data.clamp_(0, 1)
235 |
236 | return best_img[0]
237 |
238 | ##########################################
239 | def process_layers_spec(spec):
240 | spec = str(spec)
241 | layers = re.findall(r'[\-0-9]+', spec)
242 | layers = [int(num) for num in layers]
243 | if -1 in layers or 0 in layers:
244 | return STYLE_LAYERS_DEFAULT
245 | layers = list(filter(lambda x: x >= 1 and x <= 5, layers))
246 | layers = sorted(layers)
247 | layers = ['conv_' + str(num) for num in layers]
248 | return layers
249 |
250 | ##########################################
251 |
252 | def run(content_image_path, style_image_path, output_path,
253 | image_size = 512, num_steps = 300, style_weight = 1000000, content_weight = 1,
254 | content_layers_spec='4',
255 | style_layers_spec = '1, 2, 3, 4, 5'):
256 | # CUDA or CPU
257 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
258 |
259 | # process content layers specification
260 | content_layers = process_layers_spec(content_layers_spec)
261 | style_layers = process_layers_spec(style_layers_spec)
262 |
263 | # Load images
264 | style_img = image_loader(style_image_path, image_size, device)
265 | content_img = image_loader(content_image_path, image_size, device)
266 |
267 | assert style_img.size() == content_img.size(), "style and content images must be the same size"
268 |
269 | # Load the VGG CNN. Download if necessary.
270 | cnn = models.vgg19(pretrained=True).features.to(device).eval()
271 |
272 | input_img = content_img.clone()
273 | # if you want to use white noise instead uncomment the below line:
274 | # input_img = torch.randn(content_img.data.size(), device=device)
275 |
276 | output = run_style_transfer(cnn,
277 | content_img, style_img, input_img, device,
278 | num_steps = num_steps,
279 | style_weight = style_weight,
280 | content_weight = content_weight,
281 | content_layers = content_layers,
282 | style_layers = style_layers)
283 | img = tensor_to_image(output)
284 | img.save(output_path, "JPEG")
285 | return img
286 |
287 |
--------------------------------------------------------------------------------
/readWikipedia.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from xml.etree import ElementTree as etree
4 | import os
5 | import sys
6 | from bs4 import BeautifulSoup
7 | import functools
8 |
9 | wikiDir = "" # The directory to find all the wiki files
10 | outFilename = "" # The filename to dump plots to
11 | titleFilename = "" # The filename to dump title names to
12 |
13 | pattern = '*:plot'
14 | breakSentences = True
15 |
16 | okTags = ['b', 'i', 'a', 'strong', 'em']
17 | listTags = ['ul', 'ol']
18 | contentTypes = ['text', 'list']
19 |
20 |
21 | # Check command line parameters
22 | '''
23 | if len(sys.argv) > 3:
24 | wikiDir = sys.argv[1]
25 | outFilename = sys.argv[2]
26 | titleFilename = sys.argv[3]
27 | else:
28 | print "usage:", sys.argv[0], "wikidirectory resultfile titlefile"
29 | exit()
30 | '''
31 |
32 |
33 | ########################
34 | ### HELPER FUNCTIONS
35 |
36 | def matchPattern(pattern, str):
37 | if pattern == '*':
38 | return str
39 | else:
40 | pattern = pattern.split('|')
41 | for p in pattern:
42 | match = re.search(p+r'\b', str)
43 | if match is not None:
44 | return match.group(0)
45 | return ''
46 |
47 |
48 | def fixList(soup, depth = 0):
49 | result = ''
50 | if soup.name is None:
51 | lines = [s.strip() for s in soup.splitlines()]
52 | for line in lines:
53 | if len(line) > 0 and line[0] == '-':
54 | dashes = ''
55 | for n in range(depth):
56 | dashes = dashes + '-'
57 | result = result + dashes + ' ' + line[1:].strip() + '\n'
58 | elif soup.name in listTags:
59 | for child in soup.children:
60 | result = result + fixList(child, depth + 1)
61 | return result
62 |
63 |
64 |
65 | ########################
66 |
67 | def ReadWikipedia(wiki_directory, pattern, categoryPattern, out_file, titles_file):
68 | pattern = pattern.split(':')
69 | titlePattern = pattern.pop(0)
70 | headerList = pattern
71 |
72 | contentType = 'text' # The type of content to grab 'text' or 'list'. Should be the tail of the headerPattern.
73 |
74 | files = [] # All the wiki files
75 |
76 | # Get all the wiki files left by wikiextractor
77 | for dirname, dirnames, filenames in os.walk(os.path.join('.', wiki_directory)):
78 | for filename in filenames:
79 | if filename[0] != '.':
80 | files.append(os.path.join(dirname, filename))
81 |
82 | with open(out_file, "w") as outfile:
83 | # Opened the output file
84 | with open(titles_file, "w") as titlefile:
85 | # Opened the title file
86 | # Walk through each file. Each file has a json for each wikipedia article. Look for jsons with "plot" subheaders
87 | for file in files:
88 | #print >> outfile, "file:", file #FOR DEBUGGING
89 | data = [] # Each element is a json record
90 | # Read the file and get all the json records
91 | for line in open(file, 'r'):
92 | data.append(json.loads(line))
93 | # Look for pattern matches in heading tags inside the text of the json
94 | for j in data:
95 | # j is a json record
96 | titleMatch = matchPattern(titlePattern, j['title'])
97 | categoryMatch = matchPattern(categoryPattern, j['categories'])
98 | if len(titleMatch) > 0 and len(categoryMatch) > 0:
99 | print("title:", titleMatch, "in", j['title'])
100 | print("category:", categoryMatch, "in", j['categories'])
101 | # This json record is a match to titlePattern
102 | #print >> outfile, j['title'].encode('utf-8') #FOR DEBUGGING
103 | # Text element contains HTML
104 | soup = BeautifulSoup(j['text'].encode('utf-8'), "html.parser")
105 | result = "" # The result found (if any)
106 | inresult = False # Am I inside a result section of the article?
107 | previousHeaders = []
108 | headerIndex = 0
109 | # Walk through each element in the html soup object
110 | for n in range(len(soup.contents)):
111 | current = soup.contents[n] # The current html element
112 | if len(headerList) == 0:
113 | # If only title information is given, we just get everything
114 | if current is not None and current.name is None:
115 | result = result + current.strip() + ' '
116 | elif not inresult and current is not None and current.name is not None and current.name == 'h' + str(headerIndex + 2): # start with h2
117 | # Let's see if this header matches the current expected pattern
118 | #print >> outfile, "current(1):", current.name.encode('utf-8'), current.encode('utf-8')
119 | match = False
120 | if len(headerList) == 0:
121 | match = True
122 | elif len(headerList) > 0:
123 | match = matchPattern(headerList[headerIndex].lower(), current.get_text().lower())
124 | if match:
125 | # this header matches
126 | previousHeaders.append(current.get_text())
127 | #print >> outfile, "previousheaders(a):", map(lambda x: x.encode('utf-8'), previousHeaders)
128 | headerIndex = headerIndex + 1
129 | if headerIndex >= len(headerList):
130 | inresult = True
131 | elif headerList[headerIndex].lower() in contentTypes:
132 | inresult = True
133 | contentType = headerList[headerIndex].lower()
134 | else:
135 | previousHeaders = []
136 | elif inresult and current is not None and current.name is not None and current.name[0] == 'h' and int(current.name[1]) >= (headerIndex + 2):
137 | # I'm probably seeing a sub-heading inside of what I want
138 | previousHeaders.append(current.get_text())
139 | #print >> outfile, "previousheaders(b):", map(lambda x: x.encode('utf-8'), previousHeaders)
140 | elif inresult and current is not None and current.name is not None and current.name in listTags:
141 | # found a list inside what I am looking for
142 | result = result + '\n' + fixList(current) + '\n '
143 | elif inresult and current is not None and (current.name is None or current.name.lower() in okTags):
144 | # I'm probably looking at text inside of what I want
145 | #print >> outfile, "current(3):", current.encode('utf-8')
146 | if contentType != 'list':
147 | current = current.strip()
148 | # Sometimes we see the header name duplicated inside the text block that succeeds the sub-section header. Crop it off
149 | if len(current) > 0:
150 | if len(previousHeaders) > 0:
151 | #print >> outfile, "previousheaders(c):", map(lambda x: x.encode('utf-8'), previousHeaders)
152 | headerLength = functools.reduce(lambda x,y: x+y, map(lambda z: len(z)+2, previousHeaders)) # add 2 for period and space.
153 | result = result + current[headerLength:].strip() + ' '
154 | else:
155 | result = result + current.strip() + ' '
156 | # Forget the previous header. It was either consumed or wasn't duplicated in the first place.
157 | previousHeaders = []
158 | elif inresult and current is not None and current.name is not None and current.name[0] == 'h' and int(current.name[1]) < (headerIndex + 2):
159 | # Probably left the block. All done with this json!
160 | break
161 | elif not inresult and current is not None and current.name is not None and current.name[0] == 'h' and int(current.name[1]) > 1 and int(current.name[1]) < (headerIndex + 2):
162 | # not in the result, but we went up one level
163 | headerIndex = headerIndex - 1
164 | if len(previousHeaders) > 0:
165 | previousHeaders.pop()
166 | # Let's see if this header matches the current expected pattern
167 | #print >> outfile, "current(2):", current.name.encode('utf-8'), current.encode('utf-8')
168 | match = matchPattern(headerList[headerIndex].lower(), current.get_text().lower())
169 | if len(match) > 0:
170 | # this header matches
171 | previousHeaders.append(current.get_text())
172 | #print >> outfile, "previousheaders(d):", map(lambda x: x.encode('utf-8'), previousHeaders)
173 | headerIndex = headerIndex + 1
174 | if headerIndex >= len(headerList):
175 | inresult = True
176 | elif headerList[headerIndex].lower() in contentTypes:
177 | inresult = True
178 | contentType = headerList[headerIndex].lower()
179 | elif not inresult and current is not None and current.name is not None and current.name[0] == 'h' and int(current.name[1]) > 1 and int(current.name[1]) >= (headerIndex + 2):
180 | # I'm not in the result block and I saw something that wasn't a header.
181 | previousHeaders = []
182 | #print >> outfile, "previous header cleared (e)", current.encode('utf-8')
183 | elif not inresult and current is not None and current.name is None and len(current.strip()) > 0:
184 | previousHeaders = []
185 | #print >> outfile, "previous header cleared (f)", current.encode('utf-8')
186 | # Did we find what we were looking for?
187 | if len(result) > 0:
188 | # ASSERT: I have a result
189 | # Record the name of the article with the result
190 | #titlefile.write(j['title'].encode('utf-8') + '\n')
191 | titlefile.write(j['title'] + '\n')
192 | '''
193 | # remove newlines
194 | #result = result.replace('\n', ' ').replace('\r', '').strip()
195 | # remove html tags (probably mainly hyperlinks)
196 | result = re.sub('<[^<]+?>', '', result)
197 | # remove character name initials and take periods off mr/mrs/ms/dr/etc.
198 | result = re.sub(' [M|m]r\.', ' mr', result)
199 | result = re.sub(' [M|m]rs\.', ' mrs', result)
200 | result = re.sub(' [M|m]s\.', ' ms', result)
201 | result = re.sub(' [D|d]r\.', ' dr', result)
202 | #result = re.sub(' [M|m]d\.', ' md', result)
203 | #result = re.sub(' [P|p][H|h][D|d]\.', ' phd', result)
204 | #result = re.sub(' [E|e][S|s][Q|q]\.', ' esq', result)
205 | result = re.sub(' [L|l][T|t]\.', ' lt', result)
206 | result = re.sub(' [G|g][O|o][V|v]\.', ' lt', result)
207 | result = re.sub(' [C|c][P|p][T|t]\.', ' cpt', result)
208 | result = re.sub(' [S|s][T|t]\.', ' st', result)
209 | # handle i.e. and cf.
210 | result = re.sub('i\.e\. ', 'ie ', result)
211 | result = re.sub('cf\. ', 'cf', result)
212 | # deal with periods in quotes
213 | result = re.sub('\.\"', '\".', result)
214 | # remove single letter initials
215 | p4 = re.compile(r'([ \()])([A-Z|a-z])\.')
216 | result = p4.sub(r'\1\2', result)
217 | # Acroymns with periods are not fun. Need two steps to get rid of those periods.
218 | # I don't think this is working quite right
219 | p1 = re.compile('([A-Z|a-z])\.([)|\"|\,])')
220 | result = p1.sub(r'\1\2', result)
221 | p2 = re.compile('\.([A-Z|a-z])')
222 | result = p2.sub(r'\1', result)
223 | # periods in numbers
224 | p3 = re.compile('([0-9]+)\.([0-9]+)')
225 | result = p3.sub(r'\1\2', result)
226 | '''
227 | # Print result
228 | if contentType == 'text':
229 | #print >> outfile, result.strip().encode('utf-8')
230 | outfile.write(result.strip() + '\n')
231 | elif contentType == 'list':
232 | lines = [s.strip() for s in result.splitlines()]
233 | for line in lines:
234 | if len(line) > 0 and line[0] == '-':
235 | #rint >> outfile, line.strip().encode('utf-8')
236 | outfile.write(line.strip())
237 | outfile.write('\n')
238 |
239 |
240 | ### TEST RUN
241 | #readWikipedia(wikiDir, pattern, outFilename, titleFilename, breakSentences)
242 |
243 |
244 |
245 |
246 | '''
247 | TODO: future modules:
248 | - Remove Empty lines
249 | - Remove newlines
250 | - Segment sentences into separate lines
251 | - Remove characters/words/substrings
252 | - strip all lines
253 | '''
--------------------------------------------------------------------------------
/stylegan_runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import sys
4 | import pickle
5 | import random
6 | import math
7 | import argparse
8 | import numpy as np
9 | from PIL import Image
10 | from tqdm import tqdm_notebook as tqdm
11 |
12 | def easygen_train(model_path, images_path, dataset_path, start_kimg=7000, max_kimg=25000, schedule='', seed=1000):
13 | #import stylegan
14 | #from stylegan import config
15 | ##from stylegan import dnnlib
16 | #from stylegan.dnnlib import EasyDict
17 |
18 | #images_dir = '/content/raw'
19 | #max_kimg = 25000
20 | #start_kimg = 7000
21 | #schedule = ''
22 | #model_in = '/content/karras2019stylegan-cats-256x256.pkl'
23 |
24 | #dataset_dir = '/content/stylegan_dataset' #os.path.join(cwd, 'cache', 'stylegan_dataset')
25 |
26 | import config
27 | config.data_dir = '/content/datasets'
28 | config.results_dir = '/content/results'
29 | config.cache_dir = '/contents/cache'
30 | run_dir_ignore = ['/contents/results', '/contents/datasets', 'contents/cache']
31 | import copy
32 | import dnnlib
33 | from dnnlib import EasyDict
34 | from metrics import metric_base
35 | # Prep dataset
36 | import dataset_tool
37 | print("prepping dataset...")
38 | dataset_tool.create_from_images(tfrecord_dir=dataset_path, image_dir=images_path, shuffle=False)
39 | # Set up training parameters
40 | desc = 'sgan' # Description string included in result subdir name.
41 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.
42 | G = EasyDict(func_name='training.networks_stylegan.G_style') # Options for generator network.
43 | D = EasyDict(func_name='training.networks_stylegan.D_basic') # Options for discriminator network.
44 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
45 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
46 | G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating') # Options for generator loss.
47 | D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp', r1_gamma=10.0) # Options for discriminator loss.
48 | dataset = EasyDict() # Options for load_dataset().
49 | sched = EasyDict() # Options for TrainingSchedule.
50 | grid = EasyDict(size='1080p', layout='random') # Options for setup_snapshot_image_grid().
51 | #metrics = [metric_base.fid50k] # Options for MetricGroup.
52 | submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run().
53 | tf_config = {'rnd.np_random_seed': seed} # Options for tflib.init_tf().
54 | # Dataset
55 | desc += '-custom'
56 | dataset = EasyDict(tfrecord_dir=dataset_path)
57 | train.mirror_augment = False
58 | # Number of GPUs.
59 | desc += '-1gpu'
60 | submit_config.num_gpus = 1
61 | sched.minibatch_base = 4
62 | sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4} #{4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 16}
63 | # Default options.
64 | train.total_kimg = max_kimg
65 | sched.lod_initial_resolution = 8
66 | sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
67 | sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)
68 | # schedule
69 | schedule_dict = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:30, 1024:20} #{4: 2, 8:2, 16:2, 32:2, 64:2, 128:2, 256:2, 512:2, 1024:2} # Runs faster for small datasets
70 | if len(schedule) >=5 and schedule[0] == '{' and schedule[-1] == '}' and ':' in schedule:
71 | # is schedule a string of a dict?
72 | try:
73 | temp = eval(schedule)
74 | schedule_dict = dict(temp)
75 | # assert: it is a dict
76 | except:
77 | pass
78 | elif len(schedule) > 0:
79 | # is schedule an int?
80 | try:
81 | schedule_int = int(schedule)
82 | #assert: schedule is an int
83 | schedule_dict = {}
84 | for i in range(1, 10):
85 | schedule_dict[int(math.pow(2, i+1))] = schedule_int
86 | except:
87 | pass
88 | print('schedule:', str(schedule_dict))
89 | sched.tick_kimg_dict = schedule_dict
90 | # resume kimg
91 | resume_kimg = start_kimg
92 | # path to model
93 | resume_run_id = model_path
94 | # tick snapshots
95 | image_snapshot_ticks = 1
96 | network_snapshot_ticks = 1
97 | # Submit run
98 | kwargs = EasyDict(train)
99 | kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
100 | kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, tf_config=tf_config)
101 | kwargs.update(resume_kimg=resume_kimg, resume_run_id=resume_run_id)
102 | kwargs.update(image_snapshot_ticks=image_snapshot_ticks, network_snapshot_ticks=network_snapshot_ticks)
103 | kwargs.submit_config = copy.deepcopy(submit_config)
104 | kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir)
105 | kwargs.submit_config.run_dir_ignore += config.run_dir_ignore
106 | kwargs.submit_config.run_desc = desc
107 | dnnlib.submit_run(**kwargs)
108 |
109 | def easygen_run(model_path, images_path, num=1):
110 | # from https://github.com/ak9250/stylegan-art/blob/master/styleganportraits.ipynb
111 | truncation = 0.7 # hard coding because everyone uses this value
112 | import dnnlib
113 | import dnnlib.tflib as tflib
114 | import config
115 | tflib.init_tf()
116 | #num = 10
117 | #model = '/content/karras2019stylegan-cats-256x256.pkl'
118 | #images_dir = '/content/cache/run_out'
119 | #truncation = 0.7
120 | _G = None
121 | _D = None
122 | Gs = None
123 | with open(model_path, 'rb') as f:
124 | _G, _D, Gs = pickle.load(f)
125 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
126 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
127 | latents = np.random.RandomState(int(1000*random.random())).randn(num, *Gs.input_shapes[0][1:])
128 | labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:])
129 | images = Gs.run(latents, None, truncation_psi=truncation, randomize_noise=False, output_transform=fmt)
130 | for n, image in enumerate(images):
131 | # img = Image.fromarray(images[0])
132 | img = Image.fromarray(image)
133 | img.save(os.path.join(images_path, str(n) + '.jpg'), "JPEG")
134 |
135 | def get_latent_interpolation(endpoints, num_frames_per, mode = 'linear', shuffle = False):
136 | if shuffle:
137 | random.shuffle(endpoints)
138 | num_endpoints, dim = len(endpoints), len(endpoints[0])
139 | num_frames = num_frames_per * num_endpoints
140 | endpoints = np.array(endpoints)
141 | latents = np.zeros((num_frames, dim))
142 | for e in range(num_endpoints):
143 | e1, e2 = e, (e+1)%num_endpoints
144 | for t in range(num_frames_per):
145 | frame = e * num_frames_per + t
146 | r = 0.5 - 0.5 * np.cos(np.pi*t/(num_frames_per-1)) if mode == 'ease' else float(t) / num_frames_per
147 | latents[frame, :] = (1.0-r) * endpoints[e1,:] + r * endpoints[e2,:]
148 | return latents
149 |
150 |
151 |
152 | def easygen_movie(model_path, movie_path, num=10, interp=10, duration=10):
153 | # from https://github.com/ak9250/stylegan-art/blob/master/styleganportraits.ipynb
154 | import dnnlib
155 | import dnnlib.tflib as tflib
156 | import config
157 | tflib.init_tf()
158 | truncation = 0.7 # what everyone uses
159 | # Get model
160 | _G = None
161 | _D = None
162 | Gs = None
163 | with open(model_path, 'rb') as f:
164 | _G, _D, Gs = pickle.load(f)
165 | # Make waypoints
166 | #fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
167 | #synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
168 | waypoint_latents = np.random.RandomState(int(1000*random.random())).randn(num, *Gs.input_shapes[0][1:])
169 | #waypoint_labels = np.zeros([waypoint_latents.shape[0]] + Gs.input_shapes[1][1:])
170 | #waypoint_images = Gs.run(latents, None, truncation_psi=truncation, randomize_noise=False, output_transform=fmt)
171 | # interpolate
172 | interp_latents = get_latent_interpolation(waypoint_latents, interp)
173 | interp_labels = np.zeros([interp_latents.shape[0]] + Gs.input_shapes[1][1:])
174 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
175 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
176 | batch_size = 8
177 | num_frames = interp_latents.shape[0]
178 | num_batches = int(np.ceil(num_frames/batch_size))
179 | images = []
180 | for b in tqdm(range(num_batches)):
181 | new_images = Gs.run(interp_latents[b*batch_size:min((b+1)*batch_size, num_frames-1), :], None, truncation_psi=truncation, randomize_noise=False, output_transform=fmt)
182 | for img in new_images:
183 | images.append(Image.fromarray(img)) # convert to PIL.Image
184 | images[0].save(movie_path, "GIF",
185 | save_all=True,
186 | append_images=images[1:],
187 | duration=duration,
188 | loop=0)
189 |
190 |
191 | if __name__ == '__main__':
192 | parser = argparse.ArgumentParser(description='Process runner commands.')
193 | parser.add_argument('--train', action="store_true", default=False)
194 | parser.add_argument('--run', action="store_true", default=False)
195 | parser.add_argument('--movie', action="store_true", default=False)
196 | parser.add_argument("--model", help="model to load", default="")
197 | parser.add_argument("--images_in", help="directory containing training images", default="")
198 | parser.add_argument("--images_out", help="diretory to store generated images", default="")
199 | parser.add_argument("--movie_out", help="directory to save movie", default="")
200 | parser.add_argument("--dataset_temp", help="where to store prepared image data", default="")
201 | parser.add_argument("--schedule", help="training schedule", default="")
202 | parser.add_argument("--max_kimg", help="iteration to stop training at", type=int, default=25000)
203 | parser.add_argument("--start_kimg", help="iteration to start training at", type=int, default=7000)
204 | parser.add_argument("--num", help="number of images to generate", type=int, default=1)
205 | parser.add_argument("--interp", help="number of images to interpolate", type=int, default=10)
206 | parser.add_argument("--duration", help="how long for each image in movie", type=int, default=10)
207 | parser.add_argument("--seed", help="seed number", type=int, default=1000)
208 | args = parser.parse_args()
209 | if args.train:
210 | easygen_train(model_path=args.model,
211 | images_path=args.images_in,
212 | dataset_path=args.dataset_temp,
213 | start_kimg=args.start_kimg,
214 | max_kimg=args.max_kimg,
215 | schedule=args.schedule,
216 | seed=args.seed)
217 | elif args.run:
218 | easygen_run(model_path=args.model,
219 | images_path=args.images_out,
220 | num=args.num)
221 | elif args.movie:
222 | easygen_movie(model_path=args.model,
223 | movie_path=args.movie_out,
224 | num=args.num,
225 | interp=args.interp,
226 | duration=args.duration)
227 |
--------------------------------------------------------------------------------
/Easygen.ipynb:
--------------------------------------------------------------------------------
1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Easygen.ipynb","version":"0.3.2","provenance":[{"file_id":"1PpvfuRxrR93GV6z_Asm7QSG3C9wiNX4l","timestamp":1561078408259}],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"XzgHTDSyzMRP","colab_type":"text"},"source":["# EasyGen\n","\n","EasyGen is a visual user interface to help set up simple neural network generation tasks.\n","\n","There are a number of neural network frameworks (Tensorflow, TFLearn, Keras, PyTorch, etc.) that implement standard algorithms for generating text (e.g., recurrent neural networks, sequence2sequence) and images (e.g., generative adversarial networks). They require some familairity with coding. Beyond that, just running a simple experiment may require a complex set of steps to clean and prepare the data.\n","\n","EasyGen allows one to quickly set up the data cleaning and neural network training pipeline using a graphical user interface and a number of self-contained \"modules\" that implement stardard data preparation routines. EasyGen differs from other neural network user interfaces in that it doesn't focus on the graphical instantiation of the neural network itself. Instead, it provides an easy to use way to instantiate some of the most common neural network algorithms used for generation. EasyGen focuses on the data preparation.\n","\n","For documentation see the [EasyGen Github repo](https://https://github.com/markriedl/easygen)\n","\n","To get started:\n","\n","1. Clone the [EasyGen notebook](https://drive.google.com/open?id=1XNiOuNtMnItl5CPGvRjEvj9C78nDuvXj) by following the link and selecting File -> Save a copy in Drive.\n","\n","2. Turn on GPU support under Edit -> Notebook setting.\n","\n","3. Run the cells in Sections 1. Some are optional if you know you aren't going to be using particular features.\n","\n","4. Run the cell in Section 2. If you know there are any models or datasets that you won't be using you can skip them.\n","\n","5. Run the cell in Section 3. This creates a blank area below the cell in which you can use the buttons to create your visual program. An example program is loaded automatically. You can clear it with the \"clear\" button below it. Afterwards you can create your own programs. Selecting \"Make New Module\" will cause the new module appears graphically above and can be dragged around. The inputs and outputs of different modules can be connected together by clicking on an output (red) and dragging to an input (green). Gray boxes are parameters that can be edited. Clicking on a gray box causes a text input field to appear at the bottom of the editing area, just above the \"Make New Module\" controls.\n","\n","6. Save your program by entering a program name and pressing the \"save\" button.\n","\n","7. Run your program by editing the program name in the cell in Section 4 and then running the cell."]},{"cell_type":"markdown","metadata":{"id":"1PJwtOgUXpAG","colab_type":"text"},"source":["# 1. Setup\n","\n","Download EasyGen"]},{"cell_type":"code","metadata":{"id":"cXGp8MWKPumP","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"outputId":"0b48a281-edce-4a16-b1d2-ddbabaed86d6","executionInfo":{"status":"ok","timestamp":1564620470959,"user_tz":240,"elapsed":3309,"user":{"displayName":"Mark Riedl","photoUrl":"https://lh4.googleusercontent.com/-VkUvOUfU44c/AAAAAAAAAAI/AAAAAAAAAG4/lS4Rpm-pe_0/s64/photo.jpg","userId":"17417011882129852150"}}},"source":["!git clone https://github.com/markriedl/easygen.git\n","!cp easygen/*.js /usr/local/share/jupyter/nbextensions/google.colab/\n","!cp easygen/images/*.png /usr/local/share/jupyter/nbextensions/google.colab/"],"execution_count":1,"outputs":[{"output_type":"stream","text":["fatal: destination path 'easygen' already exists and is not an empty directory.\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"9XikLU2ez0Yz","colab_type":"text"},"source":["Install requirements"]},{"cell_type":"code","metadata":{"id":"jlgTqMtFX3NU","colab_type":"code","colab":{}},"source":["!apt-get update\n","!apt-get install chromium-chromedriver\n","!pip install -r easygen/requirements.txt"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gqeftQQ6z3we","colab_type":"text"},"source":["Download StyleGAN"]},{"cell_type":"code","metadata":{"id":"l9jSPMsgj0hI","colab_type":"code","colab":{}},"source":["!git clone https://github.com/NVlabs/stylegan.git\n","!cp easygen/stylegan_runner.py stylegan"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BJNElaVez7JW","colab_type":"text"},"source":["Download GPT-2"]},{"cell_type":"code","metadata":{"id":"3yHkEkVNj2Ih","colab_type":"code","colab":{}},"source":["!git clone https://github.com/nshepperd/gpt-2"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4LW3Il9nzg5M","colab_type":"text"},"source":["Create backend hooks for saving and loading programs"]},{"cell_type":"code","metadata":{"id":"bV_TJbUjzfza","colab_type":"code","colab":{}},"source":["import IPython\n","from google.colab import output\n","\n","def python_save_hook(file_text, filename):\n"," import easygen\n"," import hooks\n"," status = hooks.python_save_hook_aux(file_text, filename)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_load_hook(filename):\n"," import easygen\n"," import hooks\n"," result = hooks.python_load_hook_aux(filename)\n"," return IPython.display.JSON({'result': result})\n","\n","def python_cwd_hook(dir):\n"," import easygen\n"," import hooks\n"," result = hooks.python_cwd_hook_aux(dir)\n"," return IPython.display.JSON({'result': result})\n","\n","def python_copy_hook(path1, path2):\n"," import easygen\n"," import hooks\n"," status = hooks.python_copy_hook_aux(path1, path2)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_move_hook(path1, path2):\n"," import easygen\n"," import hooks\n"," status = hooks.python_move_hook_aux(path1, path2)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_open_text_hook(path):\n"," import easygen\n"," import hooks\n"," status = hooks.python_open_text_hook_aux(path)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n"," \n","def python_open_image_hook(path):\n"," import easygen\n"," import hooks\n"," status = hooks.python_open_image_hook_aux(path)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_mkdir_hook(path, dir_name):\n"," import easygen\n"," import hooks\n"," status = hooks.python_mkdir_hook_aux(path, dir_name)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_trash_hook(path):\n"," import easygen\n"," import hooks\n"," status = hooks.python_trash_hook_aux(path)\n"," ret_status = 'true' if status else 'false'\n"," return IPython.display.JSON({'result': ret_status})\n","\n","def python_run_hook(path):\n"," import easygen\n"," program_file_name = path\n"," easygen.main(program_file_name)\n"," return IPython.display.JSON({'result': 'true'})\n","\n","output.register_callback('notebook.python_cwd_hook', python_cwd_hook)\n","output.register_callback('notebook.python_copy_hook', python_copy_hook)\n","output.register_callback('notebook.python_move_hook', python_move_hook)\n","output.register_callback('notebook.python_open_text_hook', python_open_text_hook)\n","output.register_callback('notebook.python_open_image_hook', python_open_image_hook)\n","output.register_callback('notebook.python_save_hook', python_save_hook)\n","output.register_callback('notebook.python_load_hook', python_load_hook)\n","output.register_callback('notebook.python_mkdir_hook', python_mkdir_hook)\n","output.register_callback('notebook.python_trash_hook', python_trash_hook)\n","output.register_callback('notebook.python_run_hook', python_run_hook)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"R8WwVL1g3MSh","colab_type":"text"},"source":["Import EasyGen"]},{"cell_type":"code","metadata":{"id":"cZ9sZmeR3KsR","colab_type":"code","colab":{}},"source":["import sys\n","sys.path.insert(0, 'easygen')\n","import easygen"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nvMUzJGmXy44","colab_type":"text"},"source":["# 2. Download pre-trained neural network models\n","\n","## 2.1 Download GPT-2 and StyleGan models\n","\n","Download the GPT-2 small 117M model. Will save to ```models/117M``` directory.\n"]},{"cell_type":"code","metadata":{"id":"vqaLwY9GXpuO","colab_type":"code","colab":{}},"source":["!python gpt-2/download_model.py 117M"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wwUI0RL79R3m","colab_type":"text"},"source":["Download the GPT-2 medium 345M model. Will save to ```models/345M``` directory."]},{"cell_type":"code","metadata":{"id":"QZe-0_gY9FWx","colab_type":"code","colab":{}},"source":["!python gpt-2/download_model.py 345M"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YHmVcSNbkVvo","colab_type":"text"},"source":["Download the StyleGAN cats model (256x256). Will save as \"cats256x256.pkl\" in the home directory."]},{"cell_type":"code","metadata":{"id":"njbbbOBRkVXF","colab_type":"code","colab":{}},"source":["!wget -O cats256x256.pkl https://www.dropbox.com/s/1w97383h0nrj4ea/karras2019stylegan-cats-256x256.pkl?dl=0"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zEqyG5DMX7CI","colab_type":"text"},"source":["## 2.2 Download Wikipedia\n","\n","You only need to do this if you are using the ```ReadWikipedia``` functionality. This takes a long time. You may want to skip it if you know you wont be scraping data from Wikipedia."]},{"cell_type":"code","metadata":{"id":"A7LWxcOzWjaV","colab_type":"code","colab":{}},"source":["!wget -O wiki.zip https://www.dropbox.com/s/39w6mj1akwy2a0r/wiki.zip?dl=0\n","!unzip wiki.zip"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"oX_AG49uYV23","colab_type":"text"},"source":["# 3. Run the GUI\n","\n","Run the cell below. This will load a default example program that generate new, fictional paint names. Use the \"clear\" button to clear it and make your own.\n","\n","When done, name the program and press the \"save\" button. You should see your file appear in the file listing in the left panel."]},{"cell_type":"code","metadata":{"id":"NFpPTx0oN8N9","colab_type":"code","colab":{}},"source":["%%html\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","
\n","
text
\n"," text\n"," \n"," \n","
\n","
\n","
Make New Module
\n"," \n"," \n","
\n","
\n","
Save Program
\n"," \n"," \n","
\n","
\n","
Load Program
\n"," \n"," \n","
\n","
\n","
Clear Program
\n"," \n","
\n","\n",""],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xUmugXljEFw8","colab_type":"text"},"source":["# 4. Run Your Program"]},{"cell_type":"markdown","metadata":{"id":"GA0sfrdnaDAV","colab_type":"text"},"source":["This will run a default example program that will generate new, fictional paint names. If you don't want to run that program, skip to the next cell."]},{"cell_type":"code","metadata":{"id":"4klnhCIosXQU","colab_type":"code","colab":{}},"source":["program_file_name = 'easygen/examples/make_new_colors'\n","easygen.main(program_file_name)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vnc0wSWEp2m2","colab_type":"text"},"source":["Once you've made your own program, run the cell below, enter the program below, and the press the run button.\n"]},{"cell_type":"code","metadata":{"id":"2BNineAvps7j","colab_type":"code","colab":{}},"source":["%%html\n","\n","\n","\n","\n","Run Program: \n","\n",""],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3umTq8CjEWnl","colab_type":"text"},"source":["# 5. View Your Output Files"]},{"cell_type":"markdown","metadata":{"id":"L0fH_pqlaJTP","colab_type":"text"},"source":["Run the cell below to launch a file manager so you can view your text and image files. \n","\n","You can use the panel to the left to download any files written to disk."]},{"cell_type":"code","metadata":{"id":"Sf7C3Y2J3mUw","colab_type":"code","colab":{}},"source":["%%html\n","\n","\n","\n","
Manage Files
\n","
\n","
/content
/content
\n","
\n","
\n","
\n","
Make Directories
\n","\n","\n",""],"execution_count":0,"outputs":[]}]}
--------------------------------------------------------------------------------
/gpt2.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow.core.protobuf import rewriter_config_pb2
6 | import sys
7 | import time
8 | import pdb
9 | import shutil
10 |
11 | CUR_PATH = os.getcwd()
12 | GPT_PATH = os.path.join(CUR_PATH, 'gpt-2')
13 | sys.path.insert(0, os.path.join(GPT_PATH, 'src'))
14 |
15 |
16 | import model
17 | import sample
18 | import encoder
19 | import load_dataset
20 | from load_dataset import load_dataset, Sampler
21 | from accumulate import AccumulatingOptimizer
22 | import memory_saving_gradients
23 |
24 |
25 |
26 | ####################################################
27 |
28 | def run_gpt(
29 | model_in_path,
30 | model_name='117M',
31 | raw_text = ' ',
32 | seed=None,
33 | nsamples=1,
34 | batch_size=1,
35 | length=None,
36 | temperature=1,
37 | top_k=0,
38 | top_p=0.0
39 | ):
40 | """
41 | Interactively run the model
42 | :model_name=117M : String, which model to use
43 | :seed=None : Integer seed for random number generators, fix seed to reproduce
44 | results
45 | :nsamples=1 : Number of samples to return total
46 | :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples.
47 | :length=None : Number of tokens in generated text, if None (default), is
48 | determined by model hyperparameters
49 | :temperature=1 : Float value controlling randomness in boltzmann
50 | distribution. Lower temperature results in less random completions. As the
51 | temperature approaches zero, the model will become deterministic and
52 | repetitive. Higher temperature results in more random completions.
53 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
54 | considered for each step (token), resulting in deterministic completions,
55 | while 40 means 40 words are considered at each step. 0 (default) is a
56 | special setting meaning no restrictions. 40 generally is a good value.
57 | :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
58 | overriding top_k if set to a value > 0. A good setting is 0.9.
59 | """
60 | output_text = ''
61 |
62 | if batch_size is None:
63 | batch_size = 1
64 | assert nsamples % batch_size == 0
65 |
66 | enc = get_encoder(model_in_path)
67 | hparams = model.default_hparams()
68 | with open(os.path.join(model_in_path, 'hparams.json')) as f:
69 | hparams.override_from_dict(json.load(f))
70 |
71 | if length is None:
72 | length = hparams.n_ctx // 2
73 | elif length > hparams.n_ctx:
74 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
75 |
76 | with tf.Session(graph=tf.Graph()) as sess:
77 | context = tf.placeholder(tf.int32, [batch_size, None])
78 | np.random.seed(seed)
79 | tf.set_random_seed(seed)
80 | output = sample.sample_sequence(
81 | hparams=hparams, length=length,
82 | context=context,
83 | batch_size=batch_size,
84 | temperature=temperature, top_k=top_k, top_p=top_p
85 | )
86 |
87 | saver = tf.train.Saver()
88 | ckpt = tf.train.latest_checkpoint(model_in_path) #os.path.join('models', model_name))
89 | saver.restore(sess, ckpt)
90 |
91 | if len(raw_text) == 0:
92 | raw_text = ' '
93 | context_tokens = enc.encode(raw_text)
94 | generated = 0
95 | for n in range(nsamples // batch_size):
96 | out = sess.run(output, feed_dict={
97 | context: [context_tokens for _ in range(batch_size)]
98 | })[:, len(context_tokens):]
99 | for i in range(batch_size):
100 | text = enc.decode(out[i])
101 | generated = generated + 1
102 | #print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
103 | #print(text)
104 | if n == 0:
105 | output_text = text
106 | else:
107 | output_text = output_text + '\n' + text
108 | return output_text
109 |
110 | ##############################################################################
111 |
112 | def maketree(path):
113 | try:
114 | os.makedirs(path)
115 | except:
116 | pass
117 |
118 |
119 | def randomize(context, hparams, p):
120 | if p > 0:
121 | mask = tf.random.uniform(shape=tf.shape(context)) < p
122 | noise = tf.random.uniform(shape=tf.shape(context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32)
123 | return tf.where(mask, noise, context)
124 | else:
125 | return context
126 |
127 |
128 | def get_encoder(path):
129 | with open(os.path.join(path, 'encoder.json'), 'r') as f:
130 | enc = json.load(f)
131 | with open(os.path.join(path, 'vocab.bpe'), 'r', encoding="utf-8") as f:
132 | bpe_data = f.read()
133 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
134 | return encoder.Encoder(
135 | encoder=enc,
136 | bpe_merges=bpe_merges,
137 | )
138 |
139 | def train(dataset, model_in_path, model_out_path,
140 | model_name = '117M',
141 | steps = 1000,
142 | combine = 50000,
143 | batch_size = 1,
144 | learning_rate = 0.00002,
145 | accumulate_gradients = 1,
146 | memory_saving_gradients = False,
147 | only_train_transformer_layers = False,
148 | optimizer = 'adam',
149 | noise = 0.0,
150 | top_k = 40,
151 | top_p = 0.0,
152 | restore_from = 'latest',
153 | sample_every = 100,
154 | sample_length = 1023,
155 | sample_num = 1,
156 | save_every = 1000,
157 | val_dataset = None):
158 | # Reset the TF computation graph
159 | tf.reset_default_graph()
160 | # Get the checkpoint and sample directories
161 | #checkpoint_dir = os.path.dirname(model_path)
162 | #sample_dir = checkpoint_dir
163 | #run_name = os.path.basename(model_path)
164 | # Load the encoder
165 | enc = get_encoder(model_in_path)
166 | hparams = model.default_hparams()
167 | with open(os.path.join(model_in_path, 'hparams.json')) as f:
168 | hparams.override_from_dict(json.load(f))
169 |
170 | if sample_length > hparams.n_ctx:
171 | raise ValueError(
172 | "Can't get samples longer than window size: %s" % hparams.n_ctx)
173 |
174 | # Size matters
175 | if model_name == '345M':
176 | memory_saving_gradients = True
177 | if optimizer == 'adam':
178 | only_train_transformer_layers = True
179 |
180 | # Configure TF
181 | config = tf.ConfigProto()
182 | config.gpu_options.allow_growth = True
183 | config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
184 | # Start the session
185 | with tf.Session(config=config) as sess:
186 | context = tf.placeholder(tf.int32, [batch_size, None])
187 | context_in = randomize(context, hparams, noise)
188 | output = model.model(hparams=hparams, X=context_in)
189 | loss = tf.reduce_mean(
190 | tf.nn.sparse_softmax_cross_entropy_with_logits(
191 | labels=context[:, 1:], logits=output['logits'][:, :-1]))
192 |
193 | tf_sample = sample.sample_sequence(
194 | hparams=hparams,
195 | length=sample_length,
196 | context=context,
197 | batch_size=batch_size,
198 | temperature=1.0,
199 | top_k=top_k,
200 | top_p=top_p)
201 |
202 | all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
203 | train_vars = [v for v in all_vars if '/h' in v.name] if only_train_transformer_layers else all_vars
204 |
205 | if optimizer == 'adam':
206 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
207 | elif optimizer == 'sgd':
208 | opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
209 | else:
210 | exit('Bad optimizer:', optimizer)
211 |
212 | if accumulate_gradients > 1:
213 | if memory_saving_gradients:
214 | exit("Memory saving gradients are not implemented for gradient accumulation yet.")
215 | opt = AccumulatingOptimizer(
216 | opt=opt,
217 | var_list=train_vars)
218 | opt_reset = opt.reset()
219 | opt_compute = opt.compute_gradients(loss)
220 | opt_apply = opt.apply_gradients()
221 | summary_loss = tf.summary.scalar('loss', opt_apply)
222 | else:
223 | if memory_saving_gradients:
224 | opt_grads = memory_saving_gradients.gradients(loss, train_vars)
225 | else:
226 | opt_grads = tf.gradients(loss, train_vars)
227 | opt_grads = list(zip(opt_grads, train_vars))
228 | opt_apply = opt.apply_gradients(opt_grads)
229 | summary_loss = tf.summary.scalar('loss', loss)
230 |
231 | summary_lr = tf.summary.scalar('learning_rate', learning_rate)
232 | summaries = tf.summary.merge([summary_lr, summary_loss])
233 |
234 | summary_log = tf.summary.FileWriter(
235 | #os.path.join(checkpoint_dir, run_name)
236 | model_out_path
237 | )
238 |
239 | saver = tf.train.Saver(
240 | var_list=all_vars,
241 | max_to_keep=1)
242 | sess.run(tf.global_variables_initializer())
243 |
244 | if restore_from == 'latest':
245 | ckpt = tf.train.latest_checkpoint(
246 | #os.path.join(checkpoint_dir, run_name)
247 | model_in_path
248 | )
249 | if ckpt is None:
250 | # Get fresh GPT weights if new run.
251 | ckpt = tf.train.latest_checkpoint(
252 | model_in_path)#os.path.join('models', model_name))
253 | elif restore_from == 'fresh':
254 | ckpt = tf.train.latest_checkpoint(
255 | model_in_path)#os.path.join('models', model_name))
256 | else:
257 | ckpt = tf.train.latest_checkpoint(restore_from)
258 | print('Loading checkpoint', ckpt)
259 | saver.restore(sess, ckpt)
260 |
261 | print('Loading dataset...')
262 | chunks = load_dataset(enc, dataset, combine)
263 | data_sampler = Sampler(chunks)
264 | print('dataset has', data_sampler.total_size, 'tokens')
265 | print('Training...')
266 | counter = 1
267 | counter_path = os.path.join(model_in_path, 'counter') #os.path.join(checkpoint_dir, run_name, 'counter')
268 | if restore_from == 'latest' and os.path.exists(counter_path):
269 | # Load the step number if we're resuming a run
270 | # Add 1 so we don't immediately try to save again
271 | with open(counter_path, 'r') as fp:
272 | counter = int(fp.read()) + 1
273 |
274 | def save():
275 | #maketree(os.path.join(checkpoint_dir, run_name))
276 | maketree(model_out_path)
277 | print(
278 | 'Saving',
279 | #os.path.join(checkpoint_dir, run_name, 'model-{}').format(counter)
280 | os.path.join(model_out_path, 'model-{}').format(counter)
281 | )
282 | saver.save(
283 | sess,
284 | #os.path.join(checkpoint_dir, run_name, 'model'),
285 | os.path.join(model_out_path, 'model'),
286 | global_step=counter)
287 | with open(os.path.join(model_out_path, 'counter'), 'w') as fp:
288 | fp.write(str(counter) + '\n')
289 |
290 | def generate_samples():
291 | print('Generating samples...')
292 | context_tokens = data_sampler.sample(1)
293 | all_text = []
294 | index = 0
295 | while index < sample_num:
296 | out = sess.run(
297 | tf_sample,
298 | feed_dict={context: batch_size * [context_tokens]})
299 | for i in range(min(sample_num - index, batch_size)):
300 | text = enc.decode(out[i])
301 | text = '======== SAMPLE {} ========\n{}\n'.format(
302 | index + 1, text)
303 | all_text.append(text)
304 | index += 1
305 | print(text)
306 | #maketree(os.path.join(sample_dir, run_name))
307 | maketree(model_out_path)
308 | with open(os.path.join(model_out_path, 'samples-{}').format(counter), 'w') as fp:
309 | fp.write('\n'.join(all_text))
310 |
311 | def sample_batch():
312 | return [data_sampler.sample(1024) for _ in range(batch_size)]
313 |
314 |
315 | avg_loss = (0.0, 0.0)
316 | start_time = time.time()
317 |
318 | stop = steps + counter
319 |
320 | try:
321 | while counter < stop + 1:
322 | if counter % save_every == 0:
323 | save()
324 | '''
325 | if counter % sample_every == 0:
326 | generate_samples()
327 | '''
328 |
329 | if accumulate_gradients > 1:
330 | sess.run(opt_reset)
331 | for _ in range(accumulate_gradients):
332 | sess.run(
333 | opt_compute, feed_dict={context: sample_batch()})
334 | (v_loss, v_summary) = sess.run((opt_apply, summaries))
335 | else:
336 | (_, v_loss, v_summary) = sess.run(
337 | (opt_apply, loss, summaries),
338 | feed_dict={context: sample_batch()})
339 |
340 | summary_log.add_summary(v_summary, counter)
341 |
342 | avg_loss = (avg_loss[0] * 0.99 + v_loss,
343 | avg_loss[1] * 0.99 + 1.0)
344 |
345 | print(
346 | '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
347 | .format(
348 | counter=counter,
349 | time=time.time() - start_time,
350 | loss=v_loss,
351 | avg=avg_loss[0] / avg_loss[1]))
352 |
353 | counter += 1
354 | print('done!')
355 | save()
356 | except KeyboardInterrupt:
357 | print('interrupted')
358 | save()
359 |
360 | #####################################################################
361 |
362 | if __name__ == '__main__':
363 | cache_path = os.path.join(os.getcwd(), 'cache')
364 | os.chdir('gpt-2')
365 | train(os.path.join(cache_path, 'text0'), cache_path, cache_path, steps=1)
366 | #text = run_gpt(top_k=40)
367 | #print(text)
368 | os.chdir('..')
--------------------------------------------------------------------------------
/lstm.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import os
4 | import torch
5 | import codecs
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import pdb
9 |
10 |
11 | ########################################################
12 | ### MODEL
13 | ########################################################
14 |
15 | class RNNModel(nn.Module):
16 | """Container module with an encoder, a recurrent module, and a decoder."""
17 |
18 | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
19 | super(RNNModel, self).__init__()
20 | self.drop = nn.Dropout(dropout)
21 | self.encoder = nn.Embedding(ntoken, ninp)
22 | if rnn_type in ['LSTM', 'GRU']:
23 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
24 | else:
25 | try:
26 | nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
27 | except KeyError:
28 | raise ValueError( """An invalid option for `--model` was supplied,
29 | options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
30 | self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
31 | self.decoder = nn.Linear(nhid, ntoken)
32 |
33 | # Optionally tie weights as in:
34 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
35 | # https://arxiv.org/abs/1608.05859
36 | # and
37 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
38 | # https://arxiv.org/abs/1611.01462
39 | if tie_weights:
40 | if nhid != ninp:
41 | raise ValueError('When using the tied flag, nhid must be equal to emsize')
42 | self.decoder.weight = self.encoder.weight
43 |
44 | self.init_weights()
45 |
46 | self.rnn_type = rnn_type
47 | self.nhid = nhid
48 | self.nlayers = nlayers
49 |
50 | def init_weights(self):
51 | initrange = 0.1
52 | self.encoder.weight.data.uniform_(-initrange, initrange)
53 | self.decoder.bias.data.zero_()
54 | self.decoder.weight.data.uniform_(-initrange, initrange)
55 |
56 | def forward(self, input, hidden):
57 | emb = self.drop(self.encoder(input))
58 | output, hidden = self.rnn(emb, hidden)
59 | output = self.drop(output)
60 | decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
61 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
62 |
63 | def init_hidden(self, bsz):
64 | weight = next(self.parameters())
65 | if self.rnn_type == 'LSTM':
66 | return (weight.new_zeros(self.nlayers, bsz, self.nhid),
67 | weight.new_zeros(self.nlayers, bsz, self.nhid))
68 | else:
69 | return weight.new_zeros(self.nlayers, bsz, self.nhid)
70 |
71 |
72 | #######################################################
73 | ### DICTIONARY
74 | #######################################################
75 |
76 | class Dictionary(object):
77 | def __init__(self):
78 | self.token2idx = {}
79 | self.idx2token = []
80 |
81 | def add_token(self, word):
82 | if word not in self.token2idx:
83 | self.idx2token.append(word)
84 | self.token2idx[word] = len(self.idx2token) - 1
85 | return self.token2idx[word]
86 |
87 | def __len__(self):
88 | return len(self.idx2token)
89 |
90 | ###############################################################################
91 | ### CORPUS
92 | ################################################################################
93 |
94 | class Corpus(object):
95 |
96 | def __init__(self, train_path):
97 | self.dictionary = Dictionary()
98 | self.train_data = self.tokenize(train_path) if train_path is not None and os.path.exists(train_path) else None
99 |
100 | def tokenize(self, path):
101 | """Tokenizes a text file."""
102 | pass
103 |
104 | class WordCorpus(Corpus):
105 |
106 | def tokenize(self, path):
107 | super(WordCorpus, self).tokenize(path)
108 | # Add words to the dictionary
109 | tokens = 0
110 | with codecs.open(path, 'r', encoding="utf8") as f:
111 | for line in f:
112 | words = line.split() + ['']
113 | tokens += len(words)
114 | for word in words:
115 | self.dictionary.add_token(word)
116 |
117 | # Tokenize file content
118 | with codecs.open(path, 'r', encoding="utf8") as f:
119 | ids = torch.LongTensor(tokens)
120 | token = 0
121 | for line in f:
122 | words = line.split() + ['']
123 | for word in words:
124 | ids[token] = self.dictionary.token2idx[word]
125 | token += 1
126 | return ids
127 |
128 | class CharCorpus(Corpus):
129 |
130 | def tokenize(self, path):
131 | super(CharCorpus, self).tokenize(path)
132 | tokens = 0
133 | with codecs.open(path, 'r', encoding="utf8") as f:
134 | for line in f:
135 | tokens += len(line)
136 | for c in line:
137 | self.dictionary.add_token(c)
138 | with codecs.open(path, 'r', encoding="utf8") as f:
139 | ids = torch.LongTensor(tokens)
140 | token = 0
141 | for line in f:
142 | for c in line:
143 | ids[token] = self.dictionary.token2idx[c]
144 | token += 1
145 | return ids
146 |
147 | ###############################################################################
148 | # Globals
149 | ###############################################################################
150 |
151 | RUN_STEPS = 600
152 | RUN_TEMPERATURE = 0.5
153 | SEED = 1
154 | HISTORY = 35
155 | LAYERS = 2
156 | EPOCHS = 50
157 | HIDDEN_NODES = 512
158 | BATCH_SIZE = 10
159 | MODEL_TYPE = 'GRU'
160 | DROPOUT = 0.2
161 | TIED = False
162 | EMBED_SIZE = HIDDEN_NODES
163 | CLIP = 0.25
164 | LR = 0.0001
165 | LR_DECAY = 0.1
166 | LOG_INTERVAL = 10
167 |
168 | #############################################################################
169 | # MAIN entrypoints
170 | ############################################################################
171 |
172 |
173 | def wordLSTM_Run(model_path, dictionary_path, output_path, seed = SEED,
174 | steps = RUN_STEPS, temperature = RUN_TEMPERATURE, k = 0):
175 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
176 |
177 | # Load the model
178 | model = None
179 | with open(model_path, 'rb') as f:
180 | model = torch.load(f)
181 | model = model.to(device)
182 | model.eval()
183 |
184 | # Load the dictionary
185 | dictionary = None
186 | with open(dictionary_path, 'rb') as f:
187 | dictionary = torch.load(f)
188 | ntokens = len(dictionary)
189 |
190 | hidden = model.init_hidden(1)
191 |
192 | input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
193 |
194 | if seed is not None:
195 | seed_words = seed.strip().split()
196 | if len(seed_words) > 1:
197 | for i in range(len(seed_words)-1):
198 | word = seed_words[i]
199 | if word in dictionary.idx2token:
200 | input = torch.tensor([[dictionary.token2idx[word]]], dtype=torch.long).to(device)
201 | output, hidden = model(input, hidden)
202 | if len(seed_words) > 0:
203 | input = torch.tensor([[dictionary.token2idx[seed_words[-1]]]], dtype=torch.long).to(device)
204 |
205 | with codecs.open(output_path, 'w', encoding="utf8") as outf:
206 | if seed is not None and len(seed) > 0:
207 | outf.write(seed.strip() + ' ')
208 | with torch.no_grad(): # no tracking history
209 | for i in range(steps):
210 | output, hidden = model(input, hidden)
211 | word_weights = output.squeeze().div(temperature).exp().cpu()
212 | word_idx = None
213 | if k > 0:
214 | # top-k sampling
215 | word_idx = top_k_sample(word_weights, k)
216 | else:
217 | word_idx = torch.multinomial(word_weights, 1)[0]
218 | input.fill_(word_idx)
219 | word = dictionary.idx2token[word_idx]
220 |
221 | outf.write(word + ' ' if i < steps-1 else '')
222 |
223 | ### Top K sampling
224 | def top_k_sample(logits, k):
225 | values, _ = torch.topk(logits, k)
226 | min_value = values.min()
227 | mask = logits >= min_value
228 | new_logits = logits * mask.float()
229 | return torch.multinomial(new_logits, 1)[0]
230 |
231 |
232 |
233 |
234 | def charLSTM_Run(model_path, dictionary_path, output_path, seed = SEED,
235 | steps = RUN_STEPS, temperature = RUN_TEMPERATURE):
236 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
237 |
238 | model = None
239 | with open(model_path, 'rb') as f:
240 | model = torch.load(f)
241 | model = model.to(device)
242 | model.eval()
243 |
244 | dictionary = None
245 | with open(dictionary_path, 'rb') as f:
246 | dictionary = torch.load(f)
247 | ntokens = len(dictionary)
248 |
249 | hidden = model.init_hidden(1)
250 |
251 | input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
252 |
253 | if seed is not None:
254 | if len(seed) > 1:
255 | for i in range(len(seed)-1):
256 | char = seed[i]
257 | if char in dictionary.idx2token:
258 | input = torch.tensor([[dictionary.token2idx[char]]], dtype=torch.long).to(device)
259 | output, hidden = model(input, hidden)
260 | if len(seed) > 0:
261 | input = torch.tensor([[dictionary.token2idx[seed[-1]]]], dtype=torch.long).to(device)
262 |
263 |
264 | text = seed
265 | with torch.no_grad(): # no tracking history
266 | for i in range(steps):
267 | output, hidden = model(input, hidden)
268 | char_weights = output.squeeze().div(temperature).exp().cpu()
269 | char_idx = torch.multinomial(char_weights, 1)[0]
270 | input.fill_(char_idx)
271 | char = dictionary.idx2token[char_idx]
272 | text = text + char
273 |
274 | with codecs.open(output_path, 'w', encoding="utf8") as outf:
275 | outf.write(text)
276 |
277 | def wordLSTM_Train(train_data_path,
278 | dictionary_path, model_out_path,
279 | history = HISTORY, layers = LAYERS, epochs = EPOCHS, hidden_nodes = HIDDEN_NODES,
280 | batch_size = BATCH_SIZE, model_type=MODEL_TYPE, dropout = DROPOUT, tied = TIED, embed_size = EMBED_SIZE,
281 | clip = CLIP, lr = LR,
282 | lr_decay = LR_DECAY,
283 | log_interval = LOG_INTERVAL):
284 | train(train_data_path,
285 | dictionary_path, model_out_path,
286 | history = history, layers = layers, epochs = epochs, hidden_nodes = hidden_nodes,
287 | batch_size = batch_size, model_type=model_type, dropout = dropout, tied = tied, embed_size = embed_size,
288 | clip = clip, lr = lr,
289 | lr_decay = lr_decay,
290 | log_interval = log_interval,
291 | corpus_type = WordCorpus)
292 |
293 | def wordLSTM_Train_More(train_data_path, model_in_path, dictionary_path, model_out_path,
294 | history = HISTORY, epochs = EPOCHS, batch_size = BATCH_SIZE,
295 | clip = CLIP, lr = LR, lr_decay = LR_DECAY, log_interval = LOG_INTERVAL):
296 | train_more(train_data_path, model_in_path, dictionary_path, model_out_path,
297 | history = history, epochs = epochs, batch_size = batch_size,
298 | clip = clip, lr = lr, lr_decay = lr_decay, log_interval = log_interval,
299 | corpus_type = WordCorpus)
300 |
301 | def charLSTM_Train(train_data_path,
302 | dictionary_path, model_out_path,
303 | history = HISTORY, layers = LAYERS, epochs = EPOCHS, hidden_nodes = HIDDEN_NODES,
304 | batch_size = BATCH_SIZE, model_type=MODEL_TYPE, dropout = DROPOUT, tied = TIED, embed_size = EMBED_SIZE,
305 | clip = CLIP, lr = LR,
306 | lr_decay = LR_DECAY,
307 | log_interval = LOG_INTERVAL):
308 | train(train_data_path,
309 | dictionary_path, model_out_path,
310 | history = history, layers = layers, epochs = epochs, hidden_nodes = hidden_nodes,
311 | batch_size = batch_size, model_type=model_type, dropout = dropout, tied = tied, embed_size = embed_size,
312 | clip = clip, lr = lr,
313 | lr_decay = lr_decay,
314 | log_interval = log_interval,
315 | corpus_type = CharCorpus)
316 |
317 | def charLSTM_Train_More(train_data_path, model_in_path, dictionary_path, model_out_path,
318 | history = HISTORY, epochs = EPOCHS, batch_size = BATCH_SIZE,
319 | clip = CLIP, lr = LR, lr_decay = LR_DECAY, log_interval = LOG_INTERVAL):
320 | train_more(train_data_path, model_in_path, dictionary_path, model_out_path,
321 | history = history, epochs = epochs, batch_size = batch_size,
322 | clip = clip, lr = lr, lr_decay = lr_decay, log_interval = log_interval,
323 | corpus_type = CharCorpus)
324 |
325 |
326 |
327 | #################################################
328 | ### TRAIN
329 | ##################################################
330 |
331 | def train_more(train_data_path, model_in_path, dictionary_path, model_out_path,
332 | history = HISTORY, epochs = EPOCHS, batch_size = BATCH_SIZE,
333 | clip = CLIP, lr = LR, lr_decay = LR_DECAY, log_interval = LOG_INTERVAL,
334 | corpus_type = WordCorpus):
335 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
336 | corpus = corpus_type(train_data_path)
337 | dictionary = corpus.dictionary
338 | with open(dictionary_path, 'wb') as f:
339 | dictionary = torch.load(f)
340 | with open(model_in_path, 'wb') as f:
341 | model = torch.load(f)
342 | model = model.to(device)
343 | train_loop(model, corpus, model_out_path,
344 | history = history, epochs = epochs, batch_size = batch_size,
345 | clip = clip, lr = lr, lr_decay = lr_decay, log_interval = log_interval)
346 |
347 |
348 | def train(train_data_path, dictionary_path, model_out_path,
349 | history = HISTORY, layers = LAYERS, epochs = EPOCHS, hidden_nodes = HIDDEN_NODES,
350 | batch_size = BATCH_SIZE, model_type=MODEL_TYPE, dropout = DROPOUT, tied = TIED, embed_size = EMBED_SIZE,
351 | clip = CLIP, lr = LR,
352 | lr_decay = LR_DECAY,
353 | log_interval = LOG_INTERVAL,
354 | corpus_type = WordCorpus):
355 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
356 |
357 | corpus = corpus_type(train_data_path)
358 | dictionary = corpus.dictionary
359 | with open(dictionary_path, 'wb') as f:
360 | torch.save(dictionary, f)
361 |
362 | ### BUILD THE MODEL
363 | ntokens = len(dictionary)
364 | model = RNNModel(model_type, ntokens, embed_size, hidden_nodes, layers, dropout, tied)
365 | model = model.to(device)
366 | train_loop(model, corpus, model_out_path,
367 | history = history, epochs = epochs, batch_size = batch_size,
368 | clip = clip, lr = lr, lr_decay = lr_decay, log_interval = log_interval)
369 |
370 | def train_loop(model, corpus, model_path,
371 | history = HISTORY, epochs = EPOCHS, batch_size = BATCH_SIZE,
372 | clip = CLIP, lr = LR, lr_decay = LR_DECAY, log_interval = LOG_INTERVAL):
373 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
374 | # Starting from sequential data, batchify arranges the dataset into columns.
375 | # For instance, with the alphabet as the sequence and batch size 4, we'd get
376 | # | a g m s |
377 | # | b h n t |
378 | # | c i o u |
379 | # | d j p v |
380 | # | e k q w |
381 | # | f l r x |.
382 | # These columns are treated as independent by the model, which means that the
383 | # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
384 | # batch processing.
385 | train_data = batchify(corpus.train_data, batch_size, device)
386 | val_data = batchify(corpus.train_data[0:corpus.train_data.size()[0]//10], batch_size, device)
387 | dictionary = corpus.dictionary
388 | ntokens = len(dictionary)
389 |
390 | criterion = nn.CrossEntropyLoss()
391 |
392 | best_val_loss = None
393 | log_interval = max(1, (len(train_data) // history) // log_interval)
394 |
395 | ### TRAIN
396 | for epoch in range(1, epochs+1):
397 | epoch_start_time = time.time()
398 |
399 | model.train()
400 | optimizer = optim.Adam(model.parameters(), lr=lr)
401 |
402 | total_loss = 0.0
403 | start_time = time.time()
404 | hidden = model.init_hidden(batch_size)
405 |
406 | for batch, i in enumerate(range(0, train_data.size(0) - 1, history)):
407 | data, targets = get_batch(train_data, i, batch_size)
408 | # Starting each batch, we detach the hidden state from how it was previously produced.
409 | # If we didn't, the model would try backpropagating all the way to start of the dataset.
410 | hidden = repackage_hidden(hidden)
411 | #model.zero_grad()
412 | optimizer.zero_grad()
413 | output, hidden = model(data, hidden)
414 | loss = criterion(output.view(-1, ntokens), targets)
415 | loss.backward()
416 | optimizer.step()
417 |
418 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
419 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
420 | '''
421 | for p in model.parameters():
422 | p.data.add_(-lr, p.grad.data)
423 | '''
424 |
425 | total_loss += loss.item()
426 |
427 | if batch % log_interval == 0 and batch > 0:
428 | cur_loss = total_loss / log_interval
429 | elapsed = time.time() - start_time
430 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | '
431 | 'loss {:5.2f} | ppl {:8.2f}'.format(
432 | epoch, batch, len(train_data) // history, lr,
433 | elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss) if cur_loss < 1000 else float('inf')))
434 | total_loss = 0
435 | start_time = time.time()
436 |
437 | ### EVALUATE
438 | val_loss = evaluate(model, val_data, criterion, dictionary, batch_size, history)
439 | print('-' * 89)
440 | print('| end of epoch {:3d} | time: {:5.2f}s | train loss {:5.2f} | '
441 | 'train ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
442 | val_loss, math.exp(val_loss) if val_loss < 1000 else float('inf')))
443 | print('-' * 89)
444 | # Save the model if the validation loss is the best we've seen so far.
445 | if best_val_loss is None:
446 | best_val_loss = val_loss
447 | if val_loss <= best_val_loss:
448 | with open(model_path, 'wb') as f:
449 | torch.save(model, f)
450 | best_val_loss = val_loss
451 | else:
452 | # Anneal the learning rate if no improvement has been seen in the validation dataset.
453 | lr = lr * lr_decay
454 |
455 |
456 |
457 | #######################################################
458 | ### HELPERS
459 | ######################################################
460 |
461 |
462 | def batchify(data, bsz, device):
463 | # Work out how cleanly we can divide the dataset into bsz parts.
464 | nbatch = data.size(0) // bsz
465 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
466 | data = data.narrow(0, 0, nbatch * bsz)
467 | # Evenly divide the data across the bsz batches.
468 | data = data.view(bsz, -1).t().contiguous()
469 | return data.to(device)
470 |
471 | def repackage_hidden(h):
472 | """Wraps hidden states in new Tensors, to detach them from their history."""
473 | if isinstance(h, torch.Tensor):
474 | return h.detach()
475 | else:
476 | return tuple(repackage_hidden(v) for v in h)
477 |
478 | # get_batch subdivides the source data into chunks of length args.bptt.
479 | # If source is equal to the example output of the batchify function, with
480 | # a bptt-limit of 2, we'd get the following two Variables for i = 0:
481 | # | a g m s | | b h n t |
482 | # | b h n t | | c i o u |
483 | # Note that despite the name of the function, the subdivison of data is not
484 | # done along the batch dimension (i.e. dimension 1), since that was handled
485 | # by the batchify function. The chunks are along dimension 0, corresponding
486 | # to the seq_len dimension in the LSTM.
487 | def get_batch(source, i, history):
488 | seq_len = min(history, len(source) - 1 - i)
489 | data = source[i:i+seq_len]
490 | target = source[i+1:i+1+seq_len].view(-1)
491 | return data, target
492 |
493 |
494 | def evaluate(model, data_source, criterion, dictionary, batch_size, history):
495 | # Turn on evaluation mode which disables dropout.
496 | model.eval()
497 | total_loss = 0.0
498 | ntokens = len(dictionary)
499 | hidden = model.init_hidden(batch_size)
500 | with torch.no_grad():
501 | for i in range(0, data_source.size(0) - 1, history):
502 | data, targets = get_batch(data_source, i, batch_size)
503 | output, hidden = model(data, hidden)
504 | output_flat = output.view(-1, ntokens)
505 | total_loss += len(data) * criterion(output_flat, targets).item()
506 | hidden = repackage_hidden(hidden)
507 | return total_loss / (len(data_source) - 1)
508 |
509 | ######################################
510 | ### TESTING
511 |
512 |
513 |
514 | if __name__ == "__main__":
515 | print("running")
516 | train_data_path = 'datasets/origin_train'
517 | val_data_path = 'datasets/origin_valid'
518 | dictionary_path = 'origin_dictionary'
519 | model_out_path = 'origin.model'
520 | output_path = 'foo_out.txt'
521 | seed = 'the'
522 | print("training")
523 | wordLSTM_Train(train_data_path,
524 | dictionary_path,
525 | model_out_path,
526 | epochs = 1)
527 | print("running")
528 | wordLSTM_Run(model_out_path, dictionary_path, output_path, seed = seed, k = 20)
529 |
--------------------------------------------------------------------------------
/module_dicts.js:
--------------------------------------------------------------------------------
1 | var module_dicts = [{"name" : "ReadWikipedia",
2 | "params" : [{"name" : "wiki_directory", "type" : "directory", "default" : "wiki"},
3 | {"name": "pattern", "type" : "string", "default": "*"},
4 | {"name" : "categories", "type" : "string", "default" : "*"},
5 | {"name" : "out_file", "type" : "text", "out" : true},
6 | {"name" : "titles_file", "type" : "text", "out" : true}],
7 | "category" : "Wikipedia"},
8 | {"name" : "WordRNN_Train",
9 | "params" : [{"name" : "data", "in" : true, "type" : "text"},
10 | {"name" : "history", "type" : "int", "default" : 35},
11 | {"name" : "layers", "type" : "int", "default": 2},
12 | {"name" : "hidden_nodes", "type" : "int", "default" : 512},
13 | {"name" : "epochs", "type" : "int", "default" : 50},
14 | {"name" : "learning_rate", "type" : "float", "default": 0.0001},
15 | {"name" : "model", "type" : "model", "out" : true},
16 | {"name" : "dictionary", "type" : "dictionary", "out" : true}],
17 | "category" : "RNN"},
18 | {"name" : "WordRNN_Run",
19 | "params" : [{"name" : "model", "in" : true, "type" : "model"},
20 | {"name" : "dictionary", "in" : true, "type" : "dictionary"},
21 | {"name" : "seed", "in" : true, "type" : "text"},
22 | {"name" : "steps", "type" : "int", "default" : "600"},
23 | {"name" : "temperature", "type" : "float", "default" : "0.5"},
24 | {"name" : "k", "type" : "int", "default" : "40"},
25 | {"name" : "output", "out" : true, "type" : "text"}],
26 | "category" : "RNN"},
27 | {"name" : "MakeString",
28 | "params" : [{"name" : "string", "type" : "string"},
29 | {"name" : "output", "out" : true, "type" : "text"}],
30 | "category" : "Input"},
31 | {"name" : "RemoveEmptyLines",
32 | "params" : [{"name" : "input", "type" : "text", "in" : true},
33 | {"name" : "output", "type" : "text", "out" : true}],
34 | "category" : "Text"},
35 | {"name" : "SplitSentences",
36 | "params" : [{"name" : "input", "type" : "text", "in" : true},
37 | {"name" : "output", "type" : "text", "out" : true}],
38 | "category" : "Text"},
39 | {"name" : "ReplaceCharacters",
40 | "params" : [{"name" : "input", "type" : "text", "in" : true},
41 | {"name" : "find", "type" : "string"},
42 | {"name" : "replace", "type" : "string"},
43 | {"name" : "output", "type" : "text", "out" : true}],
44 | "category" : "Text"},
45 | {"name" : "ReadTextFile",
46 | "params" : [{"name" : "file", "type" : "directory"},
47 | {"name" : "output", "type" : "text", "out" : true}],
48 | "category" : "File"},
49 | {"name" : "WriteTextFile",
50 | "params" : [{"name" : "input", "type" : "text", "in" : true},
51 | {"name" : "file", "type" : "directory"}],
52 | "category" : "File"},
53 | {"name" : "MakeLowercase",
54 | "params" : [{"name" : "input", "type" : "text", "in" : true},
55 | {"name" : "output", "type" : "text", "out" : true}],
56 | "category" : "Text"},
57 | {"name" : "Wordify",
58 | "params" : [{"name" : "input", "type" : "text", "in" : true},
59 | {"name" : "output", "type" : "text", "out" : true}],
60 | "category" : "Text"},
61 | {"name" : "RemoveTags",
62 | "params" : [{"name" : "input", "type" : "text", "in" : true},
63 | {"name" : "output", "type" : "text", "out" : true}],
64 | "category" : "HTML"},
65 | {"name" : "CleanText",
66 | "params" : [{"name" : "input", "type" : "text", "in" : true},
67 | {"name" : "output", "type" : "text", "out" : true}],
68 | "category" : "Text"},
69 | {"name" : "SaveModel",
70 | "params" : [{"name" : "model", "in" : true, "type" : "model"},
71 | {"name" : "file", "type" : "directory"}],
72 | "category" : "File"},
73 | {"name" : "LoadModel",
74 | "params" : [{"name" : "file", "type" : "directory"},
75 | {"name" : "model", "out" : true, "type" : "model"}],
76 | "category" : "File"},
77 | {"name" : "SaveDictionary",
78 | "params" : [{"name" : "dictionary", "in" : true, "type" : "dictionary"},
79 | {"name" : "file", "type" : "directory"}],
80 | "category" : "File"},
81 | {"name" : "LoadDictionary",
82 | "params" : [{"name" : "file", "type" : "directory"},
83 | {"name" : "dictionary", "out" : true, "type" : "dictionary"}],
84 | "category" : "File"},
85 | {"name" : "SplitHTML",
86 | "params" : [{"name" : "input", "type" : "text", "in" : true},
87 | {"name" : "output", "type" : "text", "out" : true}],
88 | "category" : "HTML"},
89 | {"name" : "RandomSequence",
90 | "params" : [{"name" : "input", "type" : "text", "in" : true},
91 | {"name" : "length", "type" : "int", "default" : 100},
92 | {"name" : "output", "type" : "text", "out" : true}],
93 | "category" : "Text"},
94 | {"name" : "ConcatenateTextFiles",
95 | "params" : [{"name" : "input_1", "type" : "text", "in" : true},
96 | {"name" : "input_2", "type" : "text", "in" : true},
97 | {"name" : "output", "type" : "text", "out" : true}],
98 | "category" : "Text"},
99 | {"name" : "RandomizeLines",
100 | "params" : [{"name" : "input", "type" : "text", "in" : true},
101 | {"name" : "output", "type" : "text", "out" : true}],
102 | "category" : "Text"},
103 | {"name" : "KeepFirstLine",
104 | "params" : [{"name" : "input", "type" : "text", "in" : true},
105 | {"name" : "output", "type" : "text", "out" : true}],
106 | "category" : "Text"},
107 | {"name" : "DeleteFirstLine",
108 | "params" : [{"name" : "input", "type" : "text", "in" : true},
109 | {"name" : "output", "type" : "text", "out" : true}],
110 | "category" : "Text"},
111 | {"name" : "DeleteLastLine",
112 | "params" : [{"name" : "input", "type" : "text", "in" : true},
113 | {"name" : "output", "type" : "text", "out" : true}],
114 | "category" : "Text"},
115 | {"name" : "KeepLineWhen",
116 | "params" : [{"name" : "input", "type" : "text", "in" : true},
117 | {"name" : "match", "type" : "string"},
118 | {"name" : "output", "type" : "text", "out" : true}],
119 | "category" : "Text"},
120 | {"name" : "KeepLineUnless",
121 | "params" : [{"name" : "input", "type" : "text", "in" : true},
122 | {"name" : "match", "type" : "string"},
123 | {"name" : "output", "type" : "text", "out" : true}],
124 | "category" : "Text"},
125 | {"name" : "Sort",
126 | "params" : [{"name" : "input", "type" : "text", "in" : true},
127 | {"name" : "output", "type" : "text", "out" : true}],
128 | "category" : "Text"},
129 | {"name" : "Reverse",
130 | "params" : [{"name" : "input", "type" : "text", "in" : true},
131 | {"name" : "output", "type" : "text", "out" : true}],
132 | "category" : "Text"},
133 | {"name" : "GPT2_FineTune",
134 | "params" : [{"name" : "model_in", "type" : "model", "in" : true},
135 | {"name" : "data", "type" : "text", "in" : true},
136 | {"name" : "model_size", "type" : "string", "default" : "117M"},
137 | {"name" : "steps", "type" : "int", "default" : 1000},
138 | {"name" : "model_out", "type" : "model", "out" : true}],
139 | "category" : "GPT2"},
140 | {"name" : "GPT2_Run",
141 | "params" : [{"name" : "model_in", "type" : "model", "in" : true},
142 | {"name" : "prompt", "type" : "text", "in" : true},
143 | {"name" : "model_size", "type" : "string", "default" : "117M"},
144 | {"name" : "top_k", "type" : "int", "default" : 40},
145 | {"name" : "temperature", "type" : "float", "default" : 1.0},
146 | {"name" : "num_samples", "type" : "int", "default" : 1},
147 | {"name" : "output", "type" : "text", "out" : true}],
148 | "category" : "GPT2"},
149 | {"name" : "CharRNN_Train",
150 | "params" : [{"name" : "data", "in" : true, "type" : "text"},
151 | {"name" : "history", "type" : "int", "default" : 35},
152 | {"name" : "layers", "type" : "int", "default": 2},
153 | {"name" : "hidden_nodes", "type" : "int", "default" : 512},
154 | {"name" : "epochs", "type" : "int", "default" : 50},
155 | {"name" : "learning_rate", "type" : "float", "default": 0.0001},
156 | {"name" : "model", "type" : "model", "out" : true},
157 | {"name" : "dictionary", "type" : "dictionary", "out" : true}],
158 | "category" : "RNN"},
159 | {"name" : "CharRNN_Run",
160 | "params" : [{"name" : "model", "in" : true, "type" : "model"},
161 | {"name" : "dictionary", "in" : true, "type" : "dictionary"},
162 | {"name" : "seed", "in" : true, "type" : "text"},
163 | {"name" : "steps", "type" : "int", "default" : "600"},
164 | {"name" : "temperature", "type" : "float", "default" : "0.5"},
165 | {"name" : "output", "out" : true, "type" : "text"}],
166 | "category" : "RNN"},
167 | {"name" : "UserInput",
168 | "params" : [{"name" : "prompt", "type" : "string", "default" : "prompt"},
169 | {"name" : "output", "type" : "text", "out" : true}],
170 | "category" : "Input"},
171 | {"name" : "Regex_Search",
172 | "params" : [{"name" : "input", "type" : "text", "in" : true},
173 | {"name" : "expression", "type" : "string", "default" : "*"},
174 | {"name" : "output", "type" : "text", "out" : true},
175 | {"name" : "group_1", "type" : "text", "out" : true},
176 | {"name" : "group_2", "type" : "text", "out" : true}],
177 | "category" : "Regex"},
178 | {"name" : "Regex_Sub",
179 | "params" : [{"name" : "input", "type" : "text", "in" : true},
180 | {"name" : "expression", "type" : "string", "default" : ""},
181 | {"name" : "replacement", "type" : "string", "default" : ""},
182 | {"name" : "output", "type" : "text", "out" : true}],
183 | "category" : "Regex"},
184 | {"name" : "PrintText",
185 | "params" : [{"name" : "input", "type" : "text", "in" : true}],
186 | "category" : "Utils"},
187 | {"name": "ReadFromWeb",
188 | "params" : [{"name" : "url", "type" : "string", "default" : ""},
189 | {"name" : "data", "type" : "text", "out" : true}],
190 | "category" : "Web"},
191 | {"name" : "MakeCountFile",
192 | "params" : [{"name" : "num", "type" : "int", "default" : "10"},
193 | {"name" : "prefix", "type" : "string", "default" : ""},
194 | {"name" : "postfix", "type" : "string", "default" : ""},
195 | {"name" : "output", "type" : "text", "out" : true}],
196 | "category" : "Utils"},
197 | {"name" : "ReadAllFromWeb",
198 | "params" : [{"name" : "urls", "type" : "text", "in" : true},
199 | {"name" : "data", "type" : "text", "out" : true}],
200 | "category" : "Web"},
201 | {"name" : "RemoveDuplicates",
202 | "params" : [{"name" : "input", "type" : "text", "in" : true},
203 | {"name" : "output", "type" : "text", "out" : true}],
204 | "category" : "Text"},
205 | {"name" : "StripLines",
206 | "params" : [{"name" : "input", "type" : "text", "in" : true},
207 | {"name" : "output", "type" : "text", "out" : true}],
208 | "category" : "Text"},
209 | {"name" : "TextSubtract",
210 | "params" : [{"name" : "main", "type" : "text", "in" : true},
211 | {"name" : "subtract", "type" : "text", "in" : true},
212 | {"name" : "diff", "type" : "text", "out" : true}],
213 | "category" : "Text"},
214 | {"name" : "DuplicateText",
215 | "params" : [{"name" : "input", "type" : "text", "in" : true},
216 | {"name" : "count", "type" : "int", "default" : 1},
217 | {"name" : "output", "type" : "text", "out" : true}],
218 | "category" : "Text"},
219 | {"name" : "Spellcheck",
220 | "params" : [{"name" : "input", "type" : "text", "in" : true},
221 | {"name" : "output", "type" : "text", "out" : true}],
222 | "category" : "Text"},
223 | {"name" : "WebCrawl",
224 | "params" : [{"name" : "url", "type" : "string", "default" : ""},
225 | {"name" : "link_id", "type" : "string", "default" : ""},
226 | {"name" : "link_text", "type" : "string", "default" : ""},
227 | {"name" : "max_hops", "type" : "int", "default" : "10"},
228 | {"name" : "output", "type" : "text", "out" : true}],
229 | "category" : "Web"},
230 | {"name" : "ScrapePinterest",
231 | "params" : [{"name" : "url", "type" : "string", "default" : ""},
232 | {"name" : "username", "type" : "string", "default" : ""},
233 | {"name" : "password", "type" : "string", "default" : ""},
234 | {"name" : "target", "type" : "int", "default" : "100"},
235 | {"name" : "output", "type" : "images", "out" : true}],
236 | "category" : "Images"},
237 | {"name" : "LoadImages",
238 | "params" : [{"name" : "directory", "type" : "directory", "default" : ""},
239 | {"name" : "images", "type" : "images", "out" : true}],
240 | "category" : "File"},
241 | {"name" : "SaveImages",
242 | "params" : [{"name" : "images", "type" : "images", "in" : true},
243 | {"name" : "directory", "type" : "directory", "default" : ""}],
244 | "category" : "File"},
245 | {"name" : "ResizeImages",
246 | "params" : [{"name" : "input", "type" : "images", "in" : true},
247 | {"name" : "size", "type" : "int", "default" : "256"},
248 | {"name" : "output", "type" : "images", "out" : true}],
249 | "category" : "Images"},
250 | {"name" : "RemoveGrayscale",
251 | "params" : [{"name" : "input", "type" : "images", "in" : true},
252 | {"name" : "output", "type" : "images", "out" : true},
253 | {"name" : "rejects", "type" : "images", "out" : true}],
254 | "category" : "Images"},
255 | {"name" : "CropFaces",
256 | "params" : [{"name" : "input", "type" : "images", "in" : true},
257 | {"name" : "size", "type" : "int", "default" : "256"},
258 | {"name" : "output", "type" : "images", "out" : true},
259 | {"name" : "rejects", "type" : "images", "out" : true}],
260 | "category" : "Images"},
261 | {"name" : "StyleGAN_FineTune",
262 | "params" : [{"name" : "model_in", "type" : "model", "in" : true},
263 | {"name" : "images", "type" : "images", "in" : true},
264 | {"name" : "start_kimg", "type" : "int", "default" : "7000"},
265 | {"name" : "max_kimg", "type" : "int", "default" : "25000"},
266 | {"name" : "seed", "type" : "int", "default" : "1000"},
267 | {"name" : "schedule", "type" : "string", "default" : ""},
268 | {"name" : "model_out", "type" : "model", "out" : true},
269 | {"name" : "animation", "type" : "images", "out" : true}],
270 | "category" : "Images"},
271 | {"name" : "StyleGAN_Run",
272 | "params" : [{"name" : "model", "type" : "model", "in" : true},
273 | {"name" : "num", "type" : "int", "default" : "1"},
274 | {"name" : "images", "type" : "images", "out" : true}],
275 | "category" : "Images"},
276 | {"name" : "StyleGAN_Movie",
277 | "params" : [{"name" : "model", "type" : "model", "in" : true},
278 | {"name" : "length", "type" : "int", "default" : "10"},
279 | {"name" : "interp", "type" : "int", "default" : "10"},
280 | {"name" : "duration", "type" : "int", "default" : "10"},
281 | {"name" : "movie", "type" : "images", "out" : true}],
282 | "category" : "Images"},
283 | {"name" : "MakeMovie",
284 | "params" : [{"name" : "images", "type" : "images", "in" : true},
285 | {"name" : "duration", "type" : "int", "default" : "10"},
286 | {"name" : "movie", "type" : "images", "out" : true}],
287 | "category" : "Images"},
288 | {"name" : "Gridify",
289 | "params" : [{"name" : "input", "type" : "images", "in" : true},
290 | {"name" : "size", "type" : "int", "default" : "256"},
291 | {"name" : "columns", "type" : "int", "default" : "4"},
292 | {"name" : "output", "type" : "images", "out" : true}],
293 | "category" : "Images"},
294 | {"name" : "Degridify",
295 | "params" : [{"name" : "input", "type" : "images", "in" : true},
296 | {"name" : "columns", "type" : "int", "default" : "4"},
297 | {"name" : "rows", "type" : "int", "default" : "4"},
298 | {"name" : "output", "type" : "images", "out" : true}],
299 | "category" : "Images"},
300 | {"name" : "StyleTransfer",
301 | "params" : [{"name" : "content_image", "type" : "images", "in" : true},
302 | {"name" : "style_image", "type" : "images", "in" : true},
303 | {"name" : "steps", "type" : "int", "default" : "1000"},
304 | {"name" : "size", "type" : "int", "default" : "512"},
305 | {"name" : "style_weight", "type" : "int", "default" : "1000000"},
306 | {"name" : "content_weight", "type" : "int", "default" : "1"},
307 | {"name" : "content_layers", "type" : "string", "default" : "4"},
308 | {"name" : "style_layers", "type" : "string", "default" : "1, 2, 3, 4, 5"},
309 | {"name" : "output", "type" : "images", "out" : true}],
310 | "category" : "Images"},
311 | {"name" : "JoinImageDirectories",
312 | "params" : [{"name" : "dir1", "type" : "images", "in" : true},
313 | {"name" : "dir2", "type" : "images", "in" : true},
314 | {"name" : "output", "type" : "images", "out" : true}],
315 | "category" : "File"},
316 | {"name" : "SquareCrop",
317 | "params" : [{"name" : "input", "type" : "images", "in" : true},
318 | {"name" : "output", "type" : "images", "out" : true}],
319 | "category" : "Images"},
320 | {"name" : "UnmakeMovie",
321 | "params" : [{"name" : "movie", "type" : "images", "in" : true},
322 | {"name" : "output", "type" : "images", "out" : true}],
323 | "category" : "Images"}
324 | ];
325 | /*
326 |
327 | {"name" : "MakePredictionData", "params" : "data(in,text);x(out,text);y(out,text)", "tip" : "Prepare data for prediction--each line will try to predict the next line", "category" : "no"}, \
328 | {"name" : "DCGAN_Train", "params" : "input_images(images,in);epochs(int=10);input_height(int=108);output_height(int=108);filetype(string=jpg);model(out,model);animation(out,image)", "tip" : "Train a generateive adversarial network to make images", "category" : "no"}, \
329 | {"name" : "DCGAN_Run", "params" : "input_images(images,in);model(in,model);input_height(int=108);output_height(int=108);filetype(string=jpg);output_image(out,image)", "tip" : "Generate an image from a generative adversarial network", "category" : "no"}, \
330 | {"name" : "ReadImages", "params" : "data_directory(directory);output_images(out,images)", "tip" : "Read in a directory of image files", "category" : "File"}, \
331 | {"name" : "WriteImages", "params" : "input_images(in,images);output_directory(directory)", "tip" : "Save a group of images to a directory", "category" : "File"}, \
332 | {"name" : "PickFromWikipedia", "params" : "wiki_directory(directory,tip=Directory where wikipedia files are stored);input(in,text);catgories(string=*,tip=What categories if any?);section_name(string,tip=What section to pull text from if any);output(out,text);break_sentences(bool=false,tip=Should text be broken into one sentence per line?)", "tip" : "Pull text from wikipedia for the articles specified (file with one title per line)", "category" : "Wikipedia"}, \
333 | {"name" : "Repeat", "params" : "input(in,text);output(out,text);times(int)", "category" : "Do not use"}, \
334 | {"name" : "StyleNet_Train", "params" : "style_image(in,image);test_image(in,image);epochs(int=2,tip=how long to run);model(out,model);animation(out,image)", "tip" : "Draw the target image in the style of the style image", "category" : "no"}, \
335 | {"name" : "StyleNet_Run", "params" : "model(in,model);target_image(in,image,tip=Image to stylize);output_image(out,image)", "tip" : "Apply a style learned by a neural net to an image", "category" : "no"}, \
336 | {"name" : "ReadImageFile", "params" : "file(string,tip=Name of image file to read in);output(out,image)", "tip" : "Read an image file in", "category" : "File"}, \
337 | {"name" : "WriteImageFile", "params" : "input(in,image);file(string,tip=Name of image file to write to)", "tip" : "Write an image to file", "category" : "File"}, \
338 | {"name" : "MakeEmptyText", "params" : "output(out,text)", "tip" : "Create an empty text file", "category" : "Utils"}, \
339 | ]';
340 | */
--------------------------------------------------------------------------------
/image_modules.py:
--------------------------------------------------------------------------------
1 | from module import *
2 | import requests
3 | import os
4 | import re
5 | import time
6 | import random
7 | import pickle
8 | import numpy as np
9 | import math
10 | import copy
11 | from torchvision import transforms
12 | from torchvision import utils
13 | from PIL import Image
14 | import shutil
15 | import pdb
16 | import subprocess
17 | import sys
18 |
19 | #aaah
20 |
21 |
22 |
23 | ##############################
24 |
25 | class ScrapePinterest(Module):
26 |
27 | def __init__(self, url, username, password, target, output):
28 | self.url = url # Initial url (must be at Pinterest)
29 | self.username = username # Pinterest username
30 | self.password = password # Pinterest password
31 | self.target = target # target number of images to download
32 | self.output = output # path to output directory to store image files
33 | self.ready = True
34 | self.output_files = [output]
35 |
36 | def run(self):
37 | # Do some input checking
38 | if 'pinterest.com' not in self.url.lower():
39 | print("url (" + self.url + ") is not a pinterest.com url.")
40 | return
41 | if len(self.password) == 0:
42 | print("password is empty")
43 | return
44 | if len(self.username) == 0:
45 | print("username is empty")
46 | return
47 | from selenium import webdriver
48 | from selenium.webdriver.common.keys import Keys
49 | # Set up Chrome Driver
50 | options = webdriver.ChromeOptions()
51 | options.add_argument('--headless')
52 | options.add_argument('--no-sandbox')
53 | options.add_argument('--disable-dev-shm-usage')
54 | wd = webdriver.Chrome('chromedriver',options=options)
55 | # open Pinterest and log in
56 | print("logging in to Pinterest...")
57 | wd.get("https://www.pinterest.com")
58 | emailElem = wd.find_element_by_name('id')
59 | emailElem.send_keys(self.username)
60 | passwordElem = wd.find_element_by_name('password')
61 | passwordElem.send_keys(self.password)
62 | passwordElem.send_keys(Keys.RETURN)
63 | time.sleep(5 + random.randint(1, 5))
64 | # Get the first page
65 | print("Going to first page...")
66 | wd.get(self.url)
67 | # sleep
68 | time.sleep(5 + random.randint(1, 5))
69 | # get urls for images
70 | results = set() # the urls
71 | url_count = 0 # how many image urls have we found?
72 | miss_count = 0 # how many times have we failed to get more images?
73 | max_miss_count = 10 # how many times are we willing to try again?
74 | print("getting image urls...")
75 | while len(results) < self.target and miss_count < 5:
76 | # Find image elements in the web page
77 | images = wd.find_elements_by_tag_name("img")
78 | # Iterate through the elements and try to get the urls, which is the source of the element
79 | for i in images:
80 | try:
81 | src = i.get_attribute("src")
82 | results.add(src)
83 | except:
84 | pass
85 | # Did we fail to get new image urls?
86 | print(len(results))
87 | if len(results) == url_count:
88 | # If so, increment miss count
89 | miss_count = miss_count + 1
90 | else:
91 | miss_count = 0
92 | # Remember how many image urls we had last time around
93 | url_count = len(results)
94 | # sleep
95 | time.sleep(5 + random.randint(1, 5))
96 | # Send page down signal
97 | try:
98 | dummy = wd.find_element_by_tag_name('a')
99 | dummy.send_keys(Keys.PAGE_DOWN)
100 | except:
101 | print("can't page down")
102 | miss_count = miss_count + 1
103 | # convert results to list
104 | results = list(results)
105 | # Prep download directory
106 | prep_output_dir(self.output)
107 | # download images
108 | print("downloading images...")
109 | for result in results:
110 | res = requests.get(result)
111 | filename = os.path.join(self.output, os.path.basename(result))
112 | file = open(filename, 'wb')
113 | for chunk in res.iter_content(100000):
114 | file.write(chunk)
115 | file.close()
116 |
117 |
118 | ####################################
119 |
120 | class ResizeImages(Module):
121 |
122 | def __init__(self, input, size, output):
123 | self.input = input # path to input files (directory)
124 | self.output = output # path to output files (directory)
125 | self.size = size # width and height (int)
126 | self.ready = checkFiles(input)
127 | self.output_files = [output]
128 |
129 | def run(self):
130 | # Prep output file directory
131 | prep_output_dir(self.output)
132 | # The transformation
133 | p = transforms.Compose([transforms.Resize((self.size,self.size))])
134 | # Apply to all files in input directory
135 | if os.path.exists(self.input) and os.path.isdir(self.input):
136 | for file in os.listdir(self.input):
137 | try:
138 | img = Image.open(os.path.join(self.input, file))
139 | img2 = p(img)
140 | img2.save(os.path.join(self.output, file), 'JPEG')
141 | except:
142 | print(file, "did not load")
143 | else:
144 | print(self.input, "is not a directory")
145 |
146 | ###################################
147 |
148 | ### Module for face detection/close cropping faces
149 |
150 | class CropFaces(Module):
151 |
152 | def __init__(self, input, size, output, rejects):
153 | self.input = input # path to directory containing images
154 | self.output = output # path to directory to save new images
155 | self.rejects = rejects # path to directory to save rejected images
156 | self.size = size # (int) size of height and width of output files
157 | self.ready = checkFiles(input)
158 | self.output_files = [output, rejects]
159 |
160 | def run(self):
161 | # Prep output file directory
162 | prep_output_dir(self.output)
163 | prep_output_dir(self.rejects)
164 | # Run autocrop program
165 | cmd = 'autocrop -i ' + self.input + ' -o ' + self.output + ' -w ' + str(self.size) + ' -H ' + str(self.size) + ' --facePercent 50 -r ' + self.rejects
166 | status = os.system(cmd)
167 |
168 | ### Module for removing grayscale images
169 | class RemoveGrayscale(Module):
170 |
171 | def __init__(self, input, output, rejects):
172 | self.input = input # path to directory containing images
173 | self.output = output # path to directory to put non-grayscale images
174 | self.rejects = rejects # path to directory to put grayscale images
175 | self.ready = checkFiles(input)
176 | self.output_files = [output, rejects]
177 |
178 | def run(self):
179 | # prep output file directory
180 | prep_output_dir(self.output, makedir=True)
181 | prep_output_dir(self.rejects, makedir=True)
182 | # Check each file
183 | for file in os.listdir(self.input):
184 | try:
185 | img = Image.open(os.path.join(self.input, file))
186 | except:
187 | print(file, "could not be loaded")
188 | if len(img.getbands()) == 3:
189 | # this images has 3 color channels (not grayscale)
190 | shutil.copyfile(os.path.join(self.input, file), os.path.join(self.output, file))
191 | else:
192 | # this image is grayscale
193 | shutil.copyfile(os.path.join(self.input, file), os.path.join(self.rejects, file))
194 |
195 | #####################
196 |
197 | class MakeGrayscale(Module):
198 |
199 | def __init__(self, input, output):
200 | self.input = input # path to directory containing images
201 | self.output = output # path to directory
202 | self.ready = checkFiles(input)
203 | self.output_files = [output]
204 |
205 | def run(self):
206 | # prep output file directory
207 | prep_output_dir(self.output)
208 | # load each file
209 | for file in os.listdir(self.input):
210 | # load and convert
211 | img = Image.open(os.path.join(self.input, file)).convert('LA')
212 | # save to target destination
213 | img.save(os.path.join(self.output, file))
214 |
215 | #############################
216 |
217 | ### Module for saving images (directory)
218 | class SaveImages(Module):
219 |
220 | def __init__(self, images, directory):
221 | self.images = images # path to directory containing images
222 | self.directory = directory # path name to save to
223 | self.ready = checkFiles(images)
224 | self.output_files = [directory]
225 |
226 | def run(self):
227 | prep_output_dir(self.directory, makedir=False)
228 | shutil.copytree(self.images, self.directory)
229 |
230 |
231 | ###############################
232 |
233 | class LoadImages(Module):
234 |
235 | def __init__(self, directory, images):
236 | self.images = images # path to save images into
237 | self.directory = directory # path to load images from
238 | self.ready = True
239 | self.output_files = [images]
240 |
241 | def run(self):
242 | # Check for path existence
243 | if os.path.exists(self.directory):
244 | # If it's a directory, copy all files in the directory
245 | if os.path.isdir(self.directory):
246 | prep_output_dir(self.images, makedir=False)
247 | shutil.copytree(self.directory, self.images)
248 | else:
249 | # Not a directory, make a directory and copy the single file into it
250 | prep_output_dir(self.images)
251 | shutil.copy(self.directory, self.images)
252 |
253 | ################################
254 |
255 | ### Module for fine-tuning StyleGAN
256 | ### (include output for an animation)
257 |
258 | class StyleGAN_FineTune(Module):
259 |
260 | def __init__(self, model_in, images, start_kimg, max_kimg, seed, schedule, model_out, animation):
261 | self.model_in = model_in # path to model (pkl file)
262 | self.images = images # path to image directory
263 | self.start_kimg = start_kimg # (int) iteration to begin with (closer to zero means retrain more) default: 7000
264 | self.max_kimg = max_kimg # (int) max kimg
265 | self.schedule = schedule # (int or string or dict) the number of kimgs per resolution level
266 | self.model_out = model_out # path to fine tuned model (pkl file)
267 | self.animation = animation # path to an animation file
268 | self.seed = seed # (int) default = 1000
269 | self.ready = checkFiles(model_in, images)
270 | self.output_files = [model_out, animation]
271 |
272 | def run(self):
273 | print("Keyboard interrupt will stop training but program will try to continue.")
274 | # Run the stylegan program in a separate sub-process
275 | cwd = os.getcwd()
276 | schedule = str(self.schedule)
277 | params = {'model': os.path.join(cwd, self.model_in),
278 | 'images_in' : os.path.join(cwd, self.images),
279 | 'dataset_temp' : os.path.join(cwd, 'stylegan_dataset'),
280 | 'start_kimg' : self.start_kimg,
281 | 'max_kimg' : self.max_kimg,
282 | 'schedule' : schedule if len(schedule) > 0 else ' ',
283 | 'seed' : self.seed
284 | }
285 | command = 'python stylegan/stylegan_runner.py --train'
286 | for key in params.keys():
287 | val = params[key]
288 | command = command + ' --' + str(key) + ' ' + str(val)
289 | print("launching", command)
290 | try:
291 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True )
292 | for line in iter(process.stdout.readline, b''): # replace '' with b'' for Python 3
293 | sys.stdout.write(line)
294 | except KeyboardInterrupt:
295 | print("Keyboard interrupt")
296 | ### ASSERT: we are done training
297 | # Get final model... it's the pkl with the biggest number
298 | # First, get the latest run directory
299 | run_dir = ''
300 | latest = -1
301 | for file in os.listdir(os.path.join(cwd, 'results')):
302 | match = re.match(r'([0-9]+)', file)
303 | if match is not None and match.group(1) is not None:
304 | current = int(match.group(1))
305 | if current > latest:
306 | latest = current
307 | run_dir = file
308 | run_dir = os.path.join(cwd, 'results', run_dir)
309 | # Now get the newest pkl file and image files
310 | model_filename = ''
311 | image_files = []
312 | latest = 0.0
313 | for file in os.listdir(run_dir):
314 | match_png = re.search(r'[\w\W]*?[0-9]+?.png', file)
315 | match_pkl = re.search(r'[\w\W]*?[0-9]+?.pkl', file)
316 | if match_pkl is not None:
317 | current = os.stat(os.path.join(cwd, 'results', run_dir, file)).st_ctime
318 | if current > latest:
319 | latest = current
320 | model_filename = file
321 | elif match_png is not None:
322 | image_files.append(os.path.join(cwd, 'results', run_dir, file))
323 | model_filename = os.path.join(cwd, 'results', run_dir, model_filename)
324 | # save animation images
325 | prep_output_dir(self.animation)
326 | for file in image_files:
327 | shutil.copy(file, self.animation)
328 | # Save fine tuned model
329 | shutil.copyfile(model_filename, self.model_out)
330 |
331 |
332 |
333 | ############################
334 |
335 |
336 |
337 | class StyleGAN_Run(Module):
338 |
339 | def __init__(self, model, num, images):
340 | self.model = model # path to model
341 | self.num = num # number of images to generate (int)
342 | self.images = images # path to output image
343 | self.ready = checkFiles(model)
344 | self.output_files = [images]
345 |
346 | def run(self):
347 | prep_output_dir(self.images)
348 | # Run the stylegan program in a separate sub-process
349 | cwd = os.getcwd()
350 | params = {'model': os.path.join(cwd, self.model),
351 | 'images_out' : os.path.join(cwd, self.images),
352 | 'num' : self.num,
353 | }
354 | command = 'python stylegan/stylegan_runner.py --run'
355 | for key in params.keys():
356 | val = params[key]
357 | command = command + ' --' + str(key) + ' ' + str(val)
358 | print("launching", command)
359 | try:
360 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True )
361 | for line in iter(process.stdout.readline, b''): # replace '' with b'' for Python 3
362 | sys.stdout.write(line)
363 | except Exception as e:
364 | print("something broke")
365 | print(e)
366 |
367 |
368 | #############
369 |
370 | # Not tested. Needs to launch stylegan_runner.py
371 | class StyleGAN_Movie(Module):
372 |
373 | def __init__(self, model, length, interp, duration, movie):
374 | self.model = model # path to input model
375 | self.length = length # (int) number of way points interpolated between (default = 10)
376 | self.interp = interp # (int) number of interpolations between waypoints (default = 30)
377 | self.duration = duration # (int) duration of animation (default=1)
378 | self.movie = movie
379 | self.ready = checkFiles(model)
380 | self.output_files = [movie]
381 |
382 | def run(self):
383 | prep_output_dir(self.movie)
384 | # Run the stylegan program in a separate sub-process
385 | cwd = os.getcwd()
386 | params = {'model': os.path.join(cwd, self.model),
387 | 'movie_out' : os.path.join(cwd, self.movie, 'movie.gif'),
388 | 'num' : self.length,
389 | 'interp' : self.interp,
390 | 'duration' : self.duration
391 | }
392 | command = 'python stylegan/stylegan_runner.py --movie'
393 | for key in params.keys():
394 | val = params[key]
395 | command = command + ' --' + str(key) + ' ' + str(val)
396 | print("launching", command)
397 | try:
398 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True )
399 | for line in iter(process.stdout.readline, b''): # replace '' with b'' for Python 3
400 | sys.stdout.write(line)
401 | except Exception as e:
402 | print("something broke")
403 | print(e)
404 |
405 |
406 |
407 |
408 | ######################
409 |
410 | class MakeMovie(Module):
411 |
412 | def __init__(self, images, duration, movie):
413 | self.images = images # path to directory with images
414 | self.duration = duration # (int) duration of animation (default=10)
415 | self.movie = movie # path to output movie
416 | self.ready = checkFiles(images)
417 | self.output_files = [movie]
418 |
419 | def run(self):
420 | images = []
421 | files = []
422 | # Get list of filenames
423 | for file in os.listdir(self.images):
424 | files.append(file)
425 | # Sort the file names by creation date
426 | sorted_files = sorted(files)
427 | # Load each image in sorted order
428 | for file in sorted_files:
429 | try:
430 | img = Image.open(os.path.join(self.images, file))
431 | images.append(img)
432 | except:
433 | print(file, "did not load.")
434 | # Prep the destination, a directory to save a single file
435 | prep_output_dir(self.movie)
436 | # Make and save the animation
437 | images[0].save(os.path.join(self.movie, 'movie.gif'), "GIF",
438 | save_all=True,
439 | append_images=images[1:],
440 | duration=self.duration,
441 | loop=0)
442 |
443 | ##########################
444 |
445 | class Degridify(Module):
446 |
447 | def __init__(self, input, columns, rows, output):
448 | self.input = input # path to directory containing images
449 | self.output = output # path to output directory
450 | self.columns = columns # (int) number of columns
451 | self.rows = rows # (int) number of rows
452 | self.ready = checkFiles(input)
453 | self.output_files = [output]
454 |
455 | def run(self):
456 | columns = self.columns
457 | rows = self.rows
458 | prep_output_dir(self.output)
459 | # iterate through images in the directory
460 | for file in os.listdir(self.input):
461 | # Load image
462 | grid = Image.open(os.path.join(self.input, file))
463 | # compute the size of each grid cell
464 | grid_width, grid_height = grid.size
465 | image_width = grid_width // columns
466 | image_height = grid_height // rows
467 | # Start cropping
468 | if grid.n_frames > 1:
469 | self.cropAnim(grid, columns, rows, image_width, image_height)
470 | else:
471 | self.cropImage(grid, columns, rows)
472 |
473 | def cropImage(self, grid, columns, rows, image_width, image_height):
474 | for i in range(columns):
475 | for j in range(rows):
476 | img = grid.crop((i*image_width, j*image_height, (i+1)*image_width, (j+1)*image_height))
477 | img.save(os.path.join(self.output, str(i).zfill(5) + '-' + str(j).zfill(5) + '.gif'), "GIF")
478 |
479 | def cropAnim(self, anim, columns, rows, image_width, image_height):
480 | cwd = os.getcwd()
481 | temp_dir = os.path.join(cwd, '.degridify_temp')
482 | if os.path.exists(temp_dir):
483 | shutil.rmtree(temp_dir)
484 | os.mkdir(temp_dir)
485 | frame_paths = []
486 | anim_dict = {}
487 | for n in range(anim.n_frames):
488 | anim.seek(n)
489 | file = str(n).zfill(5) + '.gif'
490 | path = os.path.join(temp_dir, file)
491 | anim.save(path, "GIF")
492 | frame_paths.append(path)
493 | for file in frame_paths:
494 | frame = Image.open(file)
495 | for i in range(columns):
496 | for j in range(rows):
497 | cropped = frame.crop((i*image_width, j*image_height, (i+1)*image_width, (j+1)*image_height))
498 | if (i, j) not in anim_dict:
499 | anim_dict[(i, j)] = []
500 | anim_dict[(i, j)].append(cropped)
501 | for key in anim_dict:
502 | i, j = key
503 | filename = str(i).zfill(5) + '-' + str(j).zfill(5) + '.gif'
504 | frames = anim_dict[key]
505 | frames[0].save(os.path.join(self.output, filename), "GIF",
506 | save_all=True,
507 | append_images=frames[1:],
508 | duration=100,
509 | loop=0)
510 | shutil.rmtree(temp_dir)
511 |
512 |
513 | class Gridify(Module):
514 |
515 | def __init__(self, input, size, columns, output):
516 | self.input = input # path to directory containing input files
517 | self.output = ouput # path containing a grid file
518 | self.size = size # (int) size of a cell (default=256)
519 | self.columns = columns # (int) number of columns required (rows computed automatically) (default=4)
520 | self.ready = checkFiles(input)
521 | self.output_files = [output]
522 |
523 | def run(self):
524 | columns = self.columns # Just making it easier to use this variable
525 | # load and resize images
526 | images = [] # all the images
527 | files = [] # all the filenames
528 | p = transforms.Compose([transforms.Resize((self.size,self.size))]) # the transform
529 | # Sort the files by creation date
530 | for file in os.listdir(self.input):
531 | files.append(file)
532 | sorted_files = sorted(files)
533 | # Iterate through all images in the directory
534 | for file in sorted_files:
535 | # Load the image
536 | img = Image.open(os.path.join(self.input, file))
537 | # Resize it
538 | img2 = p(img)
539 | # Save it in order
540 | images.append(img2)
541 | # compute number of rows
542 | rows = len(images) // columns
543 | leftovers = len(images) % columns
544 | # round up
545 | if leftovers > 0:
546 | rows = rows + 1
547 | for _ in range(leftovers):
548 | empty = Image.new('RGB', (self.size, self.size))
549 | images.append(empty)
550 | # make new image
551 | grid = Image.new('RGB', (columns*self.size, rows*self.size))
552 | # paste images into new image
553 | counter = 0
554 | for i in range(columns):
555 | for j in range(rows):
556 | cur_img = images[counter]
557 | grid.paste(cur_img, (i*self.size, j*self.size))
558 | counter = counter + 1
559 | # save new image
560 | prep_output_dir(self.output)
561 | grid.save(os.path.join(self.output, 'grid.jpg'), "JPEG")
562 |
563 |
564 | ###########################
565 |
566 | class StyleTransfer(Module):
567 |
568 | def __init__(self, content_image, style_image, steps, size, style_weight, content_weight,
569 | content_layers, style_layers, output):
570 | self.content_image = content_image # path to content image directory
571 | self.style_image = style_image # path to style image directory
572 | self.steps = steps # (int) number of steps to run (Default=1000)
573 | self.size = size # (int) image size (default=512)
574 | self.style_weight = style_weight # (int) style weight
575 | self.content_weight = content_weight # (int) content weight
576 | self.content_layers = content_layers # (str) consisting of numbers [1-5] (default="4")
577 | self.style_layers = style_layers # (str) consisting of numbers [1-5] (default="1, 2, 3, 4, 5")
578 | self.output = output # path to output directory
579 | self.ready = checkFiles(content_image, style_image)
580 | self.output_files = [output]
581 |
582 | def run(self):
583 | import style_transfer
584 | prep_output_dir(self.output)
585 | count = 0
586 | # sort the content and style images
587 | content_images = []
588 | style_images = []
589 | sorted_content_images = []
590 | sorted_style_images = []
591 | for content_file in os.listdir(self.content_image):
592 | content_images.append(content_file)
593 | for style_file in os.listdir(self.style_image):
594 | style_images.append(style_file)
595 | sorted_content_images = sorted(content_images)
596 | sorted_style_images = sorted(style_images)
597 | # If there are multiple input content and style images, run all combinations
598 | for style_file in sorted_style_images:
599 | for content_file in sorted_content_images:
600 | count = count + 1
601 | print("Running with content=" + content_file + " style=" + style_file)
602 | style_transfer.run(os.path.join(self.content_image, content_file),
603 | os.path.join(self.style_image, style_file),
604 | os.path.join(self.output, str(count).zfill(5) + '.jpg'),
605 | image_size = self.size,
606 | num_steps = self.steps,
607 | style_weight = self.style_weight,
608 | content_weight = self.content_weight,
609 | content_layers_spec = self.content_layers,
610 | style_layers_spec = self.style_layers)
611 |
612 | #################################
613 |
614 | class JoinImageDirectories(Module):
615 |
616 | def __init__(self, dir1, dir2, output):
617 | self.dir1 = dir1
618 | self.dir2 = dir2
619 | self.output = output
620 | self.ready = checkFiles(dir1, dir2)
621 | self.output_files = [output]
622 |
623 | def run(self):
624 | prep_output_dir(self.output)
625 | for file in os.listdir(self.dir1):
626 | shutil.copy(os.path.join(self.dir1, file), self.output)
627 | for file in os.listdir(self.dir2):
628 | shutil.copy(os.path.join(self.dir2, file), self.output)
629 |
630 | #################################
631 |
632 | class SquareCrop(Module):
633 |
634 | def __init__(self, input, output):
635 | self.input = input
636 | self.output = output
637 | self.ready = checkFiles(input)
638 | self.output_files = [output]
639 |
640 | def run(self):
641 | prep_output_dir(self.output)
642 | for file in os.listdir(self.input):
643 | img = Image.open(os.path.join(self.input, file))
644 | square_img = img
645 | width, height = img.size
646 | if width > height:
647 | diff = width - height
648 | box = (diff//2, 0, width - diff//2, height)
649 | square_img = img.crop(box)
650 | elif height > width:
651 | diff = height - width
652 | box = (0, diff//2, width, height - diff//2)
653 | square_img = img.crop(box)
654 | square_img.save(os.path.join(self.output, file), "JPEG")
655 |
656 | ####################################
657 |
658 | class UnmakeMovie(Module):
659 |
660 | def __init__(self, movie, output):
661 | self.movie = movie # path to directory containing movies
662 | self.output = output # path to directory to save images
663 | self.ready = checkFiles(movie)
664 | self.output_files = [output]
665 |
666 | def run(self):
667 | prep_output_dir(self.output)
668 | for file in os.listdir(self.movie):
669 | anim = Image.open(os.path.join(self.movie, file))
670 | for n in range(anim.n_frames):
671 | anim.seek(n)
672 | new_filename = os.path.splitext(file)[0] + '-' + str(n).zfill(5) +'.gif'
673 | anim.save(os.path.join(self.output, new_filename), "GIF")
674 |
675 |
--------------------------------------------------------------------------------
/gui.js:
--------------------------------------------------------------------------------
1 | // GLOBALS ////////////////////////////////////
2 | const screen_width = 2048; // default screen width
3 | const screen_height = 600; // default screen height
4 | const module_width = 200; // default module width
5 | const module_height = 40; // default module height
6 | const parameter_offset = 10; // how much to indent parameters
7 | const parameter_width = module_width - (parameter_offset * 2); // default parameter width
8 | const parameter_height = module_height; // default parameter height
9 | const module_spacing = 20; // how much space between modules
10 | const parameter_spacing = 5; // how much space between parameters
11 | const cache_path = "cache/" // cache directory
12 | const image_path = '/nbextensions/google.colab/'; // where to find images
13 | const text_size = 16; // default text size
14 | const font_size = "" + text_size + "px"; // default font size for html5
15 |
16 | // What images to use for different parameter types
17 | var type_images = {"model" : "model.png",
18 | "dictionary" : "dictionary.png",
19 | "image" : "image.png",
20 | "images" : "images.png",
21 | "text" : "text.png",
22 | "string" : "string.png",
23 | "int" : "number.png",
24 | "float" : "number.png"};
25 |
26 | var drag = null; // Thing being dragged
27 | var clicked = null; // Thing that was clicked
28 | var open_param = null; // Which parameter is being edited
29 | var mouseX; // Mouse X
30 | var mouseY; // Mouse Y
31 | var modules = []; // List of modules
32 | var parameters = []; // list of parameters
33 | var connections = []; // List of connections
34 | var clickables = []; // List of clickable things
35 | var dragables = []; // List of dragable things
36 | var module_counters = {}; // Dictionary of counts of each module used
37 | var drag_offset_x = 0; // Keep track off the offset due to scrolling
38 | var drag_offset_y = 0; // Keep track of the offset due to scrolling
39 | var module_counter = 0; // keep count of number of modules
40 | var parameter_counter = 0; // keep count of number of parameters
41 |
42 |
43 | // START GAME /////////////////////////////////
44 | function startGame() {
45 | // HIDE HTML OBJECTS
46 | hideInput();
47 | // POPULATE MODULE MAKING HTML
48 | populateModuleMaker();
49 | // START THE GUI RUNNING
50 | myGameArea.start();
51 | // EVENT HANDLERS
52 | // Mouse down
53 | myGameArea.canvas.onmousedown = function(e) {
54 | var offset = recursive_offset(myGameArea.canvas)
55 | var i;
56 | // Iterate through Dragables
57 | for (i = 0; i < dragables.length; i++) {
58 | var obj = dragables[i];
59 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width &&
60 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) {
61 | // found the dragable we clicked on
62 | drag = obj; // we are dragging this object now
63 | drag_offset_x = mouseX - obj.x;
64 | drag_offset_y = mouseY - obj.y;
65 | // Stop showing the text entry
66 | hideInput();
67 | return;
68 | }
69 | }
70 | // Iterate through clickables
71 | for (i = 0; i < clickables.length; i++) {
72 | var obj = clickables[i];
73 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width &&
74 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) {
75 | // found the clickable we clicked on
76 | if (obj.is_out) {
77 | // Clicked on an output element that is not connected
78 | clicked = obj; // we are drawing a line from this object now
79 | break;
80 | }
81 | else if (!obj.is_in && !obj.is_out) {
82 | // clicked on a non-input connection
83 | // show input
84 | var inp = document.getElementById("inp");
85 | inp.style.display = "block";
86 | var inp_module = document.getElementById("inp_module");
87 | inp_module.innerHTML = obj.parent.name;
88 | var inp_param = document.getElementById("inp_param");
89 | inp_param.innerHTML = obj.name;
90 | var val = document.getElementById("inp_val");
91 | val.value = obj.value;
92 | open_param = obj;
93 | return;
94 | }
95 | }
96 | }
97 | // Make text input invisible
98 | hideInput();
99 | }
100 | // Mouse Move
101 | myGameArea.canvas.onmousemove = function(e) {
102 | mouseX = e.clientX;
103 | mouseY = e.clientY;
104 | }
105 | // Mouse up
106 | myGameArea.canvas.onmouseup = function(e) {
107 | if (drag) {
108 | // Done dragging
109 | drag = null;
110 | }
111 | if (clicked) {
112 | // Done clicking
113 | var offset = recursive_offset(myGameArea.canvas)
114 | var i;
115 | // we are making a line. Where did we drag the line to?
116 | for (i = 0; i < parameters.length; i++) {
117 | var obj = parameters[i];
118 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width &&
119 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) {
120 | // Found the parameter we mouse-uped on
121 | if (obj.is_in && !obj.connected && obj.parent.name != clicked.parent.name && obj.type === clicked.type) {
122 | // clicked on an input that isn't already connected
123 | // Make a connection
124 | c = new connection(clicked, obj);
125 | c.id = connections.length;
126 | connections.push(c);
127 | obj.connected = true;
128 | clicked.connected = true;
129 | var fname = cache_path + obj.type + clicked.id;
130 | obj.value = fname;
131 | clicked.value = fname;
132 | }
133 | }
134 | }
135 | // Not making a line any more
136 | clicked = null;
137 | }
138 | }
139 | // double click
140 | myGameArea.canvas.ondblclick = function(e){
141 | var offset = recursive_offset(myGameArea.canvas)
142 | var i;
143 | // Dragables
144 | for (i = 0; i < dragables.length; i++) {
145 | var obj = dragables[i];
146 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width &&
147 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) {
148 | // found the dragable we double-clicked on.
149 | // Collapse or un-collapse it
150 | obj.collapsed = !obj.collapsed;
151 | break;
152 | }
153 | }
154 | };
155 | // Key Press
156 | document.addEventListener('keydown', function(e){
157 | var offset = recursive_offset(myGameArea.canvas)
158 | if (e.keyCode == 8 || e.keyCode == 46) {
159 | // DELETE
160 | var i;
161 | // Did we delete a dragable?
162 | for (i = 0; i < dragables.length; i++) {
163 | var obj = dragables[i];
164 | if (mouseX >= obj.x - offset.x && mouseX <= obj.x - offset.x + obj.width &&
165 | mouseY >= obj.y - offset.y && mouseY <= obj.y - offset.y + obj.height) {
166 | // Found the dragable we were mousing over when the delete button was pressed
167 | delete_module(obj);
168 | break;
169 | }
170 | }
171 | }
172 | } );
173 | }
174 |
175 | // MYGAMEAREA //////////////////////////////////////////
176 | var myGameArea = {
177 | canvas : document.createElement("canvas"),
178 | start : function() {
179 | this.canvas.width = screen_width;
180 | this.canvas.height = screen_height;
181 | this.context = this.canvas.getContext("2d");
182 | document.body.insertBefore(this.canvas, document.body.childNodes[0]);
183 | this.frameNo = 0;
184 | this.interval = setInterval(updateGameArea, 20);
185 | },
186 | clear : function() {
187 | this.context.clearRect(0, 0, this.canvas.width, this.canvas.height);
188 | }
189 | }
190 |
191 |
192 | // MODULE ///////////////////////////////////////////////
193 | function module(x, y, type, name, category, id) {
194 | this.width = module_width; // module width
195 | this.height = module_height; // module height
196 | this.color = "gray"; // module color
197 | this.category = category; // what category did this belong in (not sure this is needed)
198 | this.x = x; // x location
199 | this.y = y; // y location
200 | this.name = name; // label to display to user
201 | this.type = type; // module type
202 | this.params = []; // list of parameters
203 | this.font = "Arial" // font to use
204 | this.font_size = font_size; // font size
205 | this.id = id; // unique id number
206 | this.collapsed = false; // am I collapsed?
207 | this.up = new Image(); // up arrow
208 | this.up.src = image_path + "up.png"; // load the up arrow image
209 | this.down = new Image(); // down arrow
210 | this.down.src = image_path + "down.png"; // load the down arrow image
211 | // UPDATE FUNCTION
212 | this.update = function() {
213 | ctx = myGameArea.context;
214 | if (drag == this) {
215 | // I AM BEING DRAGGED
216 | this.x = mouseX - drag_offset_x;
217 | this.y = mouseY - drag_offset_y;
218 | }
219 | // DRAW ME
220 | // draw my parameters
221 | var i;
222 | for (i = 0; i < this.params.length; i++) {
223 | param = this.params[i];
224 | var index = i + 1;
225 | if (this.collapsed) {
226 | index = 0;
227 | }
228 | param.x = this.x + parameter_offset;
229 | param.y = this.y + (this.height + parameter_spacing)*(index);
230 | param.update();
231 | }
232 | // Draw the module header
233 | ctx.fillStyle = this.color;
234 | ctx.fillRect(this.x, this.y, this.width, this.height);
235 | ctx.font = this.font_size + " " + this.font;
236 | ctx.fillStyle = "black";
237 | var text_x = this.x + 5;
238 | var text_y = this.y + (this.height / 2.0) + (text_size / 3.0);
239 | var img_height = (this.height - 10) / 4;
240 | var img_width = (this.height - 10) / 2;
241 | ctx.fillText(this.name, text_x, text_y);
242 | // Show up arrow or down arrow?
243 | if (this.collapsed) {
244 | ctx.drawImage(this.down, this.x + this.width - img_width - 5, this.y + 5, img_width, img_height);
245 | }
246 | else {
247 | ctx.drawImage(this.up, this.x + this.width - img_width - 5, this.y + 5, img_width, img_height);
248 | }
249 |
250 | }
251 | }
252 |
253 | // PARAMETER //////////////////////////
254 | function parameter(name, is_in, is_out, type, default_value, parent, id) {
255 | this.color = "lightgray"; // parameter color (red is output, green is input, lightgray otherwise)
256 | this.width = parameter_width; // parameter width
257 | this.height = parameter_height; // parameter height
258 | this.is_in = is_in; // am I an input?
259 | this.is_out = is_out; // am I an output?
260 | this.type = type; // what type am I? (string, text, int, float, dictionary, model, etc.)
261 | this.value = default_value; // my value
262 | this.x = 0; // x location (parent module will set me)
263 | this.y = 0; // y location (parent module will set me)
264 | this.name = name; // My label to show the user
265 | this.connected = false; // Am I connected to another parameter?
266 | this.parent = parent; // Who is my parent module?
267 | this.font = "Arial" // My font
268 | this.font_size = font_size // my font size
269 | this.id = id; // unique identifier
270 | // Set my color based on whether I am an input or output. Also set the default filename I create when linked
271 | if (this.is_in) {
272 | this.color = "green";
273 | this.value = cache_path + name;
274 | }
275 | else if (this.is_out) {
276 | this.color = "red";
277 | this.value = cache_path + name;
278 | }
279 | // Do I have an icon to show?
280 | this.img = null; // icon image
281 | if (this.type in type_images) {
282 | this.img = new Image();
283 | this.img.src = image_path + type_images[this.type];
284 | }
285 | // UPDATE FUNCTION
286 | this.update = function() {
287 | ctx = myGameArea.context;
288 | // DRAW ME
289 | ctx.fillStyle = this.color;
290 | ctx.fillRect(this.x, this.y, this.width, this.height);
291 | ctx.font = this.font_size + " " + this.font;
292 | ctx.fillStyle = "black";
293 | var text_x = this.x + 5;
294 | var text_y = this.y + (this.height / 2.0) + (text_size / 3.0);
295 | var img_height = this.height - 10;
296 | var img_width = img_height;
297 | // Show my icon
298 | if (this.is_out) {
299 | ctx.drawImage(this.img, this.x + this.width - img_width - 5, this.y + 5, img_width, img_height);
300 | }
301 | else if (this.is_in || this.img) {
302 | ctx.drawImage(this.img, text_x, this.y + 5, img_width, img_height);
303 | text_x = text_x + img_width + 5;
304 | }
305 | ctx.fillText(this.name, text_x, text_y);
306 |
307 | }
308 | }
309 |
310 | // CONNECTION ////////////////////////////////
311 | function connection (origin, target) {
312 | this.origin = origin; // my origin parameter
313 | this.target = target; // my target parameter
314 | this.id = null; // unique identifier
315 | this.update = function() {
316 | var ctx = myGameArea.context;
317 | // DRAW ME
318 | var origin_x = this.origin.x + this.origin.width;
319 | var origin_y = this.origin.y + this.origin.height/2;
320 | var target_x = this.target.x;
321 | var target_y = this.target.y + this.target.height/2;
322 | if (this.origin.parent.collapsed) {
323 | origin_x = this.origin.x + this.origin.width/2;
324 | origin_y = this.origin.y + this.target.height;
325 | }
326 | if (this.target.parent.collapsed) {
327 | target_x = this.target.x + this.target.width/2;
328 | target_y = this.target.y;
329 | }
330 | ctx.lineWidth = 2;
331 | ctx.beginPath();
332 | ctx.moveTo(origin_x, origin_y);
333 | ctx.lineTo(target_x, target_y);
334 | ctx.stroke();
335 | ctx.beginPath();
336 | ctx.arc(origin_x, origin_y, 5, 0, 2 * Math.PI);
337 | ctx.fill();
338 | ctx.beginPath();
339 | ctx.arc(target_x, target_y, 5, 0, 2 * Math.PI);
340 | ctx.fill();
341 | }
342 | }
343 |
344 | // UPDATE CANVAS ///////////////////////////////////
345 | function updateGameArea() {
346 | var x, height, gap, minHeight, maxHeight, minGap, maxGap;
347 | myGameArea.clear();
348 | myGameArea.frameNo += 1;
349 | // Draw my modules
350 | var i;
351 | for (i = 0; i < modules.length; i++) {
352 | modules[i].update();
353 | }
354 | // Draw connections
355 | var j;
356 | for (j = 0; j < connections.length; j++) {
357 | connections[j].update();
358 | }
359 | // Is there a line being drawn that hasn't been connected yet?
360 | if (clicked) {
361 | var offset = recursive_offset(myGameArea.canvas)
362 | var ctx = myGameArea.context;
363 | var dot_x = clicked.x + clicked.width;
364 | var dot_y = clicked.y + clicked.height/2;
365 | ctx.lineWidth = 2;
366 | ctx.beginPath();
367 | ctx.moveTo(dot_x, dot_y);
368 | ctx.lineTo(mouseX+offset.x, mouseY+offset.y);
369 | ctx.stroke();
370 | ctx.beginPath();
371 | ctx.arc(clicked.x + clicked.width, clicked.y + clicked.height/2, 5, 0, 2 * Math.PI);
372 | ctx.fill();
373 | }
374 |
375 | }
376 |
377 | // BUTTON HANDLERS /////////////////////////////
378 |
379 | function do_input_button_up () {
380 | if (open_param != null) {
381 | var val = document.getElementById("inp_val");
382 | open_param.value = val.value;
383 | }
384 | hideInput();
385 | }
386 |
387 | function do_make_module_button_up() {
388 | var sel = document.getElementById("module_select");
389 | var val = sel.options[sel.selectedIndex].value; // Name of the module type
390 | var i;
391 | // Figure out what type of module was selected
392 | for (i = 0; i < module_dicts.length; i++) {
393 | var module_dict = module_dicts[i];
394 | if ("name" in module_dict) {
395 | var name = module_dict["name"];
396 | if (name == val) {
397 | // Found it
398 | make_module(module_dict, module_counter);
399 | }
400 | }
401 | }
402 | }
403 |
404 |
405 | // HELPERS ///////////////////////////////////////
406 |
407 |
408 | function make_module(module_json, id) {
409 | var category = "";
410 | var name = "";
411 | var type = "";
412 | var module_count = 0;
413 | // type
414 | if ("name" in module_json) {
415 | type = module_json["name"];
416 | }
417 | // name
418 | if (!(type in module_counters)) {
419 | module_counters[type] = 0;
420 | }
421 | module_count = module_counters[type] + 1;
422 | module_counters[type] = module_count;
423 | name = type + " (" + module_count + ")";
424 | // category
425 | if ("category" in module_json) {
426 | category = module_json["category"];
427 | }
428 | // make new module
429 | var new_module = new module(((module_width+module_spacing) * modules.length) % (screen_width-(module_width+module_spacing)),
430 | parseInt((module_width+module_spacing) * modules.length / (screen_width-(module_width+module_spacing))) * (module_width+module_spacing),
431 | type, name, category, id);
432 | module_counter = module_counter + 1;
433 | //new_module.id = module_counter;
434 | modules.push(new_module);
435 | dragables.push(new_module);
436 | // parameters
437 | if ("params" in module_json) {
438 | var params = module_json["params"];
439 | var j;
440 | for (j = 0; j < params.length; j ++ ) {
441 | param_json = params[j];
442 | var p_name = ""
443 | var is_in = false;
444 | var is_out = false;
445 | var type = "";
446 | var default_value = "";
447 | // parameter name/type
448 | if ("name" in param_json) {
449 | p_name = param_json["name"];
450 | }
451 | // is it an input?
452 | if ("in" in param_json && param_json["in"] == true) {
453 | is_in = true;
454 | }
455 | // is it an output?
456 | if ("out" in param_json && param_json["out"] == true) {
457 | is_out = true;
458 | }
459 | // What type of value does it store?
460 | if ("type" in param_json) {
461 | type = param_json["type"];
462 | }
463 | // What is the default value?
464 | if ("default" in param_json) {
465 | default_value = param_json["default"];
466 | }
467 | // Make the parameter
468 | new_param = new parameter(p_name, is_in, is_out, type, default_value, new_module, new_module.id + "-" + j);
469 | //new_param.id = new_module.id + "-" + j;
470 | parameter_counter = parameter_counter + 1;
471 | new_module.params.push(new_param);
472 | parameters.push(new_param);
473 | if (!is_in) {
474 | clickables.push(new_param);
475 | }
476 | }
477 | }
478 | return new_module;
479 | }
480 |
481 | function hideInput() {
482 | var inp = document.getElementById("inp");
483 | inp.style.display = "none";
484 | }
485 |
486 | function populateModuleMaker() {
487 | // Grab the select object
488 | var sel = document.getElementById("module_select");
489 | var categories_dict = {}; // keys are category names, val is list of module names
490 | // seed with misc category
491 | categories_dict["misc"] = [];
492 | // Collect up category names and module names
493 | var i;
494 | for (i = 0; i < module_dicts.length; i++) {
495 | var module_dict = module_dicts[i];
496 | if ("name" in module_dict && "category" in module_dict) {
497 | var category = module_dict["category"];
498 | var name = module_dict["name"];
499 | if (!(category in categories_dict)) {
500 | categories_dict[category] = [];
501 | }
502 | categories_dict[category].push(name);
503 | }
504 | else if ("name" in module_dict) {
505 | var name = module_dict["name"];
506 | categories_dict["misc"].push(name);
507 | }
508 | }
509 | // Iterate through categories
510 | for (var key in categories_dict) {
511 | var module_names = categories_dict[key];
512 | // If this category has modules, then add them to the select object
513 | if (module_names.length > 0) {
514 | // This category has modules
515 | // Make an optgroup
516 | var group = document.createElement('OPTGROUP');
517 | group.label = key;
518 | // Iterate through modules
519 | for (i=0; i < module_names.length; i++) {
520 | // Make an option
521 | var module_name = module_names[i];
522 | var opt = document.createElement('OPTION');
523 | opt.textContent = module_name;
524 | opt.value = module_name;
525 | group.appendChild(opt);
526 | }
527 | sel.appendChild(group);
528 | }
529 | }
530 | }
531 |
532 | function everyinterval(n) {
533 | if ((myGameArea.frameNo / n) % 1 == 0) {return true;}
534 | return false;
535 | }
536 |
537 |
538 | function recursive_offset (aobj) {
539 | var currOffset = {
540 | x: 0,
541 | y: 0
542 | }
543 | var newOffset = {
544 | x: 0,
545 | y: 0
546 | }
547 |
548 | if (aobj !== null) {
549 |
550 | if (aobj.scrollLeft) {
551 | currOffset.x = aobj.scrollLeft;
552 | }
553 |
554 | if (aobj.scrollTop) {
555 | currOffset.y = aobj.scrollTop;
556 | }
557 |
558 | if (aobj.offsetLeft) {
559 | currOffset.x -= aobj.offsetLeft;
560 | }
561 |
562 | if (aobj.offsetTop) {
563 | currOffset.y -= aobj.offsetTop;
564 | }
565 |
566 | if (aobj.parentNode !== undefined) {
567 | newOffset = recursive_offset(aobj.parentNode);
568 | }
569 |
570 | currOffset.x = currOffset.x + newOffset.x;
571 | currOffset.y = currOffset.y + newOffset.y;
572 | }
573 | return currOffset;
574 | }
575 |
576 |
577 | function save_program() {
578 | // Don't save if there is a blank program.
579 | if (modules.length == 0) {
580 | return;
581 | }
582 | var filename = "myprogram"
583 | var input_save = document.getElementById("inp_save");
584 | if (input_save.value.length > 0) {
585 | filename = input_save.value;
586 | }
587 | var program = [];
588 | // Deep copy connections
589 | var module_links = [];
590 | var module_ids = [];
591 | var i;
592 | for (i = 0; i < modules.length; i++) {
593 | module_ids.push(parseInt(modules[i].id));
594 | }
595 | for (i = 0; i < connections.length; i++) {
596 | module_links.push([parseInt(connections[i].origin.parent.id), parseInt(connections[i].target.parent.id)]);
597 | }
598 | // Put all the modules in order
599 | while (module_ids.length > 0) {
600 | var ready = get_ready_modules(module_ids, module_links);
601 | program = program.concat(ready);
602 | var filtered_module_ids = module_ids.filter(function(value, index, arr) {
603 | var n;
604 | var is_in = false;
605 | for (n=0; n < ready.length; n++) {
606 | if (value == ready[n]) {
607 | is_in = true;
608 | break;
609 | }
610 | }
611 | return !is_in;
612 | });
613 | var filtered_module_links = module_links.filter(function(value, index, arr) {
614 | var n;
615 | var is_in = false;
616 | for (n=0; n < ready.length; n++) {
617 | if (value[0] == ready[n] || value[1] == ready[n]) {
618 | is_in = true;
619 | break;
620 | }
621 | }
622 | return !is_in;
623 | });
624 | module_ids = filtered_module_ids;
625 | module_links = filtered_module_links;
626 | }
627 | // assert: program is module ids in executable order
628 | var prog_json = "";
629 | var is_first = true;
630 | console.log(program);
631 | for (i = 0; i < program.length; i++) {
632 | var current = program[i];
633 | var current_module = get_module_by_id(current)
634 | var module_json = module_to_json(current_module);
635 | if (is_first) {
636 | prog_json = module_json;
637 | }
638 | else {
639 | prog_json = prog_json + "," + module_json;
640 | }
641 | console.log(prog_json);
642 | is_first = false;
643 | }
644 | // Put brackets around the json
645 | prog_json = "[" + prog_json + "]";
646 | // Call python
647 | (async function() {
648 | const result = await google.colab.kernel.invokeFunction(
649 | 'notebook.python_save_hook', // The callback name.
650 | [prog_json, filename], // The arguments.
651 | {}); // kwargs
652 | const res = result.data['application/json'];
653 | //document.querySelector("#output-area").appendChild(document.createTextNode(text.result));
654 | })();
655 | }
656 |
657 | function get_module_by_id(id) {
658 | var i = 0;
659 | for (i=0; i < modules.length; i++) {
660 | if (modules[i].id == id) {
661 | return modules[i];
662 | }
663 | }
664 | return null;
665 | }
666 |
667 |
668 | function get_ready_modules(mods, cons) {
669 | var ready = [];
670 | var i;
671 | for (i=0; i < mods.length; i++) {
672 | var current = mods[i];
673 | var is_ready = true;
674 | var j;
675 | for (j=0; j < cons.length; j++) {
676 | if (cons[j][1] == current) {
677 | is_ready = false;
678 | break;
679 | }
680 | }
681 | if (is_ready) {
682 | ready.push(current);
683 | }
684 | }
685 | return ready;
686 | }
687 |
688 | function module_to_json(module) {
689 | var json = {};
690 | json["module"] = module.type;
691 | json["name"] = module.name;
692 | json["x"] = module.x;
693 | json["y"] = module.y;
694 | json["id"] = module.id;
695 | json["collapsed"] = module.collapsed;
696 | // Parameters
697 | var i=0;
698 | for (i=0; i < module.params.length; i++) {
699 | param = module.params[i];
700 | json[param.name] = ""+param.value; // Make all values strings
701 | }
702 | return JSON.stringify(json);
703 | }
704 |
705 | function clear_program() {
706 | modules = [];
707 | parameters = [];
708 | connections = [];
709 | clickables = [];
710 | dragables = [];
711 | module_counters = {};
712 | module_counter = 0;
713 | parameter_counter = 0;
714 | drag_offset_x = 0;
715 | drag_offset_y = 0;
716 | drag = null;
717 | clicked = null;
718 | open_param = null;
719 | }
720 |
721 | function load_program(loadfile="") {
722 | // Need to clear the existing program
723 | clear_program();
724 | var filename = loadfile;
725 | if (filename.length == 0) {
726 | // Filename wasn't passed in so we need to get the filename from the html
727 | var input_load = document.getElementById("inp_load");
728 | if (input_load.value.length > 0) {
729 | filename = input_load.value;
730 | }
731 | else {
732 | return;
733 | }
734 | }
735 | // ASSERT: filename is non-empty
736 | // Call python
737 | async function foo() {
738 | const result = await google.colab.kernel.invokeFunction(
739 | 'notebook.python_load_hook', // The callback name.
740 | [filename], // The arguments.
741 | {}); // kwargs
742 | //program = result.data['application/json'];
743 | //res = result.data['application/json'];
744 | //console.log(result);
745 | //document.querySelector("#output-area").appendChild(document.createTextNode(text.result));
746 | return result;
747 | };
748 | foo().then(function(value) {
749 | var program = eval(value.data['application/json'].result);
750 | console.log(program);
751 | // what to do with the program
752 | var i;
753 | var m;
754 | // Iterate through each module in the program
755 | for (m = 0; m < program.length; m++) {
756 | var module_json = program[m]; // The json for this part of the program
757 | var mod = null;
758 | var id = module_counter;
759 | if ("id" in module_json) {
760 | id = module_json["id"];
761 | }
762 | // Find the corresponding module definition
763 | for (i = 0; i < module_dicts.length; i++) {
764 | var module_dict = module_dicts[i];
765 |
766 | if ("name" in module_dict) {
767 | var name = module_dict["name"];
768 | if (name == module_json.module) {
769 | // Make the module
770 | mod = make_module(module_dict, id);
771 | // Except it only has default values
772 | break;
773 | }
774 | }
775 | } // end i
776 | // Move the module to the saved location
777 | mod.x = module_json.x;
778 | mod.y = module_json.y;
779 | if ("collapsed" in module_json) {
780 | mod.collapsed = module_json.collapsed;
781 | }
782 | // Update default parameters
783 | // Iterate through each of the specs from the file
784 | var module_keys = Object.keys(module_json);
785 | for (i = 0; i < module_keys.length; i++) {
786 | var module_key = module_keys[i];
787 | var module_val = module_json[module_key];
788 | var p;
789 | // Find the corresponding parameter
790 | for (p = 0; p < mod.params.length; p++) {
791 | var param = mod.params[p];
792 | if (module_key === param.name) {
793 | // Found the right parameter
794 | param.value = module_val;
795 | break;
796 | }
797 | }
798 | }
799 | } //end m
800 | // Make connections
801 | console.log(parameters);
802 | var i;
803 | var j;
804 | // Iterate through all parameters, looking for matching file names
805 | for (i = 0; i < parameters.length; i++) {
806 | for (j = 0; j < parameters.length; j++) {
807 | var param1 = parameters[i];
808 | var param2 = parameters[j];
809 | if (param1.is_out && param2.is_in) {
810 | // param1 has an outlink and param2 has an inlink
811 | var fname1 = param1.value;
812 | var fname2 = param2.value;
813 | if (fname1 === fname2) {
814 | // Make a connection
815 | c = new connection(param1, param2);
816 | c.id = connections.length;
817 | connections.push(c);
818 | param1.connected = true;
819 | param2.connected = true;
820 | }
821 | }
822 | }
823 | }
824 | // update module count to be max id
825 | var mod;
826 | var max_id = 0;
827 | for (mod = 0; mod < modules.length; mod++) {
828 | cur_id = modules[mod].id;
829 | if (cur_id > max_id) {
830 | max_id = cur_id;
831 | }
832 | }
833 | module_counter = max_id+1;
834 | }); // end then function
835 | // Anything after this is not guaranteed to execute after the file is loaded.
836 | }
837 |
838 | function delete_module(module) {
839 | var filtered_modules = modules.filter(function(value, index, arr) {
840 | return value.id != module.id;
841 | });
842 | var filtered_connections = connections.filter(function(value, index, arr) {
843 | return value.origin.parent.id != module.id && value.target.parent.id != module.id;
844 | });
845 | var deleted_connections = connections.filter(function(value, index, err) {
846 | return value.origin.parent.id == module.id || value.target.parent.id == module.id;
847 | });
848 | var i;
849 | for (i = 0; i < deleted_connections.length; i++) {
850 | var c = deleted_connections[i];
851 | c.origin.connected = false;
852 | c.target.connected = false;
853 | }
854 | modules = filtered_modules;
855 | connections = filtered_connections;
856 | }
857 |
858 | // START THE GUI ///////////////////////////////////////////
859 | startGame()
--------------------------------------------------------------------------------