├── .gitignore ├── requirements.txt ├── sample.py ├── README.md ├── utils.py ├── huffman.py ├── run_single.py ├── huffman_baseline.py ├── block_baseline.py └── arithmetic.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | .ipynb_checkpoints 3 | __pycache__ 4 | data* 5 | generations/ 6 | gpt345_data/ 7 | pytorch-pretrained-BERT/ 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bitarray==3.4.2 2 | certifi==2025.4.26 3 | charset-normalizer==3.4.2 4 | filelock==3.18.0 5 | fsspec==2025.5.1 6 | hf-xet==1.1.3 7 | huggingface-hub==0.32.5 8 | idna==3.10 9 | Jinja2==3.1.6 10 | MarkupSafe==3.0.2 11 | mpmath==1.3.0 12 | networkx==3.2.1 13 | numpy==2.0.2 14 | nvidia-cublas-cu12==12.6.4.1 15 | nvidia-cuda-cupti-cu12==12.6.80 16 | nvidia-cuda-nvrtc-cu12==12.6.77 17 | nvidia-cuda-runtime-cu12==12.6.77 18 | nvidia-cudnn-cu12==9.5.1.17 19 | nvidia-cufft-cu12==11.3.0.4 20 | nvidia-cufile-cu12==1.11.1.6 21 | nvidia-curand-cu12==10.3.7.77 22 | nvidia-cusolver-cu12==11.7.1.2 23 | nvidia-cusparse-cu12==12.5.4.2 24 | nvidia-cusparselt-cu12==0.6.3 25 | nvidia-nccl-cu12==2.26.2 26 | nvidia-nvjitlink-cu12==12.6.85 27 | nvidia-nvtx-cu12==12.6.77 28 | packaging==25.0 29 | PyYAML==6.0.2 30 | regex==2024.11.6 31 | requests==2.32.4 32 | safetensors==0.5.3 33 | sympy==1.14.0 34 | tokenizers==0.21.1 35 | torch==2.7.1 36 | tqdm==4.67.1 37 | transformers==4.52.4 38 | triton==3.3.1 39 | typing_extensions==4.14.0 40 | urllib3==2.4.0 41 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from utils import limit_past, kl, entropy 5 | 6 | def sample(model, enc, length, context, temperature=1.0, device='cuda', topk=-1): 7 | assert length > 0 8 | 9 | context = torch.tensor(context[-1022:], device=device, dtype=torch.long) 10 | 11 | prev = context 12 | output = context 13 | past = None 14 | 15 | total_log_probs = 0 16 | total_entropy_ptau = 0 17 | total_num = 0 18 | total_kl = 0 # in bits 19 | 20 | with torch.no_grad(): 21 | while total_num < length: 22 | if past and past[0].shape[3] >= 1023: 23 | raise RuntimeError 24 | 25 | logits, past = model(prev.unsqueeze(0), past=past) 26 | past = limit_past(past) 27 | logits[0, -1, -1] = -1e10 # endoftext can't happen 28 | logits[0, -1, 628] = -1e10 # 2 newlines can't happen 29 | logits, indices = logits[0, -1, :].sort(descending=True) 30 | base_log_probs = F.log_softmax(logits, dim=-1) 31 | 32 | if topk > 0: 33 | logits = logits[:topk] 34 | 35 | logits = logits / temperature 36 | log_probs = F.log_softmax(logits, dim=-1) 37 | probs = torch.exp(log_probs) 38 | 39 | total_kl += kl(probs, log_probs, base_log_probs[:topk]) 40 | 41 | selection = torch.multinomial(probs, num_samples=1).item() 42 | log_prob_chosen = base_log_probs[selection] 43 | total_log_probs += log_prob_chosen.item() 44 | 45 | total_entropy_ptau += entropy(probs, log_probs) 46 | 47 | prev = indices[selection].view(1) 48 | output = torch.cat((output, prev)) 49 | total_num += 1 50 | 51 | avg_NLL = -total_log_probs/total_num 52 | avg_KL = total_kl/total_num 53 | avg_Hq = total_entropy_ptau/total_num 54 | 55 | return output[len(context):].tolist(), avg_NLL, avg_KL, avg_Hq 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STEGASURAS 2 | STEGanography via Arithmetic coding and Strong neURAl modelS 3 | 4 | This repository contains implementations of the steganography algorithms from ["Neural Linguistic Steganography," Zachary M. Ziegler*, Yuntian Deng*, Alexander M. Rush](https://arxiv.org/abs/1909.01496). 5 | 6 | ## Update (June 2025) 7 | 8 | Thanks to [Sabrina Ning](https://github.com/sabrina-ning) and [Wei-Chiu Ma](https://www.cs.cornell.edu/~weichiu/) for updating this repository to be compatible with the latest versions of PyTorch and Hugging Face Transformers. 9 | 10 | ## Online Demo 11 | 12 | Our online demo can be found at [https://steganography.live/](https://steganography.live/). 13 | 14 | ## Language model 15 | 16 | Experiments in the paper use the medium (345M parameter) GPT model via [pytorch_transformers](https://github.com/huggingface/pytorch-transformers). For compute reasons the default in this code base is the small version but the medium or large versions can be used by changing the `model_name` parameter of `get_model`. 17 | 18 | ## Algorithms 19 | 20 | The steganography algorithms implemented are: 21 | 1. Our proposed arithmetic coding-based algorithm 22 | 2. The Huffman algorithm from [RNN-Stega: Linguistic Steganography Based on Recurrent Neural Networks](https://ieeexplore.ieee.org/document/8470163) 23 | 3. The binning algorithm from [Generating Steganographic Text with LSTMs](https://arxiv.org/abs/1705.10742) 24 | 25 | An example of encoding and decoding a message is in `run_single.py`. The algorithm used is determined by the `mode` parameter. 26 | 27 | 28 | ## Citation 29 | 30 | ``` 31 | @inproceedings{ziegler-etal-2019-neural, 32 | title = "Neural Linguistic Steganography", 33 | author = "Ziegler, Zachary and 34 | Deng, Yuntian and 35 | Rush, Alexander", 36 | editor = "Inui, Kentaro and 37 | Jiang, Jing and 38 | Ng, Vincent and 39 | Wan, Xiaojun", 40 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 41 | month = nov, 42 | year = "2019", 43 | address = "Hong Kong, China", 44 | publisher = "Association for Computational Linguistics", 45 | url = "https://aclanthology.org/D19-1115/", 46 | doi = "10.18653/v1/D19-1115", 47 | pages = "1210--1215", 48 | abstract = "Whereas traditional cryptography encrypts a secret message into an unintelligible form, steganography conceals that communication is taking place by encoding a secret message into a cover signal. Language is a particularly pragmatic cover signal due to its benign occurrence and independence from any one medium. Traditionally, linguistic steganography systems encode secret messages in existing text via synonym substitution or word order rearrangements. Advances in neural language models enable previously impractical generation-based techniques. We propose a steganography technique based on arithmetic coding with large-scale neural language models. We find that our approach can generate realistic looking cover sentences as evaluated by humans, while at the same time preserving security by matching the cover message distribution with the language model distribution." 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import bitarray 4 | 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | def decode(self, token_ids, **kwargs): 8 | filtered_tokens = self.convert_ids_to_tokens(token_ids) 9 | text = self.convert_tokens_to_string(filtered_tokens) 10 | return text 11 | AutoTokenizer.decode = decode 12 | 13 | def _convert_token_to_id(self, token): 14 | return self.encoder.get(token, 0) 15 | AutoTokenizer._convert_token_to_id = _convert_token_to_id 16 | 17 | 18 | # handles both old and new cache formats 19 | def limit_past(past): 20 | past = list(past) 21 | for i in range(len(past)): 22 | if isinstance(past[i], tuple): 23 | key, value = past[i] 24 | past[i] = ( 25 | key[:, :, :, -1022:], 26 | value[:, :, :, -1022:] 27 | ) 28 | else: 29 | past[i] = past[i][:, :, :, -1022:] 30 | return past 31 | 32 | def kl(q, logq, logp): 33 | res = q*(logq-logp)/0.69315 34 | res[q==0] = 0 35 | return res.sum().item() # in bits 36 | 37 | def entropy(q, logq): 38 | res = q*logq/0.69315 39 | res[q==0] = 0 40 | return -res.sum().item() # in bits 41 | 42 | # e.g. [0, 1, 1, 1] looks like 1110=14 43 | def bits2int(bits): 44 | res = 0 45 | for i, bit in enumerate(bits): 46 | res += bit*(2**i) 47 | return res 48 | 49 | def int2bits(inp, num_bits): 50 | if num_bits == 0: 51 | return [] 52 | strlist = ('{0:0%db}'%num_bits).format(inp) 53 | return [int(strval) for strval in reversed(strlist)] 54 | 55 | def is_sent_finish(token_idx, enc): 56 | token = enc.decode([token_idx]) 57 | return '.' in token or '!' in token or '?' in token 58 | 59 | def num_same_from_beg(bits1, bits2): 60 | assert len(bits1) == len(bits2) 61 | for i in range(len(bits1)): 62 | if bits1[i] != bits2[i]: 63 | break 64 | return i 65 | 66 | def encode_context(raw_text, enc): 67 | context_tokens = enc.encode('<|endoftext|>') + enc.encode(raw_text) 68 | return context_tokens 69 | 70 | # Use gpt2-medium for 345M param model 71 | # Use gpt2-large for 774M param model 72 | def get_model(seed=1234, model_name='gpt2'): 73 | np.random.seed(seed) 74 | torch.random.manual_seed(seed) 75 | torch.cuda.manual_seed(seed) 76 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 77 | 78 | enc = AutoTokenizer.from_pretrained(model_name) 79 | enc.unk_token = None 80 | enc.bos_token = None 81 | enc.eos_token = None 82 | 83 | model = AutoModelForCausalLM.from_pretrained(model_name) 84 | model.to(device) 85 | model.eval() 86 | # model.double() 87 | 88 | return enc, model 89 | 90 | enc32_itoc = ['\0', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '.', ',', "'", '!', ' '] 91 | enc32_ctoi = {k: v for v, k in enumerate(enc32_itoc)} 92 | def enc32(text): 93 | bits = [] 94 | for c in text: 95 | bits.extend(int2bits(enc32_ctoi[c], 5)) 96 | return bits 97 | 98 | def dec32(bits): 99 | text = '' 100 | for i in range(0, len(bits), 5): 101 | c = enc32_itoc[bits2int(bits[i:i+5])] 102 | if c == '\0': 103 | break 104 | text += c 105 | return text 106 | 107 | # message should be bit string 108 | # encoded should be text string 109 | def expansion_ratio(message, encoded): 110 | message_bits = len(message) 111 | encoded_ba = bitarray.bitarray() 112 | encoded_ba.frombytes(encoded.encode('utf-8')) 113 | encoded_bits = len(encoded_ba.tolist()) 114 | return encoded_bits/message_bits 115 | -------------------------------------------------------------------------------- /huffman.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | from functools import total_ordering 3 | 4 | """ 5 | Code for Huffman Coding, compression and decompression. 6 | Explanation at http://bhrigu.me/blog/2017/01/17/huffman-coding-python-implementation/ 7 | Adapted from https://github.com/bhrigu123/huffman-coding 8 | """ 9 | 10 | @total_ordering 11 | class HeapNode: 12 | def __init__(self, token, freq): 13 | self.token = token 14 | self.freq = freq 15 | self.left = None 16 | self.right = None 17 | 18 | # defining comparators less_than and equals 19 | def __lt__(self, other): 20 | return self.freq < other.freq 21 | 22 | def __eq__(self, other): 23 | if(other == None): 24 | return False 25 | if(not isinstance(other, HeapNode)): 26 | return False 27 | return self.freq == other.freq 28 | 29 | class HuffmanCoding: 30 | def __init__(self): 31 | self.heap = [] 32 | self.codes = {} 33 | self.reverse_mapping = {} 34 | 35 | # functions for compression: 36 | 37 | def make_heap(self, frequency): 38 | for key in frequency: 39 | node = HeapNode(key, frequency[key]) 40 | heapq.heappush(self.heap, node) 41 | 42 | def make_heap_from_array(self, freqs): 43 | for index in range(len(freqs)): 44 | node = HeapNode(index, freqs[index]) 45 | heapq.heappush(self.heap, node) 46 | 47 | def merge_nodes(self): 48 | while(len(self.heap)>1): 49 | node1 = heapq.heappop(self.heap) 50 | node2 = heapq.heappop(self.heap) 51 | 52 | merged = HeapNode(None, node1.freq + node2.freq) 53 | merged.left = node1 54 | merged.right = node2 55 | 56 | heapq.heappush(self.heap, merged) 57 | 58 | 59 | def make_codes_helper(self, root, current_code): 60 | if(root == None): 61 | return 62 | 63 | if(root.token != None): 64 | self.codes[root.token] = current_code 65 | self.reverse_mapping[current_code] = root.token 66 | return 67 | 68 | self.make_codes_helper(root.left, current_code + "0") 69 | self.make_codes_helper(root.right, current_code + "1") 70 | 71 | def make_codes(self): 72 | root = heapq.heappop(self.heap) 73 | current_code = "" 74 | self.make_codes_helper(root, current_code) 75 | return root 76 | 77 | 78 | def get_encoded_tokens(self, token_list): 79 | encoded_text = "" 80 | for token in token_list: 81 | encoded_text += self.codes[token] 82 | return encoded_text 83 | 84 | def decode_text(self, encoded_text): 85 | current_code = "" 86 | decoded_text = "" 87 | 88 | for bit in encoded_text: 89 | current_code += bit 90 | if(current_code in self.reverse_mapping): 91 | character = self.reverse_mapping[current_code] 92 | decoded_text += character 93 | current_code = "" 94 | 95 | return decoded_text 96 | 97 | 98 | def decompress(self, input_path): 99 | filename, file_extension = os.path.splitext(self.path) 100 | output_path = filename + "_decompressed" + ".txt" 101 | 102 | with open(input_path, 'rb') as file, open(output_path, 'w') as output: 103 | bit_string = "" 104 | 105 | byte = file.read(1) 106 | while(len(byte) > 0): 107 | byte = ord(byte) 108 | bits = bin(byte)[2:].rjust(8, '0') 109 | bit_string += bits 110 | byte = file.read(1) 111 | 112 | encoded_text = self.remove_padding(bit_string) 113 | 114 | decompressed_text = self.decode_text(encoded_text) 115 | 116 | output.write(decompressed_text) 117 | 118 | return output_path 119 | 120 | -------------------------------------------------------------------------------- /run_single.py: -------------------------------------------------------------------------------- 1 | import bitarray 2 | import math 3 | 4 | from utils import get_model, encode_context 5 | 6 | from arithmetic import encode_arithmetic, decode_arithmetic 7 | from block_baseline import get_bins, encode_block, decode_block 8 | from huffman_baseline import encode_huffman, decode_huffman 9 | from sample import sample 10 | 11 | def main(): 12 | enc, model = get_model(model_name='gpt2') 13 | 14 | 15 | ## PARAMETERS 16 | message_str = "This is a very secret message!" 17 | 18 | unicode_enc = False 19 | mode = 'arithmetic' 20 | block_size = 3 # for huffman and bins 21 | temp = 0.9 # for arithmetic 22 | precision = 26 # for arithmetic 23 | sample_tokens = 100 # for sample 24 | topk = 300 25 | finish_sent=False # whether or not to force finish sent. If so, stats displayed will be for non-finished sentence 26 | 27 | ## VALIDATE PARAMETERS 28 | if mode not in ['arithmetic', 'huffman', 'bins', 'sample']: 29 | raise NotImplementedError 30 | 31 | if mode == 'bins': 32 | bin2words, words2bin = get_bins(len(enc.encoder), block_size) 33 | 34 | context = \ 35 | """Washington received his initial military training and command with the Virginia Regiment during the French and Indian War. He was later elected to the Virginia House of Burgesses and was named a delegate to the Continental Congress, where he was appointed Commanding General of the nation's Continental Army. Washington led American forces, allied with France, in the defeat of the British at Yorktown. Once victory for the United States was in hand in 1783, Washington resigned his commission. 36 | 37 | 38 | """ 39 | 40 | context_tokens = encode_context(context, enc) 41 | 42 | # ------------------------------------------------------------------------------------ 43 | # ------------------------------------------------------------------------------------ 44 | 45 | # First encode message to uniform bits, without any context 46 | # (not essential this is arithmetic vs ascii, but it's more efficient when the message is natural language) 47 | if unicode_enc: 48 | ba = bitarray.bitarray() 49 | ba.frombytes(message_str.encode('utf-8')) 50 | message = ba.tolist() 51 | else: 52 | message_ctx = enc.encode('<|endoftext|>') 53 | message_str += '' 54 | message = decode_arithmetic(model, enc, message_str, message_ctx, precision=40, topk=60000) 55 | 56 | # Next encode bits into cover text, using arbitrary context 57 | Hq = 0 58 | if mode == 'arithmetic': 59 | out, nll, kl, words_per_bit, Hq = encode_arithmetic(model, enc, message, context_tokens, temp=temp, finish_sent=finish_sent, precision=precision, topk=topk) 60 | elif mode == 'huffman': 61 | out, nll, kl, words_per_bit = encode_huffman(model, enc, message, context_tokens, block_size, finish_sent=finish_sent) 62 | elif mode == 'bins': 63 | out, nll, kl, words_per_bit = encode_block(model, enc, message, context_tokens, block_size, bin2words, words2bin, finish_sent=finish_sent) 64 | elif mode == 'sample': 65 | out, nll, kl, Hq = sample(model, enc, sample_tokens, context_tokens, temperature=temp, topk=topk) 66 | words_per_bit = 1 67 | text = enc.decode(out) 68 | 69 | print(message) 70 | print(len(message)) 71 | print("="*40 + " Encoding " + "="*40) 72 | print(text) 73 | print('ppl: %0.2f, kl: %0.3f, words/bit: %0.2f, bits/word: %0.2f, entropy: %.2f' % (math.exp(nll), kl, words_per_bit, 1/words_per_bit, Hq/0.69315)) 74 | 75 | # Decode binary message from bits using the same arbitrary context 76 | if mode != 'sample': 77 | if mode == 'arithmetic': 78 | message_rec = decode_arithmetic(model, enc, text, context_tokens, temp=temp, precision=precision, topk=topk) 79 | elif mode == 'huffman': 80 | message_rec = decode_huffman(model, enc, text, context_tokens, block_size) 81 | elif mode == 'bins': 82 | message_rec = decode_block(model, enc, text, context_tokens, block_size, bin2words, words2bin) 83 | 84 | print("="*40 + " Recovered Message " + "="*40) 85 | print(message_rec) 86 | print("=" * 80) 87 | # Finally map message bits back to original text 88 | if unicode_enc: 89 | message_rec = [bool(item) for item in message_rec] 90 | ba = bitarray.bitarray(message_rec) 91 | reconst = ba.tobytes().decode('utf-8', 'ignore') 92 | else: 93 | reconst = encode_arithmetic(model, enc, message_rec, message_ctx, precision=40, topk=60000) 94 | reconst = enc.decode(reconst[0]) 95 | print(reconst) 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /huffman_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from huffman import HuffmanCoding 5 | from utils import kl, entropy, is_sent_finish, limit_past 6 | 7 | def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='cuda'): 8 | length = len(message) 9 | 10 | context = torch.tensor(context[-1022:], device=device, dtype=torch.long) 11 | 12 | prev = context 13 | output = context 14 | past = None 15 | 16 | total_num = 0 17 | total_num_for_stats = 0 18 | total_log_probs = 0 19 | total_kl = 0 # in bits 20 | total_num_sents = 0 21 | 22 | with torch.no_grad(): 23 | i = 0 24 | sent_finish = False 25 | while i < length or (finish_sent and not sent_finish): 26 | logits, past = model(prev.unsqueeze(0), past=past) 27 | past = limit_past(past) 28 | logits[0, -1, -1] = -1e10 # endoftext can't happen 29 | logits[0, -1, 628] = -1e10 # 2 newlines can't happen 30 | logits, indices = logits[0, -1, :].sort(descending=True) 31 | 32 | # Get the top 2**bits options 33 | indices = indices[:2**bits_per_word] 34 | log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word] 35 | probs = torch.exp(log_probs) 36 | 37 | if i >= length: 38 | selection = 0 39 | sent_finish = is_sent_finish(indices[0].item(), enc) 40 | else: 41 | probs_array = probs.cpu().numpy() 42 | coding = HuffmanCoding() 43 | coding.make_heap_from_array(probs_array) 44 | coding.merge_nodes() 45 | root = coding.make_codes() 46 | 47 | #print(message[i:i+10]) 48 | while root.token is None: 49 | if i >= length or message[i] == 0: 50 | root = root.left 51 | else: 52 | root = root.right 53 | i += 1 54 | selection = root.token 55 | 56 | logq = torch.tensor([-len(coding.codes[idx]) for idx in range(len(probs_array))], dtype=torch.float, device=device) # in bits 57 | logq = logq*0.69315 # in nats 58 | q = torch.exp(logq) 59 | total_kl += kl(q, logq, log_probs) 60 | total_log_probs += log_probs[selection].item() 61 | total_num_for_stats += 1 62 | 63 | total_num += 1 64 | 65 | prev = indices[selection].view(1) 66 | output = torch.cat((output, prev)) 67 | 68 | avg_NLL = -total_log_probs/total_num_for_stats 69 | avg_KL = total_kl/total_num_for_stats 70 | words_per_bit = total_num_for_stats/i 71 | 72 | return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit 73 | 74 | def decode_huffman(model, enc, text, context, bits_per_word, device='cuda'): 75 | # inp is a list of token indices 76 | # context is a list of token indices 77 | inp = enc.encode(text) 78 | i = 0 79 | while i < len(inp): 80 | if inp[i] == 628: 81 | inp[i] = 198 82 | inp[i+1:i+1] = [198] 83 | i += 2 84 | else: 85 | i += 1 86 | 87 | context = torch.tensor(context[-1022:], device=device, dtype=torch.long) 88 | prev = context 89 | past = None 90 | 91 | message = [] 92 | with torch.no_grad(): 93 | i = 0 94 | while i < len(inp): 95 | if past and past[0].shape[3] >= 1023: 96 | raise RuntimeError 97 | 98 | logits, past = model(prev.unsqueeze(0), past=past) 99 | past = limit_past(past) 100 | logits[0, -1, -1] = -1e10 # endoftext can't happen 101 | logits[0, -1, 628] = -1e10 # 2 newlines can't happen 102 | logits, indices = logits[0, -1, :].sort(descending=True) 103 | 104 | # Get the top 2**bits options 105 | indices = indices[:2**bits_per_word] 106 | log_probs = F.log_softmax(logits, dim=-1)[:2**bits_per_word] 107 | probs = torch.exp(log_probs) 108 | 109 | if inp[i] not in indices: 110 | true_token_text = enc.decoder[inp[i]] 111 | for rank_idx in range(2**bits_per_word): 112 | prop_token_text = enc.decoder[indices[rank_idx].item()] 113 | # common case that is not caught 114 | if inp[i] == 128 and indices[rank_idx] == 198: 115 | rank = rank_idx 116 | inp[i] = indices[rank_idx].item() 117 | break 118 | 119 | # Is there a more likely prefix token that could be the actual token generated? 120 | if len(prop_token_text) <= len(true_token_text) and \ 121 | prop_token_text == true_token_text[:len(prop_token_text)]: 122 | rank = rank_idx 123 | suffix = true_token_text[len(prop_token_text):] 124 | suffix_tokens = enc.encode(suffix) # a list 125 | inp[i] = indices[rank_idx].item() 126 | inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list 127 | break 128 | 129 | # Is there a more likely longer token that could be the actual token generated? 130 | elif len(prop_token_text) > len(true_token_text) and \ 131 | true_token_text == prop_token_text[:len(true_token_text)]: 132 | whole_text = true_token_text 133 | num_extra = 1 134 | while len(whole_text) < len(prop_token_text): 135 | whole_text += enc.decoder[inp[i+num_extra]] 136 | num_extra += 1 137 | if prop_token_text == whole_text[:len(prop_token_text)]: 138 | rank = rank_idx 139 | inp[i] = indices[rank_idx].item() 140 | for j in range(1, num_extra): 141 | del inp[i+j] 142 | 143 | if len(whole_text) > len(prop_token_text): 144 | suffix = whole_text[len(prop_token_text):] 145 | suffix_tokens = enc.encode(suffix) # a list 146 | inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list 147 | break 148 | else: 149 | print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text)) 150 | rank = 0 151 | else: 152 | rank = (indices == inp[i]).nonzero().item() 153 | 154 | probs_array = probs.cpu().numpy() 155 | coding = HuffmanCoding() 156 | coding.make_heap_from_array(probs_array) 157 | coding.merge_nodes() 158 | coding.make_codes() 159 | 160 | tokens_t = map(int, coding.codes[rank]) 161 | 162 | message.extend(tokens_t) 163 | prev = torch.tensor([inp[i]], device=device, dtype=torch.long) 164 | i += 1 165 | 166 | return message 167 | -------------------------------------------------------------------------------- /block_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | from utils import kl, entropy, is_sent_finish, limit_past, bits2int, int2bits 6 | 7 | # number of bins is 2^block_size 8 | # each bin contains vocab_size/2^block_size words 9 | def get_bins(vocab_size, block_size): 10 | num_bins = 2**block_size 11 | words_per_bin = vocab_size/num_bins 12 | 13 | vocab_ordering = np.arange(vocab_size) 14 | np.random.seed(block_size) 15 | np.random.shuffle(vocab_ordering) 16 | 17 | bin2words = [vocab_ordering[int(i*words_per_bin):int((i+1)*words_per_bin)] for i in range(num_bins)] 18 | bin2words = [np.array(words) for words in bin2words] 19 | words2bin_list = [{i: j for i in bin2words[j]} for j in range(num_bins)] 20 | words2bin = {} 21 | for d in words2bin_list: 22 | words2bin.update(d) 23 | 24 | return bin2words, words2bin 25 | 26 | def encode_block(model, enc, message, context, block_size, bin2words, words2bin, finish_sent=False, device='cuda'): 27 | length = len(message) 28 | 29 | context = torch.tensor(context[-1022:], device=device, dtype=torch.long) 30 | 31 | prev = context 32 | output = context 33 | past = None 34 | 35 | total_num = 0 36 | total_num_for_stats = 0 37 | total_log_probs = 0 38 | total_kl = 0 # in bits 39 | total_num_sents = 0 40 | 41 | with torch.no_grad(): 42 | i = 0 43 | sent_finish = False 44 | while i < length or (finish_sent and not sent_finish): 45 | logits, past = model(prev.unsqueeze(0), past=past) 46 | past = limit_past(past) 47 | logits[0, -1, -1] = -1e10 # endoftext can't happen 48 | logits[0, -1, 628] = -1e10 # 2 newlines can't happen 49 | logits = logits[0, -1, :] 50 | log_probs = F.log_softmax(logits, dim=-1) 51 | 52 | filtered_logits = logits.clone() 53 | filtered_logits[:] = -1e10 # first set all to 0 54 | 55 | if i >= length: 56 | _, indices = logits.sort(descending=True) 57 | sent_finish = is_sent_finish(indices[0].item(), enc) 58 | else: 59 | # First calculate logq 60 | logq = logits.clone() 61 | logq[:] = -1e10 # first set all to 0 62 | 63 | for bin_val in range(2**block_size): 64 | filtered_logits = logits.clone() 65 | filtered_logits[:] = -1e10 # first set all to 0 66 | available_tokens = bin2words[bin_val] 67 | filtered_logits[available_tokens] = logits[available_tokens] 68 | filtered_logits, indices = filtered_logits.sort(descending=True) 69 | 70 | logq[indices[0]] = -block_size # in bits 71 | 72 | logq = logq*0.69315 # in nats 73 | q = torch.exp(logq) 74 | 75 | # Then find the actual word for the right bin 76 | m_part = message[i:i+block_size] 77 | 78 | filtered_logits = logits.clone() 79 | filtered_logits[:] = -1e10 # first set all to 0 80 | available_tokens = bin2words[bits2int(m_part)] 81 | filtered_logits[available_tokens] = logits[available_tokens] 82 | filtered_logits, indices = filtered_logits.sort(descending=True) 83 | 84 | total_kl += kl(q, logq, log_probs) 85 | total_log_probs += log_probs[indices[0]].item() 86 | i += block_size 87 | total_num_for_stats += 1 88 | 89 | 90 | total_num += 1 91 | prev = indices[0].view(1) 92 | output = torch.cat((output, prev)) 93 | 94 | avg_NLL = -total_log_probs/total_num_for_stats 95 | avg_KL = total_kl/total_num_for_stats 96 | words_per_bit = total_num_for_stats/i 97 | 98 | return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit 99 | 100 | def decode_block(model, enc, text, context, block_size, bin2words, words2bin, device='cuda'): 101 | # inp is a list of token indices 102 | # context is a list of token indices 103 | inp = enc.encode(text) 104 | i = 0 105 | while i < len(inp): 106 | if inp[i] == 628: 107 | inp[i] = 198 108 | inp[i+1:i+1] = [198] 109 | i += 2 110 | else: 111 | i += 1 112 | 113 | context = torch.tensor(context[-1022:], device=device, dtype=torch.long) 114 | prev = context 115 | past = None 116 | 117 | message = [] 118 | with torch.no_grad(): 119 | i = 0 120 | while i < len(inp): 121 | if past and past[0].shape[3] >= 1023: 122 | raise RuntimeError 123 | bin_num = words2bin[inp[i]] 124 | 125 | logits, past = model(prev.unsqueeze(0), past=past) 126 | past = limit_past(past) 127 | logits[0, -1, -1] = -1e10 # endoftext can't happen 128 | logits[0, -1, 628] = -1e10 # 2 newlines can't happen 129 | 130 | logits = logits[0, -1, :] 131 | filtered_logits = logits.clone() 132 | filtered_logits[:] = -1e10 # first set all to 0 133 | 134 | available_tokens = bin2words[bin_num] 135 | filtered_logits[available_tokens] = logits[available_tokens] 136 | filtered_logits, indices = filtered_logits.sort(descending=True) 137 | 138 | rank = (indices == inp[i]).nonzero().item() 139 | 140 | # Handle errors that could happen because of BPE 141 | if rank > 0: 142 | true_token_text = enc.decoder[inp[i]] 143 | for bin_num in range(len(bin2words)): 144 | filtered_logits = logits.clone() 145 | filtered_logits[:] = -1e10 # first set all to 0 146 | 147 | available_tokens = bin2words[bin_num] 148 | filtered_logits[available_tokens] = logits[available_tokens] 149 | filtered_logits, indices = filtered_logits.sort(descending=True) 150 | 151 | prop_token_text = enc.decoder[indices[0].item()] 152 | #print(true_token_text, prop_token_text) 153 | 154 | # Is there a more likely prefix token that could be the actual token generated? 155 | if len(prop_token_text) < len(true_token_text) and \ 156 | prop_token_text == true_token_text[:len(prop_token_text)]: 157 | suffix = true_token_text[len(prop_token_text):] 158 | suffix_tokens = enc.encode(suffix) # a list 159 | inp[i] = indices[0].item() 160 | inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list 161 | break 162 | 163 | # Is there a more likely longer token that could be the actual token generated? 164 | elif len(prop_token_text) > len(true_token_text) and \ 165 | true_token_text == prop_token_text[:len(true_token_text)]: 166 | whole_text = true_token_text 167 | num_extra = 1 168 | while len(whole_text) < len(prop_token_text): 169 | whole_text += enc.decoder[inp[i+num_extra]] 170 | num_extra += 1 171 | if prop_token_text == whole_text[:len(prop_token_text)]: 172 | inp[i] = indices[0].item() 173 | for j in range(1, num_extra): 174 | del inp[i+j] 175 | 176 | if len(whole_text) > len(prop_token_text): 177 | suffix = whole_text[len(prop_token_text):] 178 | suffix_tokens = enc.encode(suffix) # a list 179 | inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list 180 | break 181 | else: 182 | print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text)) 183 | 184 | tokens_t = int2bits(bin_num, block_size) 185 | 186 | message.extend(tokens_t) 187 | prev = torch.tensor([inp[i]], device=device, dtype=torch.long) 188 | i += 1 189 | 190 | return message 191 | 192 | if __name__ == '__main__': 193 | np.random.seed(123) 194 | 195 | bin2words, words2bin = get_bins(50257, 5) 196 | print(words2bin[153]) 197 | -------------------------------------------------------------------------------- /arithmetic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers import DynamicCache 4 | 5 | from utils import limit_past, kl, entropy, bits2int, int2bits, is_sent_finish, num_same_from_beg 6 | 7 | def encode_arithmetic(model, enc, message, context, finish_sent=False, device='cuda', temp=1.0, precision=16, topk=50000): 8 | context = torch.tensor(context[-1022:], device=device, dtype=torch.long) 9 | 10 | max_val = 2**precision 11 | # threshold = 2**(-precision) 12 | cur_interval = [0, max_val] # bottom inclusive, top exclusive 13 | 14 | prev = context 15 | output = context 16 | past = None 17 | 18 | # total_num = 0 19 | total_num_for_stats = 0 20 | total_log_probs = 0 21 | total_kl = 0 # in bits 22 | total_entropy_ptau = 0 23 | # total_num_sents = 0 24 | 25 | with torch.no_grad(): 26 | i = 0 27 | sent_finish = False 28 | while i < len(message) or (finish_sent and not sent_finish): 29 | out = model(prev.unsqueeze(0), past_key_values=DynamicCache.from_legacy_cache(past), use_cache=True) 30 | logits = out.logits 31 | past = out.past_key_values 32 | past = limit_past(past) 33 | 34 | logits[0, -1, -1] = -1e20 # endoftext token can't happen 35 | logits[0, -1, 628] = -1e20 # 2 newlines token can't happen 36 | 37 | logits, indices = logits[0, -1, :].sort(descending=True) 38 | logits = logits.double() 39 | logits_temp = logits / temp 40 | probs_temp = F.softmax(logits_temp, dim=0) 41 | log_probs_temp = F.log_softmax(logits_temp, dim=0) 42 | log_probs = F.log_softmax(logits, dim=0) 43 | 44 | # conditions for having reached the end of the message 45 | if i >= len(message): 46 | selection = 0 47 | sent_finish = is_sent_finish(indices[selection].item(), enc) 48 | else: 49 | # Cutoff low probabilities that would be rounded to 0 50 | cur_int_range = cur_interval[1]-cur_interval[0] 51 | cur_threshold = 1/cur_int_range 52 | k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk) 53 | probs_temp_int = probs_temp[:k] # Cutoff all but top k 54 | 55 | # Rescale to correct range 56 | probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range 57 | 58 | # Round probabilities to integers given precision 59 | probs_temp_int = probs_temp_int.round().long() 60 | cum_probs = probs_temp_int.cumsum(0) 61 | 62 | # Remove any elements from the bottom if rounding caused the total prob to be too large 63 | overfill_index = (cum_probs > cur_int_range).nonzero() 64 | if len(overfill_index) > 0: 65 | cum_probs = cum_probs[:overfill_index[0]] 66 | 67 | # Add any mass to the top if removing/rounding causes the total prob to be too small 68 | cum_probs += cur_int_range-cum_probs[-1] # add 69 | 70 | # Get out resulting probabilities 71 | probs_final = cum_probs.clone() 72 | probs_final[1:] = cum_probs[1:] - cum_probs[:-1] 73 | 74 | # Convert to position in range 75 | cum_probs += cur_interval[0] 76 | 77 | # Get selected index based on binary fraction from message bits 78 | message_bits = message[i:i+precision] 79 | if i+precision > len(message): 80 | message_bits = message_bits + [0]*(i+precision-len(message)) 81 | message_idx = bits2int(reversed(message_bits)) 82 | selection = (cum_probs > message_idx).nonzero()[0].item() 83 | 84 | # Calculate new range as ints 85 | new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0] 86 | new_int_top = cum_probs[selection] 87 | 88 | # Convert range to bits 89 | new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision))) 90 | new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive 91 | 92 | # Consume most significant bits which are now fixed and update interval 93 | num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc) 94 | i += num_bits_encoded 95 | 96 | new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded 97 | new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded 98 | 99 | cur_interval[0] = bits2int(reversed(new_int_bottom_bits)) 100 | cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive 101 | 102 | # Gather statistics 103 | total_log_probs += log_probs[selection].item() 104 | 105 | q = probs_final.double()/probs_final.sum() 106 | logq = q.log() 107 | total_kl += kl(q, logq, log_probs[:len(q)]) 108 | total_entropy_ptau += entropy(probs_temp, log_probs_temp) 109 | total_num_for_stats += 1 110 | 111 | # Update history with new token 112 | prev = indices[selection].view(1) 113 | output = torch.cat((output, prev)) 114 | # total_num += 1 115 | # print(enc.decode(prev.tolist()), message_bits[:num_bits_encoded]) 116 | 117 | # For text->bits->text 118 | partial = enc.decode(output[len(context):].tolist()) 119 | if '' in partial: 120 | break 121 | 122 | avg_NLL = -total_log_probs/total_num_for_stats 123 | avg_KL = total_kl/total_num_for_stats 124 | avg_Hq = total_entropy_ptau/total_num_for_stats 125 | words_per_bit = total_num_for_stats/i 126 | 127 | return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit, avg_Hq 128 | 129 | def decode_arithmetic(model, enc, text, context, device='cuda', temp=1.0, precision=16, topk=50000): 130 | # inp is a list of token indices 131 | # context is a list of token indices 132 | inp = enc.encode(text) 133 | # common BPE error case: 128, 128 (2 newlines) is interpretted as 628 (2 newlines) 134 | i = 0 135 | while i < len(inp): 136 | if inp[i] == 628: 137 | inp[i] = 198 138 | inp[i+1:i+1] = [198] 139 | i += 2 140 | else: 141 | i += 1 142 | 143 | context = torch.tensor(context[-1022:], device=device, dtype=torch.long) 144 | 145 | max_val = 2**precision 146 | # threshold = 2**(-precision) 147 | cur_interval = [0, max_val] # bottom inclusive, top exclusive 148 | 149 | prev = context 150 | past = None 151 | message = [] 152 | with torch.no_grad(): 153 | i = 0 154 | while i < len(inp): 155 | out = model(prev.unsqueeze(0), past_key_values=DynamicCache.from_legacy_cache(past), use_cache=True) 156 | logits = out.logits 157 | past = out.past_key_values 158 | past = limit_past(past) 159 | 160 | logits[0, -1, -1] = -1e10 # endoftext can't happen 161 | logits[0, -1, 628] = -1e10 # 2 newlines can't happen 162 | 163 | logits, indices = logits[0, -1, :].sort(descending=True) 164 | logits = logits.double() 165 | logits_temp = logits / temp 166 | probs_temp = F.softmax(logits_temp, dim=0) 167 | 168 | # Cutoff low probabilities that would be rounded to 0 169 | cur_int_range = cur_interval[1]-cur_interval[0] 170 | cur_threshold = 1/cur_int_range 171 | k = min(max(2, (probs_temp < cur_threshold).nonzero()[0].item()), topk) 172 | probs_temp_int = probs_temp[:k] # Cutoff all but top k 173 | 174 | # Rescale to correct range 175 | probs_temp_int = probs_temp_int/probs_temp_int.sum()*cur_int_range 176 | 177 | # Round probabilities to integers given precision 178 | probs_temp_int = probs_temp_int.round().long() 179 | cum_probs = probs_temp_int.cumsum(0) 180 | 181 | # Remove any elements from the bottom if rounding caused the total prob to be too large 182 | overfill_index = (cum_probs > cur_int_range).nonzero() 183 | if len(overfill_index) > 0: 184 | cum_probs = cum_probs[:overfill_index[0]] 185 | k = overfill_index[0].item() 186 | 187 | # Add any mass to the top if removing/rounding causes the total prob to be too small 188 | cum_probs += cur_int_range-cum_probs[-1] # add 189 | 190 | # Convert to position in range 191 | cum_probs += cur_interval[0] 192 | 193 | rank = (indices == inp[i]).nonzero().item() 194 | 195 | # Handle most errors that could happen because of BPE with heuristic 196 | if rank >= k: 197 | true_token_text = enc.decode([inp[i]]) 198 | for rank_idx in range(k): 199 | prop_token_text = enc.decode([indices[rank_idx].item()]) 200 | # common case that is not caught 201 | if inp[i] == 128 and indices[rank_idx] == 198: 202 | rank = rank_idx 203 | inp[i] = indices[rank_idx].item() 204 | break 205 | 206 | # Is there a more likely prefix token that could be the actual token generated? 207 | if len(prop_token_text) <= len(true_token_text) and \ 208 | prop_token_text == true_token_text[:len(prop_token_text)]: 209 | rank = rank_idx 210 | suffix = true_token_text[len(prop_token_text):] 211 | suffix_tokens = enc.encode(suffix) # a list 212 | inp[i] = indices[rank_idx].item() 213 | inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list 214 | break 215 | 216 | # Is there a more likely longer token that could be the actual token generated? 217 | elif len(prop_token_text) > len(true_token_text) and \ 218 | true_token_text == prop_token_text[:len(true_token_text)]: 219 | whole_text = true_token_text 220 | num_extra = 1 221 | while len(whole_text) < len(prop_token_text): 222 | whole_text += enc.decode([inp[i+num_extra]]) 223 | num_extra += 1 224 | if prop_token_text == whole_text[:len(prop_token_text)]: 225 | rank = rank_idx 226 | inp[i] = indices[rank_idx].item() 227 | for j in range(1, num_extra): 228 | del inp[i+j] 229 | 230 | if len(whole_text) > len(prop_token_text): 231 | suffix = whole_text[len(prop_token_text):] 232 | suffix_tokens = enc.encode(suffix) # a list 233 | inp[i+1:i+1] = suffix_tokens # insert suffix tokens into list 234 | break 235 | else: 236 | print('Unable to fix BPE error: token received: %s=%d, text: %s' % (true_token_text, inp[i], text)) 237 | rank = 0 238 | 239 | selection = rank 240 | 241 | # Calculate new range as ints 242 | new_int_bottom = cum_probs[selection-1] if selection > 0 else cur_interval[0] 243 | new_int_top = cum_probs[selection] 244 | 245 | # Convert range to bits 246 | new_int_bottom_bits_inc = list(reversed(int2bits(new_int_bottom, precision))) 247 | new_int_top_bits_inc = list(reversed(int2bits(new_int_top-1, precision))) # -1 here because upper bound is exclusive 248 | 249 | # Emit most significant bits which are now fixed and update interval 250 | num_bits_encoded = num_same_from_beg(new_int_bottom_bits_inc, new_int_top_bits_inc) 251 | if i == len(inp)-1: 252 | new_bits = new_int_bottom_bits_inc 253 | else: 254 | new_bits = new_int_top_bits_inc[:num_bits_encoded] 255 | message += new_bits 256 | 257 | new_int_bottom_bits = new_int_bottom_bits_inc[num_bits_encoded:] + [0]*num_bits_encoded 258 | new_int_top_bits = new_int_top_bits_inc[num_bits_encoded:] + [1]*num_bits_encoded 259 | 260 | cur_interval[0] = bits2int(reversed(new_int_bottom_bits)) 261 | cur_interval[1] = bits2int(reversed(new_int_top_bits))+1 # +1 here because upper bound is exclusive 262 | 263 | # Update history with new token 264 | prev = torch.tensor([inp[i]], device=device, dtype=torch.long) 265 | # print(enc.decode([inp[i]]), new_bits) 266 | i += 1 267 | 268 | return message 269 | --------------------------------------------------------------------------------