├── .gitignore ├── Dockerfile ├── api.py ├── display.py ├── docker-compose.yml ├── load-data ├── data_loader.py ├── select-langs.rs └── tokens.rs ├── network.py ├── readme.md ├── requirements.txt ├── sentence_parser.py ├── static └── index.html └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/select-langs 2 | **/tokens 3 | **/*.dSYM 4 | **/__pycache__ 5 | save/cache 6 | cache 7 | web.log 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | WORKDIR /usr/src/app/ 3 | COPY requirements.txt /usr/src/ 4 | RUN apt-get update 5 | RUN pip install -r /usr/src/requirements.txt 6 | RUN pip install torch==1.8.0 -f https://download.pytorch.org/whl/torch_stable.html 7 | CMD [ "python", "api.py", "3080" ] 8 | -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import asyncio 4 | import time 5 | 6 | from aiohttp import web 7 | import logging 8 | 9 | import torch 10 | import numpy as np 11 | 12 | logging.basicConfig( 13 | level=logging.INFO, 14 | format="%(asctime)s [%(levelname)s] %(message)s", 15 | handlers=[ 16 | logging.FileHandler("web.log"), 17 | logging.StreamHandler() 18 | ] 19 | ) 20 | 21 | 22 | logging.info("started") 23 | 24 | try: 25 | from sentence_parser import STYPE_SEC, STYPE_AUX, PRIM_GL, SEC_GL, AUX_GL 26 | from network import into_one_hot, generate_batch, load_from_save 27 | 28 | enc, sec_dec, aux_dec, *_ = load_from_save() 29 | 30 | LOADED = True 31 | except FileNotFoundError as e: 32 | logging.warning(f"Network not loaded: {e}") 33 | 34 | LOADED = False 35 | 36 | def make_json_response(data, status=200): 37 | return web.Response( 38 | body=json.dumps(data), 39 | content_type='application/json', 40 | status=status, 41 | ) 42 | 43 | 44 | class WebInterface: 45 | def __init__(self, app): 46 | self.app = app 47 | 48 | app.router.add_post("/translate", self.translate) 49 | 50 | self.currently_blocked_users = set() 51 | 52 | async def translate(self, req): 53 | if not LOADED: 54 | return make_json_response({"error": "Network not loaded. Please contact coral if this happens."}, status=500) 55 | 56 | ip, port = req.transport.get_extra_info("peername") 57 | if ip in self.currently_blocked_users: 58 | logging.info(f"Too many requests for {ip}") 59 | return make_json_response({"error": "Too many requests. Try again in a few seconds!"}, status=400) 60 | 61 | self.currently_blocked_users.add(ip) 62 | try: 63 | await asyncio.sleep(1) 64 | 65 | data = await req.json() 66 | 67 | start = time.time() 68 | 69 | bpe = PRIM_GL.str_to_bpe(data["input"]) 70 | xs = torch.LongTensor([bpe]) 71 | 72 | confidence_boost = data.get("confidence_boost", 1) 73 | confidence_boost = min(3, max(-3, float(confidence_boost))) 74 | 75 | logging.info(f"Translating {repr(data)}, confidence boost = {confidence_boost}") 76 | 77 | eof_idx = -1 78 | did_cuttof = False 79 | for e in range(5): 80 | ylen = 5 * 2 ** e 81 | ys = torch.LongTensor([[-1] * ylen]) 82 | 83 | hid = enc(xs) 84 | outs, atts, hard_outs = sec_dec(hid, ys, teacher_forcing_prob=0, choice=True, confidence_boost=confidence_boost) 85 | out, att, hard_out = outs[0], atts[0], hard_outs[0] 86 | 87 | hout_eofs = (hard_out == SEC_GL.n_tokens() - 1).nonzero() 88 | if len(hout_eofs) == 0: 89 | eof_idx = ylen 90 | continue 91 | eof_idx = hout_eofs[0] 92 | did_cuttof = True 93 | 94 | out = torch.exp(out[:eof_idx]) 95 | out /= out.sum(axis=1).unsqueeze(1) 96 | hard_out = hard_out[:eof_idx] 97 | att = att[:eof_idx] 98 | 99 | confidences = torch.gather(out, 1, hard_out.view(-1, 1)) 100 | confidence = confidences.prod().item() 101 | 102 | hy_words = [SEC_GL.bpe_to_str([word]) for word in hard_out] 103 | 104 | if did_cuttof: 105 | out = "".join(hy_words) 106 | else: 107 | out = "".join(hy_words) + "..." 108 | 109 | end = time.time() 110 | took = end - start 111 | logging.info(f"Got {repr(out)}, conf = {confidence}. Took {took} seconds") 112 | return make_json_response({"result": out, "confidence": confidence, "duration": took}) 113 | finally: 114 | self.currently_blocked_users.remove(ip) 115 | 116 | def run(self, port): 117 | web.run_app(self.app, port=port) 118 | 119 | 120 | app = web.Application() 121 | 122 | WEB_STATE = WebInterface(app) 123 | 124 | if len(sys.argv) == 2: 125 | WEB_STATE.run(port=int(sys.argv[1])) 126 | else: 127 | WEB_STATE.run(port=8080) 128 | -------------------------------------------------------------------------------- /display.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | from sentence_parser import STYPE_SEC, STYPE_AUX, PRIM_GL, SEC_GL, AUX_GL 7 | from network import into_one_hot, generate_batch, load_from_save 8 | 9 | enc, sec_dec, aux_dec, *_ = load_from_save() 10 | 11 | sec_info = ("sec", SEC_GL, sec_dec, STYPE_SEC) 12 | aux_info = ("aux", AUX_GL, aux_dec, STYPE_AUX) 13 | 14 | for name, gl, dec, stype in [sec_info, aux_info]: 15 | xs, ys = generate_batch(5, stype, max_length=-1) 16 | 17 | extra = input("Your own phrase> ") 18 | bpe = PRIM_GL.str_to_bpe(extra) 19 | bpe += [-1] * (xs.size(1) - len(bpe)) 20 | bpe = torch.LongTensor([bpe]) 21 | 22 | y = [-1] * ys.size(1) 23 | y = torch.LongTensor([y]) 24 | 25 | xs = torch.cat((xs, bpe), axis=0) 26 | ys = torch.cat((ys, y), axis=0) 27 | 28 | hid = enc(xs) 29 | outs, atts, hard_outs = dec(hid, ys, 0, choice=True) 30 | 31 | for i in range(len(xs)): 32 | plt.subplot(3, 2, i + 1) 33 | x = xs[i] 34 | 35 | y = ys[i] 36 | out = outs[i] 37 | att = atts[i] 38 | hard_out = hard_outs[i] 39 | 40 | x_eofs = (x == -1).nonzero() 41 | if len(x_eofs) > 0: 42 | x = x[:x_eofs[0]] 43 | att = att[:, :x_eofs[0]] 44 | 45 | y_eofs = (y == -1).nonzero() 46 | if len(y_eofs) > 0: 47 | y = y[:y_eofs[0]] 48 | 49 | hout_eofs = (hard_out == gl.n_tokens() - 1).nonzero() 50 | if len(hout_eofs) > 0: 51 | out = out[:hout_eofs[0]] 52 | hard_out = hard_out[:hout_eofs[0]] 53 | att = att[:hout_eofs[0]] 54 | 55 | x_words = [PRIM_GL.bpe_to_str([word]) for word in x] 56 | y_words = [gl.bpe_to_str([word]) for word in y] 57 | hy_words = [gl.bpe_to_str([word]) for word in hard_out] 58 | 59 | print() 60 | print("/".join(hy_words), " <- ", "/".join(x_words)) 61 | print("/".join(y_words)) 62 | 63 | plt.imshow(att.detach().numpy()) 64 | plt.xticks(np.arange(len(x)), x_words, rotation="vertical") 65 | plt.yticks(np.arange(len(hy_words)), hy_words) 66 | 67 | plt.show() 68 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | services: 3 | ilo-pi-ante-toki: 4 | build: . 5 | restart: on-failure 6 | volumes: 7 | - ./:/usr/src/app/ 8 | ports: 9 | - "3080:3080" 10 | -------------------------------------------------------------------------------- /load-data/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import requests 4 | import time 5 | import tarfile 6 | import tempfile 7 | 8 | SENTENCES_URL = "https://downloads.tatoeba.org/exports/sentences.tar.bz2" 9 | LINKS_URL = "https://downloads.tatoeba.org/exports/links.tar.bz2" 10 | 11 | CACHE_DIR = "cache/raw/" 12 | TMP_DIR = tempfile.gettempdir() 13 | 14 | CACHE_DIR = os.path.expanduser(CACHE_DIR) 15 | 16 | 17 | if not os.path.exists(CACHE_DIR): 18 | os.makedirs(CACHE_DIR) 19 | 20 | 21 | def download_file(data_url, output_file): 22 | with requests.get(data_url, stream=True) as resp: 23 | 24 | size = int(resp.headers["Content-Length"]) 25 | print(f"Loading {size} bytes from {data_url}") 26 | 27 | block_size = 1024 # 1 kiB 28 | 29 | with tqdm(total=size, unit="iB", unit_scale=True) as prog: 30 | 31 | start = time.time() 32 | for block in resp.iter_content(block_size): 33 | prog.update(len(block)) 34 | output_file.write(block) 35 | 36 | took = time.time() - start 37 | 38 | print(f"Loaded {size} compressed bytes in {took:.4} seconds") 39 | 40 | def extract_file(file_tar, output_file, file_to_extract): 41 | extracting = tarfile.open(fileobj=file_tar, mode="r:bz2") 42 | 43 | size = extracting.getmember(file_to_extract).size 44 | print(f"Extracting {size} bytes") 45 | 46 | buf = extracting.extractfile(file_to_extract) 47 | 48 | block_size = 1024 49 | 50 | with tqdm(total=size, unit="iB", unit_scale=True) as prog: 51 | start = time.time() 52 | while True: 53 | read = buf.read(block_size) 54 | prog.update(len(read)) 55 | 56 | if len(read) == 0: 57 | break 58 | output_file.write(read) 59 | 60 | took = time.time() - start 61 | 62 | print(f"Extracted {size} bytes in {took:.4} seconds") 63 | 64 | 65 | 66 | def plural(n, sing, plur): 67 | if n == 1: 68 | return f"{n} {sing}" 69 | return f"{n} {plur}" 70 | 71 | def format_time(n_seconds): 72 | if n_seconds >= 60 * 60 * 24: 73 | # Days + hours 74 | major_length = 60 * 60 * 24 75 | minor_length = 60 * 60 76 | maj_sing, maj_plur = "day", "days" 77 | min_sing, min_plur = "hour", "hours" 78 | 79 | elif n_seconds >= 60 * 60: 80 | # Hours + minutes 81 | major_length = 60 * 60 82 | minor_length = 60 83 | maj_sing, maj_plur = "hour", "hours" 84 | min_sing, min_plur = "minute", "minutes" 85 | 86 | elif n_seconds >= 60: 87 | # Minutes + seconds 88 | major_length = 60 89 | minor_length = 1 90 | maj_sing, maj_plur = "minute", "minutes" 91 | min_sing, min_plur = "second", "seconds" 92 | 93 | else: 94 | # Just seconds 95 | return plural(int(n_seconds + 0.5), "second", "seconds") 96 | 97 | 98 | n_maj = int(n_seconds // major_length) 99 | n_min = (n_seconds / minor_length) % (major_length / minor_length) 100 | 101 | n_min = int(n_min + 0.5) 102 | 103 | s_maj = plural(n_maj, maj_sing, maj_plur) 104 | if n_min == 0: 105 | return s_maj 106 | 107 | s_min = plural(n_min, min_sing, min_plur) 108 | return f"{s_maj} and {s_min}" 109 | 110 | compressed_sentence_path = os.path.join(TMP_DIR, "sentences.tar.bz2") 111 | sentence_path = os.path.join(CACHE_DIR, "sentences.tsv") 112 | 113 | compressed_link_path = os.path.join(TMP_DIR, "links.tar.bz2") 114 | link_path = os.path.join(CACHE_DIR, "links.tsv") 115 | 116 | if __name__ == "__main__": 117 | if os.path.isfile(sentence_path) and os.path.isfile(link_path): 118 | file_age = time.time() - os.path.getmtime(sentence_path) 119 | print(f"The compressed data is already downloaded, but {format_time(file_age)} old") 120 | if input("Redownload? [Y/n] ").lower() == "n": 121 | print("Quitting") 122 | exit() 123 | 124 | print("Loading sentences") 125 | with open(compressed_sentence_path, "bw") as comp_file: 126 | download_file(SENTENCES_URL, comp_file) 127 | 128 | with open(compressed_sentence_path, "rb") as comp_file, open(sentence_path, "wb") as out_file: 129 | extract_file(comp_file, out_file, "sentences.csv") 130 | 131 | print("Loading links") 132 | with open(compressed_link_path, "bw") as comp_file: 133 | download_file(LINKS_URL, comp_file) 134 | 135 | with open(compressed_link_path, "rb") as comp_file, open(link_path, "wb") as out_file: 136 | extract_file(comp_file, out_file, "links.csv") 137 | 138 | 139 | print("Done") 140 | -------------------------------------------------------------------------------- /load-data/select-langs.rs: -------------------------------------------------------------------------------- 1 | // This program removes all sentences which does not either have primary-secondary pair, or primary-auxiliary pair 2 | // It then separates each language's sentences into separate files, and creates a two links files with information about 3 | // where each sentence starts and ends in each language. 4 | 5 | // The generated sentence files contains every sentence back to back, with no separators. 6 | // The links files contains links between the sentences. Each sentence link is encoded as 7 | // four numbers. The four numbers describe the start of the primary sentence, the end of it, the start of the 8 | // secondary/auxiliary sentence and the end of it. 9 | // In the binary format, each number is encoded as a 32-bit unsigned integer 10 | 11 | mod tokens; 12 | 13 | use std::io::{Write, Result, BufWriter, BufReader, BufRead, Error, ErrorKind}; 14 | use std::fs::File; 15 | use std::env::var; 16 | use std::collections::{HashMap, HashSet}; 17 | use std::convert::TryInto; 18 | 19 | const PRIM_LANGUAGE: &str = "eng"; 20 | const SEC_LANGUAGE: &str = "toki"; 21 | const AUX_LANGUAGE: &str = "spa"; 22 | 23 | const REL_LIM: f64 = 0.0001; 24 | 25 | // Described above 26 | // TODO: Either make this a compile-time flag, or a CLI-argument 27 | const BINARY_MODE: bool = true; 28 | 29 | fn write_binary_number_to_file(file: &mut F, number: u32) -> Result<()> { 30 | let buf = number.to_le_bytes(); 31 | file.write(&buf)?; 32 | 33 | Ok(()) 34 | } 35 | 36 | fn write_ascii_number_to_file(file: &mut F, number: u32) -> Result<()> { 37 | let txt = format!("{} ", number); 38 | file.write(txt.as_bytes())?; 39 | 40 | Ok(()) 41 | } 42 | 43 | fn write_number_to_file(file: &mut F, number: u32) -> Result<()> { 44 | if BINARY_MODE { 45 | write_binary_number_to_file(file, number) 46 | } else { 47 | write_ascii_number_to_file(file, number) 48 | } 49 | } 50 | 51 | #[derive(Debug)] 52 | struct Translation { 53 | prim_language: HashMap, 54 | sec_language: HashMap, 55 | aux_language: HashMap, 56 | links: HashSet<(u32, u32)>, // Always (prim, sec) or (prim, aux) 57 | } 58 | 59 | impl Translation { 60 | fn new() -> Self { 61 | Translation { 62 | prim_language: HashMap::new(), 63 | sec_language: HashMap::new(), 64 | aux_language: HashMap::new(), 65 | links: HashSet::new(), 66 | } 67 | } 68 | } 69 | 70 | impl Translation> { 71 | fn consume_sentences(&mut self, mut file: F) -> Result { 72 | let mut counter = 0; 73 | 74 | loop { 75 | let mut id_buf = Vec::new(); 76 | file.read_until('\t' as u8, &mut id_buf)?; 77 | 78 | // Remove tab, quit if not found 79 | if id_buf.pop() != Some('\t' as u8) { 80 | break; 81 | } 82 | 83 | let mut language_buf = Vec::new(); 84 | file.read_until('\t' as u8, &mut language_buf)?; 85 | language_buf.pop(); // Remove tab 86 | let language = String::from_utf8_lossy(&language_buf); 87 | 88 | let mut sentence = Vec::new(); 89 | file.read_until('\n' as u8, &mut sentence)?; 90 | sentence.pop(); // Remove newline 91 | 92 | if language != PRIM_LANGUAGE && language != SEC_LANGUAGE && language != AUX_LANGUAGE { 93 | // println!("Skip"); 94 | continue; 95 | } 96 | 97 | let id_st = String::from_utf8(id_buf).unwrap(); 98 | let id_n: u32 = id_st.parse().unwrap(); 99 | 100 | // println!("ID: {:?}, Language: {:?}, Sentence: {:?}", id_n, language, sentence); 101 | 102 | let list_to_add = 103 | match &*language { 104 | PRIM_LANGUAGE => &mut self.prim_language, 105 | SEC_LANGUAGE => &mut self.sec_language, 106 | AUX_LANGUAGE => &mut self.aux_language, 107 | _ => unreachable!() // We checked before that language belonged to one of the previous alternatives 108 | }; 109 | 110 | list_to_add.insert(id_n, sentence.to_owned()); 111 | counter += 1; 112 | } 113 | 114 | Ok(counter) 115 | } 116 | 117 | fn consume_links(&mut self, mut file: F, remove_unlinked: bool) -> Result<(usize, usize)> { 118 | let mut n_read = 0; 119 | let mut n_wrong = 0; 120 | 121 | let mut prim_ids = HashSet::new(); 122 | let mut sec_ids = HashSet::new(); 123 | let mut aux_ids = HashSet::new(); 124 | 125 | loop { 126 | let mut first_buf = Vec::new(); 127 | file.read_until('\t' as u8, &mut first_buf)?; 128 | 129 | // Remove tab, quit if not found 130 | if first_buf.pop() != Some('\t' as u8) { 131 | break; 132 | } 133 | 134 | let mut second_buf = Vec::new(); 135 | file.read_until('\n' as u8, &mut second_buf)?; 136 | second_buf.pop(); 137 | 138 | 139 | let first_st = String::from_utf8(first_buf).unwrap(); 140 | let first_n: u32 = first_st.parse().unwrap(); 141 | 142 | let second_st = String::from_utf8(second_buf).unwrap(); 143 | let second_n: u32 = second_st.parse().unwrap(); 144 | 145 | match (self.prim_language.contains_key(&first_n), self.sec_language.contains_key(&second_n), self.aux_language.contains_key(&second_n)) { 146 | (true, true, false) => { 147 | self.links.insert((first_n, second_n)); 148 | prim_ids.insert(first_n); 149 | sec_ids.insert(second_n); 150 | n_read += 1; 151 | continue; 152 | } 153 | (true, false, true) => { 154 | self.links.insert((first_n, second_n)); 155 | prim_ids.insert(first_n); 156 | aux_ids.insert(second_n); 157 | n_read += 1; 158 | continue; 159 | } 160 | _ => {} 161 | } 162 | match (self.prim_language.contains_key(&second_n), self.sec_language.contains_key(&first_n), self.aux_language.contains_key(&first_n)) { 163 | (true, true, false) => { 164 | self.links.insert((second_n, first_n)); 165 | prim_ids.insert(second_n); 166 | sec_ids.insert(first_n); 167 | n_read += 1; 168 | continue; 169 | } 170 | (true, false, true) => { 171 | self.links.insert((second_n, first_n)); 172 | prim_ids.insert(second_n); 173 | aux_ids.insert(first_n); 174 | n_read += 1; 175 | continue; 176 | } 177 | _ => {} 178 | } 179 | n_wrong += 1; 180 | } 181 | 182 | if remove_unlinked { 183 | println!("Keeping {:?}/{:?}/{:?}", prim_ids.len(), sec_ids.len(), aux_ids.len()); 184 | self.prim_language.retain(|&id, _| prim_ids.contains(&id)); 185 | self.sec_language.retain(|&id, _| sec_ids.contains(&id)); 186 | self.aux_language.retain(|&id, _| aux_ids.contains(&id)); 187 | } 188 | 189 | Ok((n_read, n_wrong)) 190 | } 191 | 192 | fn stringify(self) -> Result> { 193 | let prim_language = 194 | self.prim_language 195 | .into_iter() 196 | .map(|(k, v)| { 197 | let st = String::from_utf8(v).map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8!"))?; 198 | Ok((k, st)) 199 | }) 200 | .collect::>>()?; 201 | 202 | let sec_language = 203 | self.sec_language 204 | .into_iter() 205 | .map(|(k, v)| { 206 | let st = String::from_utf8(v).map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8!"))?; 207 | Ok((k, st)) 208 | }) 209 | .collect::>>()?; 210 | 211 | let aux_language = 212 | self.aux_language 213 | .into_iter() 214 | .map(|(k, v)| { 215 | let st = String::from_utf8(v).map_err(|_| Error::new(ErrorKind::InvalidData, "Invalid UTF-8!"))?; 216 | Ok((k, st)) 217 | }) 218 | .collect::>>()?; 219 | 220 | Ok(Translation { 221 | prim_language, sec_language, aux_language, 222 | links: self.links, 223 | }) 224 | } 225 | } 226 | 227 | impl Translation { 228 | fn write_links(&self, file: &mut F, id_offset_size: &HashMap, write_secondary: bool) -> Result<()> { 229 | for &(prim_id, other_id) in &self.links { 230 | match (self.sec_language.contains_key(&other_id), write_secondary) { 231 | (true, false) | (false, true) => continue, 232 | _ => {} 233 | } 234 | 235 | let (prim_offset, prim_len) = id_offset_size.get(&prim_id).unwrap(); 236 | let (other_offset, other_len) = id_offset_size.get(&other_id).unwrap(); 237 | 238 | write_number_to_file(file, *prim_offset as u32)?; 239 | write_number_to_file(file, *prim_len as u32)?; 240 | write_number_to_file(file, *other_offset as u32)?; 241 | write_number_to_file(file, *other_len as u32)?; 242 | } 243 | Ok(()) 244 | } 245 | } 246 | 247 | fn gramify_sentences(sents: HashMap) -> (HashMap>, Gramophone) { 248 | let gram = Gramophone::from_word_iter( 249 | sents 250 | .values() 251 | .map(|x| x.chars()) 252 | ); 253 | let grammed_sents = 254 | sents 255 | .into_iter() 256 | .map(|(k, sent)| (k, gram.encode_text(sent.chars()))) 257 | .collect::>(); 258 | 259 | (grammed_sents, gram) 260 | } 261 | 262 | impl Translation { 263 | fn gramify(self) -> (Translation>, Gramophone, Gramophone, Gramophone) { 264 | let (prim_language, prim_gram) = gramify_sentences(self.prim_language); 265 | let (sec_language, sec_gram) = gramify_sentences(self.sec_language); 266 | let (aux_language, aux_gram) = gramify_sentences(self.aux_language); 267 | 268 | let trans = Translation { 269 | prim_language, sec_language, aux_language, 270 | links: self.links, 271 | }; 272 | 273 | (trans, prim_gram, sec_gram, aux_gram) 274 | } 275 | } 276 | 277 | impl Translation> { 278 | fn write_sentences(&self, file: &mut F, from: u8) -> Result> { 279 | let mut id_offset_size: HashMap = HashMap::new(); 280 | let mut offset = 0; 281 | 282 | let sentences = match from { 283 | 0 => &self.prim_language, 284 | 1 => &self.sec_language, 285 | 2 => &self.aux_language, 286 | _ => unimplemented!(), 287 | }; 288 | 289 | for (&id, sentence) in sentences { 290 | id_offset_size.insert(id, (offset, sentence.len() * 2)); 291 | 292 | for &point in sentence { 293 | let point_u16: u16 = point.try_into().map_err(|_| Error::new(ErrorKind::InvalidData, format!("{} is too large to fit in a u16!", point)))?; 294 | file.write(&point_u16.to_le_bytes())?; 295 | } 296 | offset += sentence.len() * 2; 297 | } 298 | 299 | Ok(id_offset_size) 300 | } 301 | } 302 | 303 | struct Gramophone { 304 | grams: Vec>, 305 | i2idx: HashMap, 306 | } 307 | 308 | impl Gramophone { 309 | // Assumes no zeros in iter 310 | fn from_word_iter< 311 | I: IntoIterator, 312 | J: IntoIterator, 313 | >(iter: I) -> Gramophone { 314 | let mut inp = Vec::new(); 315 | for word in iter { 316 | inp.extend(word.into_iter().flat_map(char::to_lowercase)); 317 | inp.push('\0'); 318 | } 319 | 320 | let (_, grams) = tokens::encode_into_ngrams(inp, REL_LIM, |&x| x != '\0' && x.is_alphabetic()); 321 | 322 | let mut i2idx = HashMap::new(); 323 | for (idx, gram) in grams.iter().enumerate() { 324 | if let &tokens::Gram::Orig(i) = gram { 325 | i2idx.insert(i, idx); 326 | } 327 | } 328 | 329 | Gramophone { 330 | grams, 331 | i2idx, 332 | } 333 | } 334 | 335 | fn encode_text>(&self, text: I) -> Vec { 336 | let mut tokens = Vec::new(); 337 | for ch in text.into_iter().flat_map(char::to_lowercase) { 338 | tokens.push(*self.i2idx.get(&ch).expect(&format!("no such char: {:?}", ch))); 339 | } 340 | 341 | for (i, gram) in self.grams.iter().enumerate() { 342 | let (a, b) = if let &tokens::Gram::Composition(a, b) = gram { 343 | (a, b) 344 | } else { 345 | continue; 346 | }; 347 | 348 | let mut contracted_tokens = Vec::new(); 349 | let mut at = 0; 350 | while at < tokens.len() { 351 | let here = tokens[at]; 352 | let next = tokens.get(at + 1); 353 | 354 | if here == a && next == Some(&b) { 355 | contracted_tokens.push(i); 356 | at += 2; 357 | } else { 358 | contracted_tokens.push(here); 359 | at += 1; 360 | } 361 | } 362 | 363 | tokens = contracted_tokens; 364 | } 365 | 366 | tokens 367 | } 368 | } 369 | 370 | fn get_cache_path(filename: &str) -> String { 371 | format!("cache/{}", filename) 372 | } 373 | 374 | fn main() -> Result<()> { 375 | let sentence_file = BufReader::new(File::open(get_cache_path("raw/sentences.tsv"))?); 376 | let links_file = BufReader::new(File::open(get_cache_path("raw/links.tsv"))?); 377 | 378 | let mut sentences = Translation::new(); 379 | 380 | println!("Consuming sentences"); 381 | sentences.consume_sentences(sentence_file)?; 382 | println!("Loaded {:?}/{:?}/{:?}", sentences.prim_language.len(), sentences.sec_language.len(), sentences.aux_language.len()); 383 | 384 | println!("Consuming links"); 385 | let (read, wrong) = sentences.consume_links(links_file, true)?; 386 | println!("Loaded {:?} links ({:?} were wrong)", read, wrong); 387 | 388 | println!("After filter {:?}/{:?}/{:?}", sentences.prim_language.len(), sentences.sec_language.len(), sentences.aux_language.len()); 389 | 390 | println!("Stringifying"); 391 | let sent_string = sentences.stringify()?; 392 | println!("Gramifying"); 393 | let (sent_ngram, prim_gram, sec_gram, aux_gram) = sent_string.gramify(); 394 | println!("{} / {} / {} grams", prim_gram.grams.len(), sec_gram.grams.len(), aux_gram.grams.len()); 395 | 396 | 397 | println!("Writing primary ngrams"); 398 | let mut prim_ngrams = BufWriter::new(File::create(get_cache_path("ngrams-prim.bin"))?); 399 | tokens::encode_grams(&mut prim_ngrams, prim_gram.grams)?; 400 | prim_ngrams.flush()?; 401 | 402 | println!("Writing secondary ngrams"); 403 | let mut sec_ngrams = BufWriter::new(File::create(get_cache_path("ngrams-sec.bin"))?); 404 | tokens::encode_grams(&mut sec_ngrams, sec_gram.grams)?; 405 | sec_ngrams.flush()?; 406 | 407 | println!("Writing auxiliary ngrams"); 408 | let mut aux_ngrams = BufWriter::new(File::create(get_cache_path("ngrams-aux.bin"))?); 409 | tokens::encode_grams(&mut aux_ngrams, aux_gram.grams)?; 410 | aux_ngrams.flush()?; 411 | 412 | println!("Writing primary sentences"); 413 | let mut prim_output = BufWriter::new(File::create(get_cache_path("sentences-prim.bin"))?); 414 | let mut meta = sent_ngram.write_sentences(&mut prim_output, 0)?; 415 | prim_output.flush()?; 416 | 417 | println!("Writing secondary sentences"); 418 | let mut sec_output = BufWriter::new(File::create(get_cache_path("sentences-sec.bin"))?); 419 | let sec_meta = sent_ngram.write_sentences(&mut sec_output, 1)?; 420 | sec_output.flush()?; 421 | 422 | meta.extend(sec_meta.into_iter()); 423 | 424 | println!("Writing auxiliary sentences"); 425 | let mut aux_output = BufWriter::new(File::create(get_cache_path("sentences-aux.bin"))?); 426 | let aux_meta = sent_ngram.write_sentences(&mut aux_output, 2)?; 427 | sec_output.flush()?; 428 | 429 | meta.extend(aux_meta.into_iter()); 430 | 431 | println!("Writing secondary links"); 432 | let mut links_output = BufWriter::new(File::create(get_cache_path("sec-links.bin"))?); 433 | sent_ngram.write_links(&mut links_output, &meta, true)?; 434 | links_output.flush()?; 435 | 436 | println!("Writing auxiliary links"); 437 | let mut links_output = BufWriter::new(File::create(get_cache_path("aux-links.bin"))?); 438 | sent_ngram.write_links(&mut links_output, &meta, false)?; 439 | links_output.flush()?; 440 | 441 | println!("Done!"); 442 | Ok(()) 443 | } 444 | -------------------------------------------------------------------------------- /load-data/tokens.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet}; 2 | use std::hash::Hash; 3 | use std::fmt::Debug; 4 | 5 | use std::io::Write; 6 | 7 | #[derive(Debug, Copy, Clone)] 8 | pub enum Gram { 9 | Orig(I), 10 | Composition(usize, usize), 11 | } 12 | 13 | // Retuns the tokenized text, along with a list of decompositions 14 | pub fn encode_into_ngrams bool>(inp: Vec, rel_lim: f64, can_pair: F) -> (Vec, Vec>) { 15 | // Convert the text into orig tokens 16 | 17 | let inp_len = inp.len() as f64; 18 | 19 | let mut tokens: Vec = Vec::new(); 20 | let mut grams = Vec::new(); 21 | 22 | let mut skips = HashSet::new(); 23 | 24 | let mut i2tok: HashMap = HashMap::new(); 25 | for i in inp { 26 | if !i2tok.contains_key(&i) { 27 | let idx = grams.len(); 28 | i2tok.insert(i, idx); 29 | grams.push(Gram::Orig(i)); 30 | 31 | if !can_pair(&i) { 32 | skips.insert(idx); 33 | } 34 | } 35 | 36 | tokens.push(*i2tok.get(&i).unwrap()); 37 | } 38 | 39 | // println!("Tokens: {:?}", tokens); 40 | // println!("Grams: {:?}", grams); 41 | 42 | for _ in 0..std::u16::MAX - tokens.len() as u16 { 43 | let pair_freq = get_pair_freq(&tokens, &skips); 44 | let (&commonest_pair, &freq) = if let Some(x) = pair_freq.iter().max_by_key(|(_, &count)| count) { 45 | x 46 | } else { 47 | eprintln!("Ran out of pairs: {:?} / {:?} - {:?}", tokens, grams, skips); 48 | return (tokens, grams); 49 | }; 50 | let rel_freq = freq as f64 / inp_len as f64; 51 | println!("Commonest: {:?}, freq {}, {}%", commonest_pair, freq, rel_freq); 52 | if rel_freq < rel_lim { 53 | break; 54 | } 55 | 56 | let new_gram_idx = grams.len(); 57 | grams.push(Gram::Composition(commonest_pair.0, commonest_pair.1)); 58 | 59 | let mut contracted_tokens = Vec::with_capacity(tokens.len() - freq); 60 | let mut at = 0; 61 | while at < tokens.len() { 62 | let here = tokens[at]; 63 | let next = tokens.get(at + 1); 64 | 65 | if here == commonest_pair.0 && next == Some(&commonest_pair.1) { 66 | contracted_tokens.push(new_gram_idx); 67 | at += 2; 68 | } else { 69 | contracted_tokens.push(here); 70 | at += 1; 71 | } 72 | } 73 | 74 | tokens = contracted_tokens; 75 | 76 | // println!("Tokens: {:?}", tokens); 77 | // println!("Grams: {:?}", grams); 78 | } 79 | 80 | (tokens, grams) 81 | } 82 | 83 | fn get_pair_freq(input: &[usize], skips: &HashSet) -> HashMap<(usize, usize), usize> { 84 | let mut counter = HashMap::new(); 85 | 86 | for (&a, &b) in input.iter().zip(input.iter().skip(1)) { 87 | if skips.contains(&a) || skips.contains(&b) { 88 | continue; // Don't combine over word boundaries 89 | } 90 | 91 | let count = counter.entry((a, b)).or_insert(0); 92 | *count += 1; 93 | } 94 | 95 | counter 96 | } 97 | 98 | fn decompose_sequence(mut tokens: Vec, grams: &[Gram]) -> Vec { 99 | // Make all tokens point into Orig-ngrams 100 | while { 101 | let mut new_tokens = Vec::new(); 102 | let mut all_origs = true; 103 | 104 | for &tok in &tokens { 105 | match grams[tok] { 106 | Gram::Orig(_) => { 107 | new_tokens.push(tok); 108 | } 109 | Gram::Composition(a, b) => { 110 | new_tokens.extend_from_slice(&[a, b]); 111 | all_origs = false; 112 | } 113 | } 114 | } 115 | 116 | tokens = new_tokens; 117 | 118 | !all_origs 119 | } {} 120 | 121 | // Now we know that grams[tok] == Gram::Orig(_) for all tok in tokens 122 | tokens 123 | .into_iter() 124 | .map(|g_idx| { 125 | match grams[g_idx] { 126 | Gram::Orig(ch) => ch, 127 | _ => unreachable!(), 128 | } 129 | }) 130 | .collect() 131 | } 132 | 133 | // Assumes ngrams are in topolocical order 134 | // Format: 135 | // Gram::Orig(ch): chl = utf-8 len of ch 136 | // offset: 0x0 0x1 0x2 0x3 0x4 0x5 0x6 0x7 0x8 137 | // value: chl --- ch... --- 0x0 0x0 0x0 0x0 138 | // 139 | // Gram::Composition(a, b) encodes: 140 | // offset: 0x0 0x1 0x2 0x3 0x4 0x5 0x6 0x7 0x8 141 | // value: 0x0 -------a------- -------b------- 142 | 143 | pub fn encode_grams(out: &mut F, grams: Vec>) -> std::io::Result<()> { 144 | // Assert topoorder 145 | for (i, gram) in grams.iter().enumerate() { 146 | match gram { 147 | &Gram::Composition(a, b) => debug_assert!(a < i && b < i), 148 | _ => {} 149 | } 150 | } 151 | 152 | for gram in grams { 153 | match gram { 154 | Gram::Orig(ch) => { 155 | let chl = ch.len_utf8(); 156 | out.write(&[chl as u8])?; // Safe to covert since ch.len_utf8() <= 4 157 | let mut ch_buf = [0u8; 8]; 158 | ch.encode_utf8(&mut ch_buf); 159 | out.write(&ch_buf)?; 160 | } 161 | Gram::Composition(a, b) => { 162 | out.write(&[0])?; 163 | 164 | let a_bytes = (a as u32).to_le_bytes(); 165 | let b_bytes = (b as u32).to_le_bytes(); 166 | 167 | out.write(&a_bytes)?; 168 | out.write(&b_bytes)?; 169 | } 170 | } 171 | } 172 | 173 | Ok(()) 174 | } 175 | 176 | #[allow(unused)] 177 | fn main() -> std::io::Result<()> { 178 | use std::io::Read; 179 | 180 | let mut f = std::fs::File::open("cache/sentences-sec.txt")?; 181 | let mut buf = Vec::new(); 182 | f.read_to_end(&mut buf)?; 183 | 184 | let inp = String::from_utf8_lossy(&buf).chars().collect(); 185 | 186 | println!("Encoding"); 187 | let (stream, grams) = encode_into_ngrams(inp, 0.001, |&_ch| true); 188 | 189 | // println!("{:?}", decompose_sequence(stream, &grams)); 190 | for i in 0..grams.len() { 191 | if !stream.contains(&i) { 192 | print!(" "); 193 | } else { 194 | print!("* "); 195 | } 196 | let expanded_gram = decompose_sequence(vec![i], &grams); 197 | println!("{} -> {:?}", i, expanded_gram.iter().collect::()); 198 | } 199 | 200 | let mut out = std::fs::File::create("/tmp/gramcode.bin")?; 201 | encode_grams(&mut out, grams)?; 202 | 203 | Ok(()) 204 | } 205 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from sentence_parser import load_one_pair, STYPE_SEC, PRIM_GL, SEC_GL 9 | 10 | torch.set_printoptions(precision=5) 11 | 12 | if torch.cuda.is_available(): 13 | device = torch.device('cuda') 14 | else: 15 | device = torch.device('cpu') 16 | 17 | print("starting using device", device) 18 | 19 | class Encoder(nn.Module): 20 | def __init__(self, input_size, emb_size, hidden_size): 21 | super(Encoder, self).__init__() 22 | 23 | self.emb_size = emb_size 24 | self.hidden_size = hidden_size 25 | self.input_size = input_size 26 | 27 | self.embedding = nn.Embedding(input_size, emb_size).to(device) 28 | 29 | self.rnn_right = nn.Linear(emb_size + hidden_size, hidden_size).to(device) 30 | self.rnn_left = nn.Linear(emb_size + hidden_size, hidden_size).to(device) 31 | 32 | def run(self, inp, run_right): 33 | current_hid = self.init_hidden(inp.size(0)) 34 | 35 | hidden_states = torch.FloatTensor(inp.size(0), inp.size(1), self.hidden_size).to(device) 36 | 37 | network_to_use = self.rnn_right if run_right else self.rnn_left 38 | 39 | for i in range(inp.size(1)): 40 | if not run_right: 41 | i = inp.size(1) - i - 1 42 | current_ch = inp[:, i] 43 | 44 | emb = self.embedding(current_ch % self.input_size) 45 | 46 | current_hid = network_to_use(torch.cat([emb, current_hid], axis=1)) 47 | current_hid = F.elu(current_hid) 48 | hidden_states[:, i] = current_hid 49 | 50 | return hidden_states 51 | 52 | def forward(self, inp): 53 | hidden_right = self.run(inp, True) 54 | hidden_left = self.run(inp, False) 55 | 56 | return torch.cat([hidden_right, hidden_left], axis=2) 57 | 58 | def init_hidden(self, batch_size): 59 | return torch.zeros((batch_size, self.hidden_size)).to(device) 60 | 61 | class Decoder(nn.Module): 62 | def __init__(self, output_size, emb_size, enc_hid_size, dec_hid_size): 63 | super(Decoder, self).__init__() 64 | 65 | self.emb_size = emb_size 66 | self.enc_hid_size = enc_hid_size 67 | self.dec_hid_size = dec_hid_size 68 | self.output_size = output_size 69 | 70 | self.energizer_l1 = nn.Linear(enc_hid_size + dec_hid_size, 20) 71 | self.energizer_l2 = nn.Linear(20, 1) 72 | 73 | self.embedding = nn.Embedding(output_size, emb_size) 74 | 75 | self.rnn = nn.Linear(emb_size + dec_hid_size + enc_hid_size, dec_hid_size) 76 | self.out = nn.Linear(emb_size + dec_hid_size + enc_hid_size, output_size) 77 | 78 | def forward(self, enc_hid, real_output, teacher_forcing_prob=0.5, choice=False, confidence_boost=1): 79 | batch_size = enc_hid.size(0) 80 | inp_size = enc_hid.size(1) 81 | out_size = real_output.size(1) 82 | 83 | last_char = torch.LongTensor(batch_size).to(device).zero_() 84 | last_hidden = self.init_hidden(batch_size) 85 | 86 | output = torch.FloatTensor(batch_size, out_size, self.output_size).to(device) 87 | weights_mat = torch.FloatTensor(batch_size, out_size, inp_size).to(device) 88 | 89 | hard_out = torch.LongTensor(batch_size, out_size).to(device) 90 | 91 | for i in range(out_size): 92 | # Generate energies 93 | 94 | energies = torch.FloatTensor(batch_size, inp_size).to(device) 95 | 96 | for j in range(inp_size): 97 | j_energies = self.energizer_l1(torch.cat([enc_hid[:, j], last_hidden], axis=1)) 98 | j_energies = F.elu(j_energies) 99 | j_energies = self.energizer_l2(j_energies) 100 | j_energies = F.elu(j_energies) 101 | energies[:, j] = j_energies[:, 0] 102 | 103 | summed = torch.exp(energies).sum(axis=1) 104 | 105 | weights = torch.exp(energies) / torch.unsqueeze(summed, 1).repeat(1, inp_size) 106 | 107 | weights_mat[:, i] = weights 108 | 109 | weights = torch.unsqueeze(weights, 2).repeat(1, 1, enc_hid.size(2)) 110 | context = (enc_hid * weights).sum(axis=1) 111 | 112 | last_ch = self.embedding(last_char % self.output_size) 113 | 114 | new_hidden = self.rnn(torch.cat([last_ch, last_hidden, context], axis=1)) 115 | new_hidden = F.elu(new_hidden) 116 | 117 | out = self.out(torch.cat([last_ch, new_hidden, context], axis=1)) 118 | out = F.elu(out) 119 | output[:, i] = out 120 | last_hidden = new_hidden 121 | 122 | # Update last_char 123 | if random.random() < teacher_forcing_prob: 124 | # Teacher forcing 125 | last_char_idx = real_output[:, i] 126 | else: 127 | if choice: 128 | dist = torch.exp(out * confidence_boost) 129 | last_char_idx = torch.multinomial(dist, num_samples=1)[:,0] 130 | else: 131 | last_char_idx = out.argmax(axis=1) 132 | 133 | hard_out[:, i] = last_char_idx 134 | 135 | last_char = last_char_idx.clone() 136 | 137 | return output, weights_mat, hard_out 138 | 139 | 140 | def init_hidden(self, batch_size): 141 | return torch.zeros((batch_size, self.dec_hid_size)).to(device) 142 | 143 | def into_one_hot(values, n_tokens): 144 | one_hot = torch.LongTensor(values.shape[0], values.shape[1], n_tokens, device=device).zero_() 145 | for i in range(values.size(0)): 146 | one_hot[i, torch.arange(values.size(1)), values[i] % n_tokens] = 1 147 | return one_hot 148 | 149 | def generate_batch(batch_size, other_stype, max_length=None): 150 | xs, ys = [], [] 151 | 152 | if max_length is None: 153 | max_length = 3 + random.expovariate(1 / 10) 154 | max_length = min(15, max(3, max_length)) 155 | 156 | min_length = int(max_length * 0.9 - 2) 157 | 158 | longest_x = longest_y = 0 159 | for i in range(batch_size): 160 | y = None 161 | while y is None or not (min_length < len(y) < max_length or min_length < len(x) < max_length): 162 | x, y = load_one_pair(other_stype) 163 | if max_length == -1: 164 | break 165 | 166 | longest_x = max(longest_x, len(x)) 167 | longest_y = max(longest_y, len(y)) 168 | 169 | xs.append(x) 170 | ys.append(y) 171 | 172 | # Pad each sentence to the appropriate length 173 | for x in xs: 174 | x += [-1] * (longest_x - len(x)) 175 | 176 | for y in ys: 177 | y += [-1] * (longest_y - len(y)) 178 | 179 | x_tensors = [] 180 | y_tensors = [] 181 | 182 | for x in xs: 183 | x_tensor = torch.LongTensor(x) 184 | x_tensors.append(x_tensor.view(1, -1)) 185 | 186 | for y in ys: 187 | y_tensor = torch.LongTensor(y) 188 | y_tensors.append(y_tensor.view(1, -1)) 189 | 190 | return torch.cat(x_tensors, dim=0).to(device), torch.cat(y_tensors, dim=0).to(device) 191 | 192 | ENCODER_SAVE = "save/enc.pth" 193 | SEC_DECODER_SAVE = "save/dec-sec.pth" 194 | AUX_DECODER_SAVE = "save/dec-aux.pth" 195 | SEC_OPT_SAVE = "save/opt-sec.pth" 196 | AUX_OPT_SAVE = "save/opt-aux.pth" 197 | 198 | 199 | def load_from_save(): 200 | enc = Encoder(PRIM_GL.n_tokens(), 600, 150) 201 | sec_dec = Decoder(SEC_GL.n_tokens(), 600, 300, 450) 202 | aux_dec = Decoder(SEC_GL.n_tokens(), 600, 300, 450) 203 | 204 | if os.path.isfile(ENCODER_SAVE) \ 205 | and os.path.isfile(SEC_DECODER_SAVE) \ 206 | and os.path.isfile(AUX_DECODER_SAVE) \ 207 | and os.path.isfile(SEC_OPT_SAVE) \ 208 | and os.path.isfile(AUX_OPT_SAVE): 209 | print("Loading from save") 210 | enc.load_state_dict(torch.load(ENCODER_SAVE, map_location=device)) 211 | sec_dec.load_state_dict(torch.load(SEC_DECODER_SAVE, map_location=device)) 212 | aux_dec.load_state_dict(torch.load(AUX_DECODER_SAVE, map_location=device)) 213 | 214 | enc = enc.to(device) 215 | sec_dec = sec_dec.to(device) 216 | aux_dec = aux_dec.to(device) 217 | 218 | sec_opt = torch.optim.Adam( 219 | list(enc.parameters()) + list(sec_dec.parameters()), 220 | lr=0.0005, 221 | ) 222 | 223 | aux_opt = torch.optim.Adam( 224 | list(enc.parameters()) + list(aux_dec.parameters()), 225 | lr=0.0005, 226 | ) 227 | 228 | sec_opt.load_state_dict(torch.load(SEC_OPT_SAVE, map_location=device)) 229 | aux_opt.load_state_dict(torch.load(AUX_OPT_SAVE, map_location=device)) 230 | 231 | else: 232 | enc = enc.to(device) 233 | sec_dec = sec_dec.to(device) 234 | aux_dec = aux_dec.to(device) 235 | 236 | sec_opt = torch.optim.Adam( 237 | list(enc.parameters()) + list(sec_dec.parameters()), 238 | lr=0.0005, 239 | ) 240 | 241 | aux_opt = torch.optim.Adam( 242 | list(enc.parameters()) + list(aux_dec.parameters()), 243 | lr=0.0005, 244 | ) 245 | 246 | return enc, sec_dec, aux_dec, sec_opt, aux_opt 247 | 248 | def save(enc, sec_dec, aux_dec, sec_opt, aux_opt): 249 | torch.save(enc.state_dict(), ENCODER_SAVE) 250 | torch.save(sec_dec.state_dict(), SEC_DECODER_SAVE) 251 | torch.save(aux_dec.state_dict(), AUX_DECODER_SAVE) 252 | torch.save(sec_opt.state_dict(), SEC_OPT_SAVE) 253 | torch.save(aux_opt.state_dict(), AUX_OPT_SAVE) 254 | 255 | 256 | if __name__ == "__main__": 257 | enc, sec_dec, aux_dec, _ = load_from_save() 258 | 259 | prim_sent, sec_sent = generate_batch(256, STYPE_SEC) #load_one_pair(STYPE_SEC) 260 | 261 | print("Prim :", prim_sent.size()) 262 | print("Sec :", sec_sent.size()) 263 | 264 | hid = enc(prim_sent) 265 | print("Hidden:", hid.size()) 266 | 267 | dec, weights, _ = sec_dec(hid, sec_sent) 268 | print("Dec :", dec.size()) 269 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ilo pi ante toki 2 | 3 | ilo pi ante toki is a translator based on transfer learning on top of the method proposed by in Neural Machine Translation by Jointly Learning to Align and Translate (Bahdanau et. al., arXiv:1409.0473 [cs.CL]). 4 | 5 | The transfer is based on training the network to translate from English to Spanish (but this can be changed in the data-loading script), and building on that to translate from English to toki pona. This works well because the dataset of English -> Spanish is quite large (about 200000 pairs), which gives the network a lot of context about English. The dataset of English -> toki pona is a lot smaller, about 13000 pairs. 6 | 7 | All sentences are loaded from [Tateoba](https://tatoeba.org), and are under the CC BY 2.0 FR license. 8 | 9 | ## Loading the data 10 | 11 | The first step is to download the raw data from Tatoeba, which is done with a Python script. We load the `sentences.tar.bz2` and `links.tar.bz2` files, and untar them. 12 | 13 | ```sh 14 | python3 load-data/data_loader.py 15 | ``` 16 | 17 | All files are by default put into the folder `cache/ilo-pi-ante-toki/`, which will be created automatically by `data_loader.py`. If you run Windows, you might need to change this path to something else. This needs to be done in all files separately. 18 | 19 | The uncompressed data is quite large, around 450MiB for the sentences and 250MiB for the links. These include a lot of languages we don't need, and is stored in quite an inefficient format for reading arbitrary sentence pairs. The program `load-data/select-langs.rs` processes and converts this data into a more friendly format, only including the languages we want. In this script you can specify the languages to use. The primary language is the input to the translator, the secondary is the output and the auxiliary language is the transfor-learning part. 20 | 21 | ```sh 22 | rustc -O load-data/select-langs.rs 23 | ./select-langs 24 | ``` 25 | 26 | This will run for a few minutes. 27 | 28 | ## Training the model 29 | 30 | TODO 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20 2 | aiohttp==3.7.4 3 | tqdm==4.53.0 4 | requests==2.24 5 | -------------------------------------------------------------------------------- /sentence_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from abc import ABC, abstractmethod 4 | import struct 5 | 6 | class Gram(ABC): 7 | @abstractmethod 8 | def parse_one(tag_byte, inp_file): 9 | pass 10 | 11 | @abstractmethod 12 | def __str__(self): 13 | pass 14 | 15 | def __repr__(self): 16 | return str(self) 17 | 18 | class Orig(Gram): 19 | def __init__(self, char): 20 | self.char = char 21 | 22 | def parse_one(tag_byte, inp_file): 23 | if tag_byte == 0: 24 | return None 25 | 26 | n_utf8_bytes = tag_byte 27 | utf8 = inp_file.read(n_utf8_bytes) 28 | char = utf8.decode('utf-8') 29 | inp_file.read(8 - n_utf8_bytes) # Skip padding 30 | 31 | return Orig(char) 32 | 33 | def __str__(self): 34 | return f"Orig({self.char})" 35 | 36 | class Composition(Gram): 37 | def __init__(self, a, b): 38 | # Indicies into gram list 39 | self.a = a 40 | self.b = b 41 | 42 | def parse_one(tag_byte, inp_file): 43 | if tag_byte != 0: 44 | return None 45 | 46 | a, = struct.unpack("")) 60 | 61 | def n_tokens(self): 62 | return len(self.gram_list) 63 | 64 | def from_file(inp_file): 65 | gram_list = [] 66 | while True: 67 | tag_lst = inp_file.read(1) 68 | if len(tag_lst) == 0: 69 | break 70 | tag = tag_lst[0] 71 | 72 | found_any = False 73 | for target_class in bpe_list: 74 | parsed = target_class.parse_one(tag, inp_file) 75 | if parsed == None: 76 | continue 77 | 78 | gram_list.append(parsed) 79 | found_any = True 80 | break 81 | 82 | if not found_any: 83 | raise Exception(f"invalid tag: {tag}") 84 | 85 | 86 | return GramList(gram_list) 87 | 88 | def bpe_to_str(self, bpe): 89 | while True: 90 | all_are_origs = True 91 | new_bpe = [] 92 | for token in bpe: 93 | pointed_gram = self.gram_list[token] 94 | 95 | if isinstance(pointed_gram, Orig): 96 | new_bpe.append(token) 97 | if isinstance(pointed_gram, Composition): 98 | new_bpe.append(pointed_gram.a) 99 | new_bpe.append(pointed_gram.b) 100 | all_are_origs = False 101 | 102 | if all_are_origs: 103 | break 104 | 105 | bpe = new_bpe 106 | 107 | return "".join(self.gram_list[x].char for x in bpe) 108 | 109 | def str_to_bpe(self, st): 110 | # convert into tokens 111 | st = st.lower() 112 | bpe = [] 113 | for ch in st: 114 | found = False 115 | for i, gram in enumerate(self.gram_list): 116 | if isinstance(gram, Orig) and gram.char == ch: 117 | bpe.append(i) 118 | found = True 119 | break 120 | if not found: 121 | print(f"character {ch} not found!") 122 | 123 | # because gram_list is toposorted, we just need to apply all compositions in order 124 | for i, gram in enumerate(self.gram_list): 125 | if not isinstance(gram, Composition): 126 | continue 127 | 128 | new_bpe = [] 129 | 130 | skip_next = False 131 | for token_here, next_token in zip(bpe, bpe[1:] + [None]): 132 | if skip_next: 133 | skip_next = False 134 | continue 135 | 136 | if token_here == gram.a and next_token == gram.b: 137 | new_bpe.append(i) 138 | skip_next = True 139 | else: 140 | new_bpe.append(token_here) 141 | 142 | bpe = new_bpe 143 | 144 | return bpe 145 | 146 | 147 | def __str__(self): 148 | return f"GramList({self.gram_list})" 149 | 150 | STYPE_PRIM = 0 151 | STYPE_SEC = 1 152 | STYPE_AUX = 2 153 | 154 | def open_size(path): 155 | return open(os.path.expanduser(path), "rb"), os.path.getsize(os.path.expanduser(path)) 156 | 157 | sec_links, sec_links_size = open_size("cache/sec-links.bin") 158 | aux_links, aux_links_size = open_size("cache/aux-links.bin") 159 | 160 | sents_prim = open(os.path.expanduser("cache/sentences-prim.bin"), "rb") 161 | sents_sec = open(os.path.expanduser("cache/sentences-sec.bin"), "rb") 162 | sents_aux = open(os.path.expanduser("cache/sentences-aux.bin"), "rb") 163 | 164 | def load_one_pair(other_stype): 165 | links_file = sec_links if other_stype == STYPE_SEC else aux_links 166 | links_size = sec_links_size if other_stype == STYPE_SEC else aux_links_size 167 | sents_other = sents_sec if other_stype == STYPE_SEC else sents_aux 168 | 169 | n_links = links_size // (4 * 4) 170 | selected = random.randrange(0, n_links) 171 | file_offset = selected * 4 * 4 172 | 173 | links_file.seek(file_offset) 174 | 175 | p_start, p_len, o_start, o_len = struct.unpack("<4I", links_file.read(4 * 4)) 176 | 177 | sents_prim.seek(p_start) 178 | prim_sent = list(struct.unpack(f"<{p_len // 2}H", sents_prim.read(p_len))) 179 | 180 | sents_other.seek(o_start) 181 | other_sent = list(struct.unpack(f"<{o_len // 2}H", sents_other.read(o_len))) 182 | 183 | return prim_sent + [-1], other_sent + [-1] 184 | 185 | PRIM_GL = GramList.from_file(open(os.path.expanduser("cache/ngrams-prim.bin"), "rb")) 186 | SEC_GL = GramList.from_file(open(os.path.expanduser("cache/ngrams-sec.bin"), "rb")) 187 | AUX_GL = GramList.from_file(open(os.path.expanduser("cache/ngrams-aux.bin"), "rb")) 188 | 189 | if __name__ == "__main__": 190 | prim, sec = load_one_pair(STYPE_SEC) 191 | 192 | print("/".join(PRIM_GL.bpe_to_str([x]) for x in prim)) 193 | 194 | print("/".join(SEC_GL.bpe_to_str([x]) for x in sec)) 195 | -------------------------------------------------------------------------------- /static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | ilo pi ante toki 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 75 | 76 | 77 | 78 |

ilo pi ante toki language changing tool

79 |
80 | 81 |
82 | 83 |
84 | 85 |
86 | 87 |
88 | 89 |
90 | 91 | 92 | 93 | 100 | 101 |
102 | pro tips: 103 |
    104 |
  • end your sentences with a period
  • 105 |
  • for names, use Tom or Mary (jan ton / jan mewi in toki pona)
  • 106 |
  • every request has a one second cool down, so don't spam the system
  • 107 |
108 |
109 | 110 |

source available at github. please give a star :)

111 | 112 |

made by coral

113 | 114 | 115 | 164 | 165 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from sentence_parser import STYPE_SEC, STYPE_AUX, PRIM_GL, SEC_GL, AUX_GL 9 | from network import device, Encoder, Decoder, into_one_hot, generate_batch, load_from_save, save 10 | 11 | if device.type == "cuda": 12 | BATCH_SIZE = 1024 13 | else: 14 | BATCH_SIZE = 32 15 | 16 | print(f"Using batch size of {BATCH_SIZE}") 17 | 18 | def display_tokens(toklist, gl): 19 | out = "" 20 | last = 0 # 0 = any char, 1 = end, 2 = end after another end 21 | for tidx in toklist: 22 | tok = gl.bpe_to_str([tidx]) 23 | if tok != "": 24 | out += "/" + tok 25 | last = 0 26 | else: 27 | if last == 0: 28 | out += "/" 29 | last += 1 30 | elif last == 1: 31 | out += "..." 32 | last += 1 33 | return out[1:] 34 | 35 | enc, sec_dec, aux_dec, sec_opt, aux_opt = load_from_save() 36 | 37 | EPSILON = 1e-4 38 | 39 | if __name__ == "__main__": 40 | torch.autograd.set_detect_anomaly(True) 41 | crit = nn.CrossEntropyLoss() 42 | 43 | epoch = 0 44 | while True: 45 | epoch += 1 46 | 47 | sec_losses = [] 48 | aux_losses = [] 49 | 50 | sec_info = ("sec", sec_losses, sec_dec, sec_opt, STYPE_SEC) 51 | aux_info = ("aux", aux_losses, aux_dec, aux_opt, STYPE_AUX) 52 | 53 | for batch_nr in range(16): 54 | print(hex(batch_nr)[2:], end=" ") 55 | for name, losses, dec, opt, stype in [sec_info, aux_info, aux_info]: 56 | print(name, end=":") 57 | xs, ys = generate_batch(BATCH_SIZE, stype) 58 | 59 | print("l={:2d}/{:2d};".format(xs.size(1), ys.size(1)), end=" ", flush=True) 60 | 61 | enc.zero_grad() 62 | dec.zero_grad() 63 | 64 | print("z", end="", flush=True) 65 | 66 | hids = enc(xs) 67 | print("e", end="", flush=True) 68 | y_hat, _, _ = dec(hids, ys) 69 | print("d", end="", flush=True) 70 | 71 | pred = y_hat.argmax(axis=2) 72 | is_not_eof = ys != -1 73 | acc = ((pred == ys) * is_not_eof).type(torch.FloatTensor).sum() / (is_not_eof.type(torch.FloatTensor).sum()) 74 | 75 | print("a", end=" ", flush=True) 76 | 77 | loss = crit(EPSILON + y_hat.view(-1, SEC_GL.n_tokens()), ys.view(-1) % SEC_GL.n_tokens()) 78 | 79 | print("L={:.3f}; a={:6.3f}%".format(loss, acc*100), end=" ") 80 | loss.backward() 81 | print("b", end=" ") 82 | opt.step() 83 | print("s", end=";") 84 | 85 | enc.zero_grad() 86 | dec.zero_grad() 87 | 88 | losses.append(loss.item()) 89 | print() 90 | 91 | save(enc, sec_dec, aux_dec, sec_opt, aux_opt) 92 | 93 | print("Epoch", epoch, "done") 94 | 95 | for name, losses, dec, opt, stype in [sec_info, aux_info]: 96 | print(f"For {name}") 97 | xs, ys = generate_batch(4, stype, max_length=10) 98 | gl = SEC_GL if stype == STYPE_SEC else AUX_GL 99 | 100 | hids = enc(xs) 101 | y_hat, _, _ = dec(hids, ys, teacher_forcing_prob=0) 102 | 103 | # Display 104 | for i in range(len(xs)): 105 | print() 106 | print(">", display_tokens(xs[i], PRIM_GL)) 107 | print("=", display_tokens(ys[i], gl)) 108 | gen = y_hat[i].argmax(dim=1) 109 | print("≈", display_tokens(gen, gl)) 110 | 111 | print("Loss:", sum(losses) / len(losses)) 112 | 113 | --------------------------------------------------------------------------------