├── .gitignore ├── README.md ├── audio_samples ├── README.md ├── p=0.80_0_Discop.flac ├── p=0.80_0_RS.flac ├── p=0.92_4_Discop.flac ├── p=0.92_4_RS.flac ├── p=0.95_1_Discop.flac ├── p=0.95_1_RS.flac ├── p=0.98_2_Discop.flac ├── p=0.98_2_RS.flac ├── p=1.00_3_Discop.flac └── p=1.00_3_RS.flac ├── recursion.png ├── requirements.txt ├── rotate.png ├── src ├── README.md ├── config.py ├── get_statistics.py ├── model.py ├── random_sample_cy.pyx ├── run_single_example.py ├── setup.py ├── stega_cy.pyx ├── stega_tts.py ├── tacotron │ ├── TTS_cleaner.py │ ├── __init__.py │ ├── cmudict-0.7b.txt │ ├── config.toml │ ├── dataset.py │ ├── model.py │ └── text.py ├── univoc │ ├── __init__.py │ ├── config │ │ ├── config.yaml │ │ ├── preprocess.yaml │ │ └── train.yaml │ ├── dataset.py │ └── model.py └── utils.py └── temp ├── message.txt └── small.png /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Discop 2 | 3 | Discop: Provably Secure Steganography in Practice Based on “Distribution Copies” 4 | 5 | [Jinyang Ding](https://dingjinyang.github.io/), [Kejiang Chen](http://home.ustc.edu.cn/~chenkj/), [Yaofei Wang](http://faculty.hfut.edu.cn/yaofeiwang/en/index.htm), Na Zhao, [Weiming Zhang](http://staff.ustc.edu.cn/~zhangwm/), and [Nenghai Yu](http://staff.ustc.edu.cn/~ynh/) 6 | 7 | In [IEEE Symposium on Security and Privacy (IEEE S&P) 2023](https://sp2023.ieee-security.org/) 8 | 9 | [![paper](https://img.shields.io/badge/paper-red)](https://dingjinyang.github.io/uploads/Discop_sp23_paper.pdf) [![cite](https://img.shields.io/badge/cite-orange)](#citation) [![slides](https://img.shields.io/badge/slides-yellow)](https://dingjinyang.github.io/uploads/Discop_sp23_slides.pdf) [![doi](https://img.shields.io/badge/doi-green)](https://doi.org/10.1109/SP46215.2023.10179287) [![blog_post](https://img.shields.io/badge/blog_post-blue)](https://comydream.github.io/2023/06/07/discop-sp23/) [![semantic_scholar](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.semanticscholar.org%2Fgraph%2Fv1%2Fpaper%2F200526f0cfaf9ac9e452890b3ef7bc1a4b42c98a?fields=citationCount&query=citationCount&prefix=cited%20by%20&logo=semanticscholar&label=%20&labelColor=purple&color=purple)](https://www.semanticscholar.org/paper/Discop%3A-Provably-Secure-Steganography-in-Practice-Ding-Chen/200526f0cfaf9ac9e452890b3ef7bc1a4b42c98a) 10 | 11 | ## Brief Overview 12 | 13 | Given a probability distribution to sample from, if we want to embed $n$ bits, we construct $2^{n}$ copies of the distribution by rotation and use the copy index to express information. 14 | 15 | ![distribution copies](rotate.png) 16 | 17 | To improve the embedding rate, we decompose the multi-variate distribution into multiple bi-variate distributions through a Huffman tree. 18 | 19 | ![recursion](recursion.png) 20 | 21 | The embedding rate can reach about 0.95 of its theoretical limit. 22 | 23 | ## Usage 24 | 25 | ### Preparation 26 | 27 | First, please ensure that you have installed all the required libraries for this repository. 28 | 29 | We recommend using [Anaconda](https://anaconda.org/anaconda/conda) and execute the following commands. 30 | 31 | ```shell 32 | conda create -n discop python=3.8.12 33 | conda activate discop 34 | 35 | # Visit the PyTorch website (https://pytorch.org/get-started/locally/) for installation commands tailored to your environment 36 | # We have not tested PyTorch versions other than v1.12.0. 37 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch 38 | # Install other dependencies 39 | python -m pip install -r requirements.txt 40 | 41 | # Build the Cython files 42 | python src/setup.py build_ext --build-lib=src/ 43 | ``` 44 | 45 | ### Run Single Example 46 | 47 | You can modify the default settings for each generation task in `src/config.py`. 48 | 49 | The program may automatically download the required pretrained models and datasets during the first run. 50 | 51 | ```shell 52 | python src/run_single_example.py 53 | ``` 54 | 55 | ### Get Statistics 56 | 57 | ```shell 58 | python src/get_statistics.py 59 | ``` 60 | 61 | ## Acknowledgment 62 | 63 | In the text generation and image completion tasks, we directly employ the pre-trained models provided by [Hugging Face](https://huggingface.co/models). 64 | 65 | In the text-to-speech (TTS) task, we utilize publicly available pre-trained models from [bshall/Tacotron](https://github.com/bshall/Tacotron/tree/main/tacotron) and [bshall/UniversalVocoding](https://github.com/bshall/UniversalVocoding). 66 | We have incorporated them into our code repository (`src/tacotron/` and `src/univoc/`) and made some adaptations as needed. 67 | 68 | - Add `src/tacotron/TTS_cleaner.py`, which borrows from [Coqui.ai TTS](https://github.com/coqui-ai/TTS/blob/main/TTS/tts/utils/text/cleaners.py). 69 | - Add the `encode_speech()`, `decode_speech()`, and `random_sample_speech()` functions in `src/univoc/model.py` to facilitate Discop’s message embedding and extraction, as well as random sampling. 70 | 71 | ## Citation 72 | 73 | If you find this work useful, please consider citing: 74 | 75 | ```latex 76 | @inproceedings{dingDiscopProvablySecure2023, 77 | title = {Discop: Provably Secure Steganography in Practice Based on ``Distribution Copies''}, 78 | shorttitle = {Discop}, 79 | booktitle = {2023 {IEEE} Symposium on Security and Privacy ({SP})}, 80 | author = {Ding, Jinyang and Chen, Kejiang and Wang, Yaofei and Zhao, Na and Zhang, Weiming and Yu, Nenghai}, 81 | year = {2023}, 82 | month = may, 83 | pages = {2238--2255}, 84 | publisher = {{IEEE}}, 85 | doi = {10.1109/SP46215.2023.10179287}, 86 | url = {https://ieeexplore.ieee.org/document/10179287}, 87 | isbn = {978-1-66549-336-9}, 88 | langid = {english} 89 | } 90 | ``` 91 | 92 | ## Further Reading 93 | 94 | [comydream/provably-secure-steganography: Provably Secure Steganography](https://github.com/comydream/provably-secure-steganography) 95 | -------------------------------------------------------------------------------- /audio_samples/README.md: -------------------------------------------------------------------------------- 1 | # Audio Clip Examples 2 | 3 | File name format: 4 | 5 | ``` 6 | p=__.flac 7 | ``` 8 | 9 | Text: 10 | 11 | | text-id | Text | 12 | | ------- | -------------------------------------------------------------------------------------------------------------------------------------- | 13 | | 0 | I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. | 14 | | 1 | I think I will make a movie next weekend. | 15 | | 2 | When I first saw a glimpse of this movie, I quickly noticed the actress who was playing the role of Lucille Ball. | 16 | | 3 | This is said to be a personal film for Peter. | 17 | | 4 | This film is just plain horrible. | 18 | 19 | 20 | -------------------------------------------------------------------------------- /audio_samples/p=0.80_0_Discop.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=0.80_0_Discop.flac -------------------------------------------------------------------------------- /audio_samples/p=0.80_0_RS.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=0.80_0_RS.flac -------------------------------------------------------------------------------- /audio_samples/p=0.92_4_Discop.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=0.92_4_Discop.flac -------------------------------------------------------------------------------- /audio_samples/p=0.92_4_RS.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=0.92_4_RS.flac -------------------------------------------------------------------------------- /audio_samples/p=0.95_1_Discop.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=0.95_1_Discop.flac -------------------------------------------------------------------------------- /audio_samples/p=0.95_1_RS.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=0.95_1_RS.flac -------------------------------------------------------------------------------- /audio_samples/p=0.98_2_Discop.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=0.98_2_Discop.flac -------------------------------------------------------------------------------- /audio_samples/p=0.98_2_RS.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=0.98_2_RS.flac -------------------------------------------------------------------------------- /audio_samples/p=1.00_3_Discop.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=1.00_3_Discop.flac -------------------------------------------------------------------------------- /audio_samples/p=1.00_3_RS.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/audio_samples/p=1.00_3_RS.flac -------------------------------------------------------------------------------- /recursion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/recursion.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | psutil 2 | entrypoints 3 | jsonschema>=3.0 4 | future 5 | matplotlib 6 | openpyxl 7 | tensorboard>=2.9.1 8 | google-auth-oauthlib<1.1,>=0.5 9 | tensorboard-data-server<0.8.0,>=0.7.0 10 | werkzeug>=1.0.1 11 | requests==2.25.1 12 | sacremoses 13 | nltk==3.7 14 | datasets==2.0.0 15 | transformers~=4.19.2 16 | scipy~=1.7.3 17 | pandas~=1.3.5 18 | numpy<2,>=1.22.1 19 | librosa~=0.8.0 20 | tqdm~=4.62.3 21 | omegaconf~=2.0.6 22 | pillow~=9.0.1 23 | cython~=0.29.28 24 | soundfile~=0.10.3.post1 25 | toml==0.10.2 26 | inflect==5.6.0 27 | importlib-resources -------------------------------------------------------------------------------- /rotate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/rotate.png -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | ## Directory Tree 2 | 3 | ``` 4 | . 5 | ├── tacotron/ 6 | ├── univoc/ 7 | ├── config.py 8 | ├── get_statistics.py 9 | ├── model.py 10 | ├── random_sample_cy.pyx 11 | ├── run_single_example.py 12 | ├── setup.py 13 | ├── stega_cy.pyx 14 | ├── stega_tts.py 15 | └── utils.py 16 | ``` 17 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class Settings: 6 | 7 | def __init__(self, 8 | task: str = 'text', 9 | algo: str = 'Discop', 10 | model_name: str = 'gpt2', 11 | temp: float = 1.0, 12 | top_p: float = 0.92, 13 | length: int = 100, 14 | seed: int = os.urandom(1), 15 | device=torch.device('cpu')): 16 | 17 | if task not in ['text', 'image', 'text-to-speech']: 18 | raise NotImplementedError("`Settings.task` must belong to {'text', 'image', 'text-to-speech'}!") 19 | self.task = task 20 | 21 | if algo not in ['Discop', 'Discop_baseline', 'sample']: 22 | raise NotImplementedError("`Settings.algo` must belong to {'Discop', 'Discop_baseline', 'sample'}!") 23 | self.algo = algo 24 | 25 | self.model_name = model_name 26 | 27 | if temp is None: 28 | temp = 1.0 29 | self.temp = temp 30 | 31 | if top_p is None: 32 | top_p = 1.0 33 | elif top_p <= 0 or top_p > 1: 34 | raise ValueError('`top_p` must be in (0, 1]!') 35 | self.top_p = top_p 36 | 37 | self.length = length 38 | self.seed = seed 39 | self.device = device 40 | 41 | def __call__(self): 42 | return self.algo, self.temp, self.top_p, self.length, self.seed 43 | 44 | def __str__(self): 45 | return '\n'.join('{} = {}'.format(key, value) for (key, value) in self.__dict__.items()) 46 | 47 | 48 | # text_default_settings = Settings('text', model_name='gpt2', top_p=0.92, length=100) 49 | text_default_settings = Settings('text', model_name='transfo-xl-wt103', top_p=0.92, length=100) 50 | 51 | image_default_settings = Settings('image', model_name='openai/imagegpt-small', top_p=1.0) 52 | audio_default_settings = Settings('text-to-speech', model_name='univoc', top_p=0.98) 53 | -------------------------------------------------------------------------------- /src/get_statistics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from scipy.io.wavfile import write 9 | from nltk import sent_tokenize 10 | from datasets import load_dataset 11 | 12 | from config import Settings, text_default_settings, image_default_settings, audio_default_settings 13 | from model import get_model, get_feature_extractor, get_tokenizer 14 | from utils import check_dir, SingleExampleOutput 15 | 16 | # top_p_lst = [0.80, 0.92, 0.95, 0.98, 1.0] 17 | # top_p_lst = [0.95] 18 | top_p_lst = [0.95] 19 | 20 | text_context_dataset = 'imdb' 21 | # image_context_dataset = 'huggan/CelebA-faces' 22 | image_context_dataset = 'nielsr/CelebA-faces' 23 | tts_text_dataset = 'imdb' 24 | 25 | message_file_path = os.path.join('temp', 'message.txt') 26 | with open(message_file_path, 'r', encoding='utf-8') as f: 27 | message = f.read() 28 | 29 | summary_columns = [ 30 | 'algorithm', 'temperature', 'top-p', 'total_n_bits', 'total_n_tokens', 'total_entropy', 'total_time_cost', 'ave_time_cost', 31 | 'ave_kld', 'max_kld', 'ave_embedding_rate', 'ave_entropy', 'utilization_rate', 'ave_perplexity', 'ave_minimum_entropy', 32 | 'perplexity_std' 33 | ] 34 | 35 | 36 | class Summary: 37 | 38 | def __init__(self, settings: Settings) -> None: 39 | self.task = settings.task 40 | self.n_examples = 0 41 | self.total_ave_kld = 0 42 | self.total_minimum_entropy = 0 43 | self.perplexity_list = [] 44 | self.output = { 45 | 'algorithm': settings.algo, 46 | 'temperature': settings.temp, 47 | 'top-p': settings.top_p, 48 | 'total_n_bits': 0, 49 | 'total_n_tokens': 0, 50 | 'total_entropy': 0, 51 | 'total_time_cost': 0, 52 | 'max_kld': 0 53 | } 54 | 55 | def __str__(self) -> str: 56 | self.process() 57 | selected_attr = list(self.output.keys()) 58 | return '\n'.join('{} = {}'.format(x, self.output[x]) for x in selected_attr) 59 | 60 | def add_example(self, example: SingleExampleOutput) -> None: 61 | self.output['total_n_bits'] += example.n_bits 62 | self.output['total_n_tokens'] += example.n_tokens 63 | self.output['total_entropy'] += example.total_entropy 64 | self.output['total_time_cost'] += example.time_cost 65 | self.perplexity_list.append(example.perplexity) 66 | self.total_ave_kld += example.ave_kld 67 | if example.max_kld > self.output['max_kld']: 68 | self.output['max_kld'] = example.max_kld 69 | self.n_examples += 1 70 | self.total_minimum_entropy += example.total_minimum_entropy 71 | 72 | def process(self) -> None: 73 | self.output['ave_embedding_rate'] = self.output['total_n_bits'] / self.output['total_n_tokens'] 74 | self.output['utilization_rate'] = self.output['total_n_bits'] / self.output['total_entropy'] if self.output[ 75 | 'total_entropy'] != 0 else 0 76 | self.output['ave_entropy'] = self.output['total_entropy'] / self.output['total_n_tokens'] 77 | # self.output['ave_perplexity'] = self.total_perplexity / self.n_examples 78 | self.output['ave_perplexity'] = np.mean(self.perplexity_list) 79 | self.output['perplexity_std'] = np.std(self.perplexity_list) 80 | self.output['ave_kld'] = self.total_ave_kld / self.n_examples 81 | self.output['ave_time_cost'] = self.output['total_time_cost'] / self.output['total_n_bits'] if self.output[ 82 | 'total_n_bits'] != 0 else 0 83 | self.output['ave_minimum_entropy'] = self.total_minimum_entropy / self.output['total_n_tokens'] 84 | 85 | def gather(self) -> pd.DataFrame: 86 | self.process() 87 | ret_lst = [] 88 | for column in summary_columns: 89 | ret_lst.append(self.output[column]) 90 | df = pd.DataFrame(ret_lst, index=summary_columns).T 91 | # perplexity_np = np.array(self.perplexity_list) 92 | # save_perplexity_np_dir = os.path.join('results', self.task) 93 | # save_perplexity_np_path = os.path.join(save_perplexity_np_dir, 94 | # 'perplexity_{}.npy'.format(time.strftime("%m%d_%H%M", time.localtime()))) 95 | # check_dir(save_perplexity_np_dir) 96 | # np.save(save_perplexity_np_path, perplexity_np) 97 | return df 98 | 99 | 100 | def get_text_statistics(settings: Settings = text_default_settings, n_examples: int = 100, save_data: bool = False) -> None: 101 | if settings.algo == 'sample': 102 | from random_sample_cy import encode_text 103 | elif settings.algo in ['Discop', 'Discop_baseline']: 104 | from stega_cy import encode_text 105 | else: 106 | raise NotImplementedError("`Settings.algo` must belong to {'Discop', 'Discop_baseline', 'sample'}!") 107 | 108 | model = get_model(settings) 109 | tokenizer = get_tokenizer(settings) 110 | 111 | dataset = load_dataset(text_context_dataset, split='train')[:n_examples]['text'] 112 | 113 | time_stamp = time.strftime("%m%d_%H%M", time.localtime()) 114 | 115 | if save_data: 116 | save_data_main_dir = os.path.join( 117 | 'data', settings.task, 118 | '{}_{}_{}'.format(settings.model_name.split('/')[-1], 119 | str(settings.device).replace(':', '-'), time_stamp)) 120 | check_dir(save_data_main_dir) 121 | 122 | df = pd.DataFrame(columns=summary_columns) 123 | for top_p in top_p_lst: 124 | settings.top_p = top_p 125 | summary = Summary(settings) 126 | 127 | if save_data: 128 | save_context_path = os.path.join(save_data_main_dir, '{:.2f}_context.txt'.format(top_p)) 129 | save_stego_path = os.path.join(save_data_main_dir, '{:.2f}_stego.txt'.format(top_p)) 130 | f_context = open(save_context_path, 'w') 131 | f_stego = open(save_stego_path, 'w') 132 | context_lst = [] 133 | stego_lst = [] 134 | 135 | for i in tqdm(range(n_examples), ncols=70, desc='p={:.2f}'.format(top_p)): 136 | random.seed(os.urandom(1)) 137 | message_start_index = random.randint(0, 10000) 138 | 139 | context = dataset[i] 140 | context = context.replace('

', ' ').replace('
', ' ') # remove all '
' 141 | context = ' '.join(sent_tokenize(context)[:3]) # Selecting leading 3 sentences as `context` 142 | settings.seed = os.urandom(1) 143 | example = encode_text(model, tokenizer, message[message_start_index:], context, settings) 144 | summary.add_example(example) 145 | if save_data: 146 | context_lst.append(context) 147 | stego_lst.append(example.stego_object) 148 | if save_data: 149 | f_context.write('\n'.join(context_lst)) 150 | f_stego.write('\n'.join(stego_lst)) 151 | 152 | print(summary) 153 | print() 154 | df = pd.concat([df, summary.gather()], ignore_index=True) 155 | save_table_dir = os.path.join('results', settings.task) 156 | check_dir(save_table_dir) 157 | save_table_filename = '{}_{}_{}.xlsx'.format( 158 | settings.model_name.split('/')[-1], 159 | str(settings.device).replace(':', '-'), time_stamp) 160 | save_table_path = os.path.join(save_table_dir, save_table_filename) 161 | df.to_excel(save_table_path) 162 | if save_data: 163 | f_context.close() 164 | f_stego.close() 165 | 166 | 167 | def get_image_statistics(settings: Settings = image_default_settings, 168 | n_examples: int = 100, 169 | context_ratio: float = 0.5, 170 | save_data: bool = False) -> None: 171 | if settings.algo == 'sample': 172 | from random_sample_cy import encode_image 173 | elif settings.algo in ['Discop', 'Discop_baseline']: 174 | from stega_cy import encode_image 175 | else: 176 | raise NotImplementedError("`Settings.algo` must belong to {'Discop', 'Discop_baseline', 'sample'}!") 177 | 178 | model = get_model(settings) 179 | feature_extractor = get_feature_extractor(settings) 180 | 181 | dataset = load_dataset(image_context_dataset, split='train')[:n_examples]['image'] 182 | width, height = dataset[0].size 183 | # resize 184 | width_after = feature_extractor.size 185 | height_after = round(width_after / width * height) 186 | 187 | time_stamp = time.strftime("%m%d_%H%M", time.localtime()) 188 | 189 | if save_data: 190 | save_data_main_dir = os.path.join( 191 | 'data', settings.task, 192 | '{}_{}_{}'.format(settings.model_name.split('/')[-1], 193 | str(settings.device).replace(':', '-'), time_stamp)) 194 | 195 | df = pd.DataFrame(columns=summary_columns) 196 | for top_p in top_p_lst: 197 | settings.top_p = top_p 198 | summary = Summary(settings) 199 | 200 | if save_data: 201 | save_data_sub_dir = os.path.join(save_data_main_dir, '{:.2f}'.format(top_p)) 202 | check_dir(save_data_sub_dir) 203 | 204 | for i in tqdm(range(n_examples), ncols=70, desc='p={:.2f}'.format(top_p)): 205 | # for i in tqdm(range(6025, 10000), ncols=70, desc='p={:.2f}'.format(top_p)): 206 | random.seed(os.urandom(1)) 207 | message_start_index = random.randint(0, 10000) 208 | 209 | original_img = dataset[i] 210 | original_img = original_img.resize([width_after, height_after]) 211 | original_img = original_img.crop((0, 4, 32, 36)) 212 | 213 | settings.seed = os.urandom(1) 214 | example = encode_image(model, 215 | feature_extractor, 216 | message[message_start_index:], 217 | settings, 218 | context_ratio=context_ratio, 219 | original_img=original_img) 220 | summary.add_example(example) 221 | if save_data: 222 | save_data_path = os.path.join(save_data_sub_dir, '{}.png'.format(i)) 223 | example.stego_object.save(save_data_path) 224 | print(summary) 225 | print() 226 | df = pd.concat([df, summary.gather()], ignore_index=True) 227 | save_table_dir = os.path.join('results', settings.task) 228 | check_dir(save_table_dir) 229 | save_table_filename = '{}_{}_{}.xlsx'.format( 230 | settings.model_name.split('/')[-1], 231 | str(settings.device).replace(':', '-'), time_stamp) 232 | save_table_path = os.path.join(save_table_dir, save_table_filename) 233 | df.to_excel(save_table_path) 234 | 235 | 236 | def get_audio_statistics(settings: Settings = audio_default_settings, n_examples: int = 30, save_data: bool = False) -> None: 237 | from stega_tts import get_tts_model 238 | if settings.algo == 'sample': 239 | from stega_tts import random_sample_speech as encode_speech 240 | elif settings.algo in ['Discop', 'Discop_baseline']: 241 | from stega_tts import encode_speech 242 | else: 243 | raise NotImplementedError("`Settings.algo` must belong to {'Discop', 'Discop_baseline', 'sample'}!") 244 | 245 | vocoder, tacotron, cmudict = get_tts_model(settings) 246 | 247 | dataset = load_dataset(tts_text_dataset, split='train')[:n_examples]['text'] 248 | 249 | time_stamp = time.strftime("%m%d_%H%M", time.localtime()) 250 | 251 | if save_data: 252 | save_data_main_dir = os.path.join( 253 | 'data', settings.task, 254 | '{}_{}_{}'.format(settings.model_name.split('/')[-1], 255 | str(settings.device).replace(':', '-'), time_stamp)) 256 | check_dir(save_data_main_dir) 257 | 258 | df = pd.DataFrame(columns=summary_columns) 259 | for top_p in top_p_lst: 260 | settings.top_p = top_p 261 | 262 | summary = Summary(settings) 263 | 264 | if save_data: 265 | save_data_sub_dir = os.path.join(save_data_main_dir, '{:.2f}'.format(top_p)) 266 | check_dir(save_data_sub_dir) 267 | save_text_path = os.path.join(save_data_main_dir, '{:.2f}.txt'.format(top_p)) 268 | text_lst = [] 269 | f_text = open(save_text_path, 'w') 270 | 271 | for i in tqdm(range(n_examples), ncols=70, desc='p={:.2f}'.format(top_p)): 272 | random.seed(os.urandom(1)) 273 | message_start_index = random.randint(0, 10000) 274 | 275 | text = dataset[i] 276 | text = text.replace('

', ' ').replace('
', ' ') # remove all '
' 277 | text = ' '.join(sent_tokenize(text)[:1]) # Selecting leading 1 sentences as `context` 278 | settings.seed = os.urandom(1) 279 | example, sr = encode_speech(vocoder, tacotron, cmudict, message[message_start_index:], text, settings) 280 | 281 | summary.add_example(example) 282 | if save_data: 283 | save_data_path = os.path.join(save_data_sub_dir, '{}.flac'.format(i)) 284 | write(os.path.join(save_data_path), sr, example.stego_object) 285 | text_lst.append(text) 286 | if save_data: 287 | f_text.write('\n'.join(text_lst)) 288 | 289 | print(summary) 290 | print() 291 | df = pd.concat([df, summary.gather()], ignore_index=True) 292 | save_table_dir = os.path.join('results', settings.task) 293 | check_dir(save_table_dir) 294 | save_table_filename = '{}_{}_{}.xlsx'.format( 295 | settings.model_name.split('/')[-1], 296 | str(settings.device).replace(':', '-'), time_stamp) 297 | save_table_path = os.path.join(save_table_dir, save_table_filename) 298 | df.to_excel(save_table_path) 299 | if save_data: 300 | f_text.close() 301 | 302 | 303 | if __name__ == '__main__': 304 | # # Text Generation 305 | settings = text_default_settings 306 | settings.device = torch.device('cuda:0') 307 | # settings.algo = 'Discop_baseline' 308 | # settings.algo = 'sample' 309 | get_text_statistics(settings, n_examples=10000, save_data=True) 310 | 311 | # # Image Completion 312 | settings = image_default_settings 313 | settings.device = torch.device('cuda:0') 314 | # settings.algo = 'Discop_baseline' 315 | # settings.algo = 'sample' 316 | get_image_statistics(settings, n_examples=10000, save_data=True) 317 | 318 | # Text-to-Speech 319 | settings = audio_default_settings 320 | settings.device = torch.device('cuda:0') 321 | # settings.algo = 'Discop_baseline' 322 | # settings.algo = 'sample' 323 | get_audio_statistics(settings, n_examples=1000, save_data=True) 324 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 2 | from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel 3 | from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling 4 | 5 | from transformers import PreTrainedTokenizer, PreTrainedModel 6 | from config import Settings 7 | 8 | 9 | def get_model(settings: Settings) -> PreTrainedModel: 10 | if settings.task == 'text': 11 | if settings.model_name in ['gpt2', 'distilgpt2']: 12 | model = GPT2LMHeadModel.from_pretrained(settings.model_name).to(settings.device) 13 | elif settings.model_name == 'transfo-xl-wt103': 14 | model = TransfoXLLMHeadModel.from_pretrained(settings.model_name).to(settings.device) 15 | else: 16 | raise NotImplementedError 17 | elif settings.task == 'image': 18 | if settings.model_name == 'openai/imagegpt-small': 19 | model = ImageGPTForCausalImageModeling.from_pretrained(settings.model_name).to(settings.device) 20 | else: 21 | raise NotImplementedError 22 | else: 23 | raise NotImplementedError 24 | model.eval() 25 | return model 26 | 27 | 28 | def get_tokenizer(settings: Settings) -> PreTrainedTokenizer: 29 | assert settings.task == 'text' 30 | if settings.model_name in ['gpt2', 'distilgpt2']: 31 | tokenizer = GPT2Tokenizer.from_pretrained(settings.model_name) # local_files_only=True 32 | elif settings.model_name == 'transfo-xl-wt103': 33 | tokenizer = TransfoXLTokenizer.from_pretrained(settings.model_name) # local_files_only=True 34 | else: 35 | raise NotImplementedError 36 | return tokenizer 37 | 38 | 39 | def get_feature_extractor(settings: Settings) -> ImageGPTFeatureExtractor: 40 | assert settings.task == 'image' 41 | if settings.model_name == 'openai/imagegpt-small': 42 | feature_extractor = ImageGPTFeatureExtractor.from_pretrained(settings.model_name) # local_files_only=True 43 | else: 44 | raise NotImplementedError 45 | return feature_extractor 46 | -------------------------------------------------------------------------------- /src/random_sample_cy.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # cython: c_string_type=unicode, c_string_encoding=utf8 3 | from cython.operator cimport dereference as deref 4 | from libc.math cimport log2 5 | from libc.time cimport time, time_t, difftime 6 | from libcpp cimport nullptr 7 | from libcpp.map cimport map 8 | from libcpp.string cimport string 9 | from libcpp.vector cimport vector 10 | from libcpp.queue cimport queue 11 | from libcpp.memory cimport shared_ptr, make_shared 12 | 13 | import random 14 | import numpy as np 15 | from PIL import Image 16 | import torch 17 | from scipy.stats import entropy 18 | from tqdm import tqdm 19 | # import time as py_time 20 | 21 | from config import text_default_settings, image_default_settings, audio_default_settings 22 | from model import get_model, get_tokenizer, get_feature_extractor 23 | from utils import get_probs_indices_past, set_seed, SingleExampleOutput 24 | 25 | ## Classes & Structures 26 | # Sampling (Encoding) results and statistics for single time step 27 | cdef struct CySingleEncodeStepOutput: 28 | int sampled_index 29 | int n_bits 30 | double entropy_t 31 | double kld 32 | double minimum_entropy_t 33 | 34 | cdef class SingleEncodeStepOutput: 35 | cdef public: 36 | int sampled_index, n_bits 37 | double entropy_t, kld, minimum_entropy_t 38 | def __init__(self, 39 | int sampled_index, 40 | int n_bits, 41 | double entropy_t, 42 | double kld, 43 | double minimum_entropy_t): 44 | self.sampled_index = sampled_index 45 | self.n_bits = n_bits 46 | self.entropy_t = entropy_t 47 | self.kld = kld 48 | self.minimum_entropy_t = minimum_entropy_t 49 | 50 | def __call__(self): 51 | return self.sampled_index, self.n_bits, self.entropy_t, self.kld, self.minimum_entropy_t 52 | 53 | def __str__(self): 54 | d = { 55 | 'sampled_index': self.sampled_index, 56 | 'n_bits': self.n_bits, 57 | 'entropy_t': self.entropy_t, 58 | 'kld': self.kld, 59 | 'minimum_entropy_t': self.minimum_entropy_t 60 | } 61 | return '\n'.join('{} = {}'.format(key, value) for (key, value) in d.items()) 62 | 63 | 64 | ## Random sampling - single time step 65 | cdef CySingleEncodeStepOutput cy_encode_step(list indices, list probs, string message_bits): 66 | # Encode step 67 | cdef: 68 | int index_idx, sampled_index 69 | double ptr = random.random(), minimum_entropy_t = -log2(probs[0]), entropy_t = entropy(probs, base=2) 70 | 71 | probs_cumsum = torch.tensor(probs).cumsum(dim=0) 72 | interval_begin = torch.cat((torch.tensor([0], device=probs_cumsum.device), probs_cumsum[:-1]), dim=0) 73 | index_idx = (ptr >= interval_begin).nonzero()[-1].item() 74 | sampled_index = indices[index_idx] 75 | 76 | return CySingleEncodeStepOutput(sampled_index, 0, entropy_t, 0, minimum_entropy_t) 77 | 78 | ## Random sampling - main loop 79 | def encode(model, context, message_bits, settings, bint verbose = False, string tqdm_desc = b'Enc '): 80 | # Steganography Encoding (message_bits -> English text) 81 | cdef: 82 | int t = 0, length = settings.length, indices_idx 83 | double time_cost 84 | time_t start, end 85 | string stego_object 86 | list generated_ids = [] 87 | 88 | # CySingleEncodeStepOutput 89 | CySingleEncodeStepOutput single_encode_step_output 90 | int sampled_index 91 | int capacity_t 92 | double entropy_t 93 | double kld_step 94 | double minimum_entropy_t 95 | 96 | # statistics 97 | int total_capacity = 0 98 | double total_entropy = 0.0 99 | double total_minimum_entropy = 0.0 100 | double total_log_probs = 0.0 # for perplexity 101 | double total_kld = 0.0 102 | double max_kld = 0.0 103 | double perplexity, ave_kld 104 | 105 | set_seed(settings.seed) 106 | 107 | past = None # pass into the `past_keys_values` for speed-up 108 | prev = context # indices that were never passed to the model before 109 | 110 | start = time(NULL) 111 | for t in tqdm(range(length), desc=tqdm_desc, ncols=70): 112 | probs, indices, past = get_probs_indices_past(model, prev, past, settings) 113 | probs = probs.tolist() 114 | indices = indices.tolist() 115 | 116 | single_encode_step_output = cy_encode_step(indices, probs, message_bits) 117 | sampled_index = single_encode_step_output.sampled_index 118 | capacity_t = single_encode_step_output.n_bits 119 | entropy_t = single_encode_step_output.entropy_t 120 | kld_step = single_encode_step_output.kld 121 | minimum_entropy_t = single_encode_step_output.minimum_entropy_t 122 | 123 | indices_idx = indices.index(sampled_index) 124 | 125 | # update statistics 126 | total_entropy += entropy_t 127 | total_minimum_entropy += minimum_entropy_t 128 | total_log_probs += log2(probs[indices_idx]) 129 | total_kld += kld_step 130 | if kld_step > max_kld: 131 | max_kld = kld_step 132 | 133 | # when `capacity_t == 0`, cannot embed message, but still needs to return a token_index 134 | if capacity_t > 0: 135 | total_capacity += capacity_t 136 | message_bits = message_bits[capacity_t:] # remove the encoded part of `message_bits` 137 | generated_ids.append(sampled_index) 138 | if settings.task == 'text': 139 | prev = torch.tensor([sampled_index], device=settings.device).unsqueeze(0) 140 | elif settings.task == 'image': 141 | prev = torch.tensor([sampled_index], device=settings.device) 142 | end = time(NULL) 143 | time_cost = difftime(end, start) 144 | 145 | perplexity = 2 ** (-1 / length * total_log_probs) 146 | ave_kld = total_kld / length 147 | return SingleExampleOutput(generated_ids, None, total_capacity, total_entropy, ave_kld, max_kld, perplexity, 148 | time_cost, settings, 149 | total_minimum_entropy) 150 | 151 | def encode_text(model, tokenizer, message_bits, context, settings = text_default_settings): 152 | # tokenizer = get_tokenizer(settings) 153 | # model = get_model(settings) 154 | context = tokenizer(context, return_tensors='pt', max_length=1024, truncation=True)['input_ids'].to(settings.device) 155 | 156 | single_encode_step_output = encode(model, context, message_bits, settings) 157 | single_encode_step_output.stego_object = tokenizer.decode(single_encode_step_output.generated_ids) 158 | 159 | return single_encode_step_output 160 | 161 | def encode_image(model, feature_extractor, message_bits, settings = image_default_settings, bint verbose = False, 162 | string tqdm_desc = b'Enc ', 163 | double context_ratio = 0.0, original_img = None): 164 | cdef: 165 | int n_pixels_to_gen = 1024, n_pixels_context, n_px, width, height, width_after, height_after, height_context 166 | # feature_extractor = get_feature_extractor(settings) 167 | # model = get_model(settings) 168 | clusters = feature_extractor.clusters # with shape (512, 3) 169 | n_px = feature_extractor.size # 32 170 | 171 | context = torch.tensor([model.config.vocab_size - 1], device=settings.device) # initialize with SOS token 172 | if context_ratio != 0.0: 173 | if original_img is None: 174 | raise ValueError('If you set `context_ratio`, please make sure that `original_img` is not None!') 175 | elif type(original_img) == str: 176 | img = Image.open(original_img) 177 | else: 178 | img = original_img 179 | width, height = img.size 180 | 181 | # resize if needed 182 | if width != n_px: 183 | width_after = n_px 184 | height_after = round(width_after / width * height) 185 | img = img.resize((width_after, height_after)) 186 | width, height = width_after, height_after 187 | 188 | pixel_indices_lst = feature_extractor(img, return_tensors='pt')['input_ids'][0].to(settings.device) 189 | height_context = round(n_px * context_ratio) 190 | n_pixels_context = n_px * height_context 191 | primers = pixel_indices_lst[:n_pixels_context] 192 | context = torch.cat((context, primers), dim=-1) 193 | n_pixels_to_gen -= len(primers) 194 | 195 | output_pixel_ids = pixel_indices_lst.tolist()[:n_pixels_context] 196 | else: 197 | output_pixel_ids = [] 198 | settings.length = n_pixels_to_gen 199 | 200 | single_encode_step_output = encode(model, context, message_bits, settings) 201 | output_pixel_ids.extend(single_encode_step_output.generated_ids) 202 | 203 | def pixel_ids_lst_to_pil_image(pixel_indices_lst): 204 | pixel_indices_array = np.array(pixel_indices_lst) 205 | img = Image.fromarray( 206 | np.reshape(np.rint(127.5 * (clusters[pixel_indices_array] + 1.0)), [n_px, n_px, 3]).astype(np.uint8)) 207 | return img 208 | single_encode_step_output.stego_object = pixel_ids_lst_to_pil_image(output_pixel_ids) 209 | return single_encode_step_output 210 | 211 | ## Python interface 212 | def encode_step(list indices, list probs, string message_bits): 213 | cdef CySingleEncodeStepOutput single_encode_step_output = cy_encode_step(indices, probs, message_bits) 214 | sampled_index = single_encode_step_output.sampled_index 215 | n_bits = single_encode_step_output.n_bits 216 | entropy_t = single_encode_step_output.entropy_t 217 | kld = single_encode_step_output.kld 218 | minimum_entropy_t = single_encode_step_output.minimum_entropy_t 219 | return SingleEncodeStepOutput(sampled_index, n_bits, entropy_t, kld, minimum_entropy_t) 220 | -------------------------------------------------------------------------------- /src/run_single_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | import torch 4 | from scipy.io.wavfile import read, write 5 | from PIL import Image 6 | 7 | from config import Settings, text_default_settings, image_default_settings, audio_default_settings 8 | from model import get_model, get_feature_extractor, get_tokenizer 9 | from utils import SingleExampleOutput, check_dir 10 | 11 | # Load message 12 | message_file_path = os.path.join('temp', 'message.txt') 13 | with open(message_file_path, 'r', encoding='utf-8') as f: 14 | message = f.read() 15 | # message *= 10 16 | 17 | 18 | def test_text(settings: Settings = text_default_settings, context: Optional[str] = None): 19 | if settings.algo == 'sample': 20 | from random_sample_cy import encode_text 21 | elif settings.algo in ['Discop', 'Discop_baseline']: 22 | from stega_cy import encode_text, decode_text 23 | else: 24 | raise NotImplementedError("`Settings.algo` must belong to {'Discop', 'Discop_baseline', 'sample'}!") 25 | 26 | if context is None: 27 | context = 'We were both young when I first saw you, I close my eyes and the flashback starts.' 28 | 29 | model = get_model(settings) 30 | tokenizer = get_tokenizer(settings) 31 | 32 | single_example_output: SingleExampleOutput = encode_text(model, tokenizer, message, context, settings) 33 | print(single_example_output) 34 | if settings.algo != 'sample': 35 | message_encoded = message[:single_example_output.n_bits] 36 | message_decoded = decode_text(model, tokenizer, single_example_output.generated_ids, context, settings) 37 | print(message_encoded) 38 | print(message_decoded) 39 | print(message_encoded == message_decoded) 40 | 41 | 42 | def test_image(settings: Settings = image_default_settings, 43 | context_ratio: float = 0.5, 44 | original_img: Image = Image.open(os.path.join('temp', 'small.png'))): 45 | if settings.algo == 'sample': 46 | from random_sample_cy import encode_image 47 | elif settings.algo in ['Discop', 'Discop_baseline']: 48 | from stega_cy import encode_image, decode_image 49 | else: 50 | raise NotImplementedError("`Settings.algo` must belong to {'Discop', 'Discop_baseline', 'sample'}!") 51 | 52 | model = get_model(settings) 53 | feature_extractor = get_feature_extractor(settings) 54 | 55 | single_example_output: SingleExampleOutput = encode_image(model, 56 | feature_extractor, 57 | message, 58 | context_ratio=context_ratio, 59 | original_img=original_img) 60 | print(single_example_output) 61 | 62 | stego_img = single_example_output.stego_object 63 | # stego_img.save('stego.png') 64 | if settings.algo != 'sample': 65 | message_encoded = message[:single_example_output.n_bits] 66 | message_decoded = decode_image(model, feature_extractor, stego_img, context_ratio=context_ratio) 67 | print(message_encoded == message_decoded) 68 | 69 | 70 | def test_tts(settings: Settings = audio_default_settings, 71 | text: str = "We are both young.", 72 | save_audio_dir: Optional[str] = 'temp'): 73 | from stega_tts import get_tts_model, encode_speech, decode_speech, random_sample_speech 74 | 75 | vocoder, tacotron, cmudict = get_tts_model(settings) 76 | 77 | # Encode 78 | if settings.algo == 'sample': 79 | single_example_output, sr = random_sample_speech(vocoder, tacotron, cmudict, message, text, settings) 80 | elif settings.algo in ['Discop', 'Discop_baseline']: 81 | single_example_output, sr = encode_speech(vocoder, tacotron, cmudict, message, text, settings) 82 | else: 83 | raise NotImplementedError("`Settings.algo` must belong to {'Discop', 'Discop_baseline', 'sample'}!") 84 | 85 | print(single_example_output) 86 | 87 | wav = single_example_output.stego_object 88 | if save_audio_dir is not None: 89 | check_dir(save_audio_dir) 90 | write(os.path.join(save_audio_dir, 'test.flac'), sr, wav) 91 | message_encoded = message[:single_example_output.n_bits] 92 | 93 | # Decode 94 | if settings.algo != 'sample': 95 | if save_audio_dir is not None: 96 | sr, wav = read(os.path.join(save_audio_dir, 'test.flac')) 97 | message_decoded = decode_speech(vocoder, tacotron, cmudict, wav, text, settings) 98 | # print(message_decoded) 99 | print(message_encoded == message_decoded) 100 | 101 | 102 | if __name__ == '__main__': 103 | # Text Generation 104 | settings = text_default_settings 105 | settings.device = torch.device('cuda:0') 106 | # settings.algo = 'Discop_baseline' 107 | # settings.algo = 'sample' 108 | context = """I remember this film, it was the first film I had watched at the cinema.""" 109 | test_text(settings, context) 110 | 111 | # Image Completion 112 | settings = image_default_settings 113 | settings.device = torch.device('cuda:0') 114 | # settings.algo = 'Discop_baseline' 115 | # settings.algo = 'sample' 116 | test_image(settings) 117 | 118 | # Text-to-Speech 119 | settings: Settings = audio_default_settings 120 | settings.device = torch.device('cuda:0') 121 | # settings.algo = 'Discop_baseline' 122 | # settings.algo = 'sample' 123 | settings.seed = 1 # debug 124 | settings.top_p = 0.98 125 | test_tts(settings) 126 | -------------------------------------------------------------------------------- /src/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from distutils.core import setup 3 | from Cython.Build import cythonize 4 | 5 | setup(ext_modules=cythonize(os.path.join('src', 'stega_cy.pyx'), 6 | annotate=False, 7 | compiler_directives={ 8 | 'boundscheck': False, 9 | 'wraparound': False, 10 | 'language_level': 3 11 | })) 12 | 13 | setup(ext_modules=cythonize(os.path.join('src', 'random_sample_cy.pyx'), 14 | annotate=False, 15 | compiler_directives={ 16 | 'boundscheck': False, 17 | 'wraparound': False, 18 | 'language_level': 3 19 | })) -------------------------------------------------------------------------------- /src/stega_cy.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # cython: c_string_type=unicode, c_string_encoding=utf8 3 | from cython.operator cimport dereference as deref 4 | from libc.math cimport log2 5 | from libcpp cimport nullptr 6 | from libcpp.map cimport map 7 | from libcpp.string cimport string 8 | from libcpp.vector cimport vector 9 | from libcpp.queue cimport queue 10 | from libcpp.memory cimport shared_ptr, make_shared 11 | 12 | import time 13 | import random 14 | import numpy as np 15 | from PIL import Image 16 | import torch 17 | from scipy.stats import entropy 18 | from tqdm import tqdm 19 | # import time as py_time 20 | 21 | from config import text_default_settings, image_default_settings, audio_default_settings 22 | from model import get_model, get_tokenizer, get_feature_extractor 23 | from utils import get_probs_indices_past, set_seed, SingleExampleOutput 24 | 25 | cdef bint msg_exhausted_flag = False 26 | 27 | ## Classes & Structures 28 | # Nodes of Huffman tree 29 | cdef struct Node: 30 | double prob 31 | shared_ptr[Node] left 32 | shared_ptr[Node] right 33 | int index 34 | # >=0 - index 35 | # -1 - None 36 | int search_path 37 | # 0 - this node 38 | # -1 - in left subtree 39 | # 1 - in right subtree 40 | # 9 - unknown 41 | 42 | 43 | cdef inline bint is_leaf(shared_ptr[Node] node_ptr): 44 | return deref(node_ptr).index != -1 45 | 46 | # Sampling (Encoding) results and statistics for single time step 47 | cdef struct CySingleEncodeStepOutput: 48 | int sampled_index 49 | int n_bits 50 | double entropy_t 51 | double kld 52 | double minimum_entropy_t 53 | 54 | cdef class SingleEncodeStepOutput: 55 | cdef public: 56 | int sampled_index, n_bits 57 | double entropy_t, kld, minimum_entropy_t 58 | def __init__(self, 59 | int sampled_index, 60 | int n_bits, 61 | double entropy_t, 62 | double kld, 63 | double minimum_entropy_t): 64 | self.sampled_index = sampled_index 65 | self.n_bits = n_bits 66 | self.entropy_t = entropy_t 67 | self.kld = kld 68 | self.minimum_entropy_t = minimum_entropy_t 69 | 70 | def __call__(self): 71 | return self.sampled_index, self.n_bits, self.entropy_t, self.kld, self.minimum_entropy_t 72 | 73 | def __str__(self): 74 | d = { 75 | 'sampled_index': self.sampled_index, 76 | 'n_bits': self.n_bits, 77 | 'entropy_t': self.entropy_t, 78 | 'kld': self.kld, 79 | 'minimum_entropy_t': self.minimum_entropy_t 80 | } 81 | return '\n'.join('{} = {}'.format(key, value) for (key, value) in d.items()) 82 | 83 | 84 | # Decoding results for single time step 85 | cdef struct CySingleDecodeStepOutput: 86 | string message_decoded_t 87 | 88 | cdef class SingleDecodeStepOutput: 89 | cdef public: 90 | string message_decoded_t 91 | 92 | def __init__(self, string message_decoded_t) -> None: 93 | self.message_decoded_t = message_decoded_t 94 | 95 | def __call__(self): 96 | return self.message_decoded_t 97 | 98 | ## Utils 99 | # Building a Huffman tree 100 | cdef shared_ptr[Node] create_huffman_tree(list indices, list probs, int search_for): 101 | # Returns a pointer to the root node of the Huffman tree 102 | # if `search_for == -1`, we don't need to initialize the `search_path` of any Node object 103 | cdef: 104 | int sz = len(indices) 105 | int i, search_path 106 | double prob 107 | shared_ptr[Node] node_ptr, first, second, ans 108 | queue[shared_ptr[Node]] q1, q2 109 | 110 | for i in range(sz - 1, -1, -1): 111 | # search_path = 0 if search_for == indices[i] else 9 112 | if search_for == indices[i]: 113 | search_path = 0 114 | else: 115 | search_path = 9 116 | node_ptr = make_shared[Node]( 117 | Node(probs[i], shared_ptr[Node](nullptr), shared_ptr[Node](nullptr), indices[i], search_path)) 118 | q1.push(node_ptr) 119 | 120 | while q1.size() + q2.size() > 1: 121 | # first 122 | if not q1.empty() and not q2.empty() and deref(q1.front()).prob < deref(q2.front()).prob: 123 | first = q1.front() 124 | q1.pop() 125 | elif q1.empty(): 126 | first = q2.front() 127 | q2.pop() 128 | elif q2.empty(): 129 | first = q1.front() 130 | q1.pop() 131 | else: 132 | first = q2.front() 133 | q2.pop() 134 | 135 | # second 136 | if not q1.empty() and not q2.empty() and deref(q1.front()).prob < deref(q2.front()).prob: 137 | second = q1.front() 138 | q1.pop() 139 | elif q1.empty(): 140 | second = q2.front() 141 | q2.pop() 142 | elif q2.empty(): 143 | second = q1.front() 144 | q1.pop() 145 | else: 146 | second = q2.front() 147 | q2.pop() 148 | 149 | # merge 150 | prob = deref(first).prob + deref(second).prob 151 | search_path = 9 152 | if deref(first).search_path != 9: 153 | search_path = -1 154 | elif deref(second).search_path != 9: 155 | search_path = 1 156 | q2.push(make_shared[Node](Node(prob, first, second, -1, search_path))) 157 | 158 | if not q2.empty(): 159 | ans = q2.front() 160 | else: 161 | ans = q1.front() 162 | return ans 163 | 164 | ## Steganography process - single time step 165 | # Sampling (Encoding) - single time step 166 | cdef CySingleEncodeStepOutput cy_encode_step(list indices, list probs, string message_bits): 167 | # Encode step 168 | global msg_exhausted_flag 169 | cdef: 170 | int sampled_index, n_bits = 0 171 | double entropy_t = 0.0, kld = 0.0, minimum_entropy_t = 0.0, prob_sum, ptr, ptr_0, ptr_1, partition 172 | shared_ptr[Node] node_ptr = create_huffman_tree(indices, probs, -1) 173 | vector[int] path_table = [-1, 1] 174 | int len_message_bits = len(message_bits) 175 | 176 | # if len_message_bits > 0: 177 | # print('len(message_bits) = {}'.format(len_message_bits)) 178 | while not is_leaf(node_ptr): # non-leaf node 179 | prob_sum = deref(node_ptr).prob 180 | ptr = random.random() 181 | ptr_0 = ptr * prob_sum 182 | ptr_1 = (ptr + 0.5) * prob_sum 183 | if ptr_1 > prob_sum: 184 | ptr_1 -= prob_sum 185 | 186 | partition = deref(deref(node_ptr).left).prob 187 | 188 | # path_table[0] = -1 if (ptr_0 < partition) else 1 189 | if ptr_0 < partition: 190 | path_table[0] = -1 191 | else: 192 | path_table[0] = 1 193 | # path_table[1] = -1 if (ptr_1 < partition) else 1 194 | if ptr_1 < partition: 195 | path_table[1] = -1 196 | else: 197 | path_table[1] = 1 198 | 199 | # node_ptr = deref(node_ptr).right if path_table[message_bits[n_bits] - 48] == 1 else deref(node_ptr).left 200 | if not msg_exhausted_flag and (len_message_bits <= n_bits): 201 | print('[*] The message is exhausted and will be padded with all zeros!') 202 | msg_exhausted_flag = True 203 | # print(n_bits) 204 | if msg_exhausted_flag: 205 | if path_table[0] == 1: 206 | node_ptr = deref(node_ptr).right 207 | else: 208 | node_ptr = deref(node_ptr).left 209 | else: 210 | if path_table[message_bits[n_bits] - 48] == 1: 211 | node_ptr = deref(node_ptr).right 212 | else: 213 | node_ptr = deref(node_ptr).left 214 | 215 | if path_table[0] != path_table[1]: 216 | n_bits += 1 217 | # print(deref(node_ptr).index) 218 | sampled_index = deref(node_ptr).index 219 | minimum_entropy_t = -log2(probs[0]) 220 | entropy_t = entropy(probs, base=2) 221 | return CySingleEncodeStepOutput(sampled_index, n_bits, entropy_t, kld, minimum_entropy_t) 222 | 223 | 224 | # Discop Baseline Sampling (Encoding) - single time step 225 | cdef CySingleEncodeStepOutput cy_baseline_encode_step(list indices, list probs, string message_bits): 226 | # Encode step 227 | global msg_exhausted_flag 228 | cdef: 229 | int sampled_index, n_bits = 0, capacity, capacity_upper_bound, i 230 | double entropy_t = 0.0, kld = 0.0, minimum_entropy_t = 0.0, ptr, ptr_i, rotate_step_size 231 | int len_message_bits = len(message_bits) 232 | 233 | probs_cumsum = torch.tensor(probs).cumsum(dim=0) 234 | interval_begin = torch.cat((torch.tensor([0], device=probs_cumsum.device), probs_cumsum[:-1]), dim=0) 235 | 236 | # Determine capacity 237 | capacity = int(log2(1 / probs[0])) 238 | capacity_upper_bound = capacity + 1 239 | 240 | tbl = {} # message bits -> token_index 241 | ptr = random.random() 242 | 243 | while capacity <= capacity_upper_bound: 244 | if capacity == 0: 245 | capacity += 1 246 | continue 247 | rotate_step_size = 2.0**-capacity 248 | is_available = True 249 | tbl_new = {} 250 | for i in range(2**capacity): 251 | ptr_i = ptr + i * rotate_step_size 252 | if ptr_i >= 1.0: 253 | ptr_i -= 1 254 | index_idx = (ptr_i >= interval_begin).nonzero()[-1].item() 255 | index = indices[index_idx] 256 | if index in tbl_new.values(): 257 | is_available = False 258 | break 259 | tbl_new[i] = index 260 | if not is_available: 261 | break 262 | tbl = tbl_new 263 | n_bits = capacity 264 | capacity += 1 265 | if n_bits < 1: 266 | sampled_index = indices[(ptr >= interval_begin).nonzero()[-1].item()] 267 | else: 268 | cur_message_bits_decimal = 0 269 | base = 1 270 | for d in range(n_bits - 1, -1, -1): 271 | if message_bits[d] == b'1': 272 | cur_message_bits_decimal += base 273 | base *= 2 274 | sampled_index = tbl[cur_message_bits_decimal] 275 | 276 | minimum_entropy_t = -log2(probs[0]) 277 | entropy_t = entropy(probs, base=2) 278 | return CySingleEncodeStepOutput(sampled_index, n_bits, entropy_t, kld, minimum_entropy_t) 279 | 280 | 281 | # Decoding - single time step 282 | cdef CySingleDecodeStepOutput cy_decode_step(list indices, list probs, int stego_t): 283 | # Decode step 284 | cdef: 285 | string message_decoded_t 286 | double prob_sum, ptr, ptr_0, ptr_1, partition 287 | shared_ptr[Node] node_ptr = create_huffman_tree(indices, probs, stego_t) 288 | vector[int] path_table = vector[int](2) 289 | map[int, string] path_table_swap 290 | 291 | while not is_leaf(node_ptr): # non-leaf node 292 | prob_sum = deref(node_ptr).prob 293 | ptr = random.random() 294 | ptr_0 = ptr * prob_sum 295 | ptr_1 = (ptr + 0.5) * prob_sum 296 | if ptr_1 > prob_sum: 297 | ptr_1 -= prob_sum 298 | 299 | partition = deref(deref(node_ptr).left).prob 300 | 301 | # path_table[0] = -1 if (ptr_0 < partition) else 1 302 | if ptr_0 < partition: 303 | path_table[0] = -1 304 | else: 305 | path_table[0] = 1 306 | # path_table[1] = -1 if (ptr_1 < partition) else 1 307 | if ptr_1 < partition: 308 | path_table[1] = -1 309 | else: 310 | path_table[1] = 1 311 | 312 | if path_table[0] != path_table[1]: # can embed 1 bit 313 | if deref(node_ptr).search_path == 9: # fail to decode 314 | message_decoded_t = b'x' 315 | break 316 | 317 | if path_table[0] == -1: 318 | path_table_swap[-1] = b'0' 319 | path_table_swap[1] = b'1' 320 | else: 321 | path_table_swap[-1] = b'1' 322 | path_table_swap[1] = b'0' 323 | message_decoded_t += path_table_swap[deref(node_ptr).search_path] 324 | 325 | # walk 326 | if deref(node_ptr).search_path == -1: 327 | node_ptr = deref(node_ptr).left 328 | else: 329 | node_ptr = deref(node_ptr).right 330 | else: 331 | if path_table[0] == -1: 332 | node_ptr = deref(node_ptr).left 333 | else: 334 | node_ptr = deref(node_ptr).right 335 | 336 | if deref(node_ptr).search_path != 0: # cannot reach a leaf node 337 | message_decoded_t = b'x' 338 | return CySingleDecodeStepOutput(message_decoded_t) 339 | 340 | 341 | # Discop Baseline Decoding - single time step 342 | cdef CySingleDecodeStepOutput cy_baseline_decode_step(list indices, list probs, int stego_t): 343 | # Decode step 344 | cdef: 345 | int capacity, capacity_upper_bound, n_bits = 0 346 | string message_decoded_t 347 | double ptr 348 | probs_cumsum = torch.tensor(probs).cumsum(dim=0) 349 | interval_begin = torch.cat((torch.tensor([0], device=probs_cumsum.device), probs_cumsum[:-1]), dim=0) 350 | 351 | # Determine capacity 352 | capacity = int(log2(1 / probs[0])) 353 | capacity_upper_bound = capacity + 1 354 | 355 | tbl = {} # message bits -> token_index 356 | ptr = random.random() 357 | 358 | while capacity <= capacity_upper_bound: 359 | if capacity == 0: 360 | capacity += 1 361 | continue 362 | rotate_step_size = 2.0**-capacity 363 | is_available = True 364 | tbl_new = {} 365 | for i in range(2**capacity): 366 | ptr_i = ptr + i * rotate_step_size 367 | if ptr_i >= 1.0: 368 | ptr_i -= 1 369 | index_idx = (ptr_i >= interval_begin).nonzero()[-1].item() 370 | index = indices[index_idx] 371 | if index in tbl_new.values(): 372 | is_available = False 373 | break 374 | tbl_new[i] = index 375 | if not is_available: 376 | break 377 | tbl = tbl_new 378 | n_bits = capacity 379 | capacity += 1 380 | if n_bits < 1: 381 | message_decoded_t = b'' 382 | else: 383 | if stego_t not in tbl.values(): # Error 384 | message_decoded_t = b'x' 385 | tbl_swapped = dict(zip(tbl.values(), tbl.keys())) # token_index -> message bits 386 | message_decoded_t = bin(tbl_swapped[stego_t])[2:].zfill(n_bits) 387 | return CySingleDecodeStepOutput(message_decoded_t) 388 | 389 | 390 | def encode(model, context, message_bits, settings, bint verbose = False, string tqdm_desc = b'Enc '): 391 | # Steganography Encoding (message_bits -> English text) 392 | cdef: 393 | int t = 0, length = settings.length, indices_idx 394 | string stego_object 395 | list generated_ids = [] 396 | 397 | # CySingleEncodeStepOutput 398 | CySingleEncodeStepOutput single_encode_step_output 399 | int sampled_index 400 | int capacity_t 401 | double entropy_t 402 | double kld_step 403 | double minimum_entropy_t 404 | 405 | # statistics 406 | int total_capacity = 0 407 | double total_entropy = 0.0 408 | double total_minimum_entropy = 0.0 409 | double total_log_probs = 0.0 # for perplexity 410 | double total_kld = 0.0 411 | double max_kld = 0.0 412 | double perplexity, ave_kld 413 | 414 | set_seed(settings.seed) 415 | 416 | past = None # pass into the `past_keys_values` for speed-up 417 | prev = context # indices that were never passed to the model before 418 | 419 | start = time.time() 420 | for t in tqdm(range(length), desc=tqdm_desc, ncols=70): 421 | probs, indices, past = get_probs_indices_past(model, prev, past, settings) 422 | probs = probs.tolist() 423 | indices = indices.tolist() 424 | 425 | if settings.algo == 'Discop': 426 | single_encode_step_output = cy_encode_step(indices, probs, message_bits) 427 | elif settings.algo == 'Discop_baseline': 428 | single_encode_step_output = cy_baseline_encode_step(indices, probs, message_bits) 429 | sampled_index = single_encode_step_output.sampled_index 430 | capacity_t = single_encode_step_output.n_bits 431 | entropy_t = single_encode_step_output.entropy_t 432 | kld_step = single_encode_step_output.kld 433 | minimum_entropy_t = single_encode_step_output.minimum_entropy_t 434 | 435 | indices_idx = indices.index(sampled_index) 436 | 437 | # update statistics 438 | total_entropy += entropy_t 439 | total_minimum_entropy += minimum_entropy_t 440 | total_log_probs += log2(probs[indices_idx]) 441 | total_kld += kld_step 442 | if kld_step > max_kld: 443 | max_kld = kld_step 444 | 445 | # when `capacity_t == 0`, cannot embed message, but still needs to return a token_index 446 | if capacity_t > 0: 447 | total_capacity += capacity_t 448 | message_bits = message_bits[capacity_t:] # remove the encoded part of `message_bits` 449 | generated_ids.append(sampled_index) 450 | if settings.task == 'text': 451 | prev = torch.tensor([sampled_index], device=settings.device).unsqueeze(0) 452 | elif settings.task == 'image': 453 | prev = torch.tensor([sampled_index], device=settings.device) 454 | end = time.time() 455 | 456 | perplexity = 2 ** (-1 / length * total_log_probs) 457 | ave_kld = total_kld / length 458 | return SingleExampleOutput(generated_ids, None, total_capacity, total_entropy, ave_kld, max_kld, perplexity, 459 | end - start, settings, 460 | total_minimum_entropy) 461 | 462 | def decode(model, context, list stego, settings, bint verbose = False, string tqdm_desc = b'Dec '): 463 | # Steganography Decoding (English text -> message_bits) 464 | cdef: 465 | int t = 0, length = len(stego), indices_idx 466 | string message_decoded 467 | 468 | # CySingleEncodeStepOutput 469 | CySingleDecodeStepOutput single_decode_step_output 470 | int sampled_index 471 | int capacity_t 472 | double entropy_t 473 | double kld_step 474 | double minimum_entropy_t 475 | 476 | set_seed(settings.seed) 477 | past = None # pass into the `past_keys_values` for speed-up 478 | prev = context # indices that were never passed to the model before 479 | 480 | start = time.time() 481 | # while t < length: 482 | for t in tqdm(range(length), desc=tqdm_desc, ncols=70): 483 | probs, indices, past = get_probs_indices_past(model, prev, past, settings) 484 | probs = probs.tolist() 485 | indices = indices.tolist() 486 | 487 | if settings.algo == 'Discop': 488 | single_decode_step_output = cy_decode_step(indices, probs, stego[t]) 489 | elif settings.algo == 'Discop_baseline': 490 | single_decode_step_output = cy_baseline_decode_step(indices, probs, stego[t]) 491 | message_decoded_t = single_decode_step_output.message_decoded_t 492 | 493 | # print(message_decoded_t) 494 | 495 | if message_decoded_t == b'x': 496 | print('Fail to decode!') 497 | break 498 | message_decoded += message_decoded_t 499 | if settings.task == 'text': 500 | prev = torch.tensor([stego[t]], device=settings.device).unsqueeze(0) 501 | elif settings.task == 'image': 502 | prev = torch.tensor([stego[t]], device=settings.device) 503 | end = time.time() 504 | print('Decode time = {}s'.format(end - start)) 505 | return message_decoded 506 | 507 | def encode_text(model, tokenizer, message_bits, context, settings = text_default_settings): 508 | # tokenizer = get_tokenizer(settings) 509 | # model = get_model(settings) 510 | context = tokenizer(context, return_tensors='pt', max_length=1024, truncation=True)['input_ids'].to(settings.device) 511 | 512 | single_encode_step_output = encode(model, context, message_bits, settings) 513 | single_encode_step_output.stego_object = tokenizer.decode(single_encode_step_output.generated_ids) 514 | 515 | return single_encode_step_output 516 | 517 | def decode_text(model, tokenizer, list stego, context, settings = text_default_settings): 518 | # tokenizer = get_tokenizer(settings) 519 | # model = get_model(settings) 520 | context = tokenizer(context, return_tensors='pt', max_length=1024, truncation=True)['input_ids'].to(settings.device) 521 | 522 | message_decoded = decode(model, context, stego, settings) 523 | 524 | return message_decoded 525 | 526 | def encode_image(model, feature_extractor, message_bits, settings = image_default_settings, bint verbose = False, 527 | string tqdm_desc = b'Enc ', 528 | double context_ratio = 0.0, original_img = None): 529 | cdef: 530 | int n_pixels_to_gen = 1024, n_pixels_context, n_px, width, height, width_after, height_after, height_context 531 | # feature_extractor = get_feature_extractor(settings) 532 | # model = get_model(settings) 533 | clusters = feature_extractor.clusters # with shape (512, 3) 534 | n_px = feature_extractor.size # 32 535 | 536 | context = torch.tensor([model.config.vocab_size - 1], device=settings.device) # initialize with SOS token 537 | if context_ratio != 0.0: 538 | if original_img is None: 539 | raise ValueError('If you set `context_ratio`, please make sure that `original_img` is not None!') 540 | elif type(original_img) == str: 541 | img = Image.open(original_img) 542 | else: 543 | img = original_img 544 | width, height = img.size 545 | 546 | # resize if needed 547 | if width != n_px: 548 | width_after = n_px 549 | height_after = round(width_after / width * height) 550 | img = img.resize((width_after, height_after)) 551 | width, height = width_after, height_after 552 | 553 | pixel_indices_lst = feature_extractor(img, return_tensors='pt')['input_ids'][0].to(settings.device) 554 | height_context = round(n_px * context_ratio) 555 | n_pixels_context = n_px * height_context 556 | primers = pixel_indices_lst[:n_pixels_context] 557 | context = torch.cat((context, primers), dim=-1) 558 | n_pixels_to_gen -= len(primers) 559 | 560 | output_pixel_ids = pixel_indices_lst.tolist()[:n_pixels_context] 561 | else: 562 | output_pixel_ids = [] 563 | settings.length = n_pixels_to_gen 564 | 565 | single_encode_step_output = encode(model, context, message_bits, settings) 566 | output_pixel_ids.extend(single_encode_step_output.generated_ids) 567 | 568 | def pixel_ids_lst_to_pil_image(pixel_indices_lst): 569 | pixel_indices_array = np.array(pixel_indices_lst) 570 | img = Image.fromarray( 571 | np.reshape(np.rint(127.5 * (clusters[pixel_indices_array] + 1.0)), [n_px, n_px, 3]).astype(np.uint8)) 572 | return img 573 | single_encode_step_output.stego_object = pixel_ids_lst_to_pil_image(output_pixel_ids) 574 | return single_encode_step_output 575 | 576 | def decode_image(model, feature_extractor, img, settings = image_default_settings, bint verbose = False, 577 | string tqdm_desc = b'Dec ', 578 | double context_ratio = 0.0): 579 | cdef: 580 | int height_context, n_pixels_context, n_pixels_to_gen = 1024, n_px 581 | list stego 582 | string message_decoded 583 | if type(img) == str: 584 | img = Image.open(img) 585 | 586 | # feature_extractor = get_feature_extractor(settings) 587 | # model = get_model(settings) 588 | clusters = feature_extractor.clusters # with shape (512, 3) 589 | n_px = feature_extractor.size # 32 590 | 591 | stego = feature_extractor(img)['input_ids'][0].tolist() 592 | if len(stego) != 1024: 593 | raise ValueError('len(stego) must be 1024!') 594 | 595 | context = torch.tensor([model.config.vocab_size - 1], device=settings.device) # initialize with SOS token 596 | if context_ratio != 0.0: 597 | pixel_indices_lst = feature_extractor(img, return_tensors='pt')['input_ids'][0].to(settings.device) 598 | height_context = round(n_px * context_ratio) 599 | n_pixels_context = n_px * height_context 600 | primers = pixel_indices_lst[:n_pixels_context] 601 | context = torch.cat((context, primers), dim=-1) 602 | stego = stego[n_pixels_context:] 603 | n_pixels_to_gen -= len(primers) 604 | 605 | settings.length = n_pixels_to_gen 606 | 607 | message_decoded = decode(model, context, stego, settings) 608 | return message_decoded 609 | 610 | ## Python interface 611 | def encode_step(settings, list indices, list probs, string message_bits): 612 | cdef CySingleEncodeStepOutput single_encode_step_output 613 | if settings.algo == 'Discop': 614 | single_encode_step_output = cy_encode_step(indices, probs, message_bits) 615 | elif settings.algo == 'Discop_baseline': 616 | single_encode_step_output = cy_baseline_encode_step(indices, probs, message_bits) 617 | sampled_index = single_encode_step_output.sampled_index 618 | n_bits = single_encode_step_output.n_bits 619 | entropy_t = single_encode_step_output.entropy_t 620 | kld = single_encode_step_output.kld 621 | minimum_entropy_t = single_encode_step_output.minimum_entropy_t 622 | return SingleEncodeStepOutput(sampled_index, n_bits, entropy_t, kld, minimum_entropy_t) 623 | 624 | def decode_step(settings, list indices, list probs, int stego_t): 625 | cdef CySingleDecodeStepOutput single_decode_step_output 626 | if settings.algo == 'Discop': 627 | single_decode_step_output = cy_decode_step(indices, probs, stego_t) 628 | elif settings.algo == 'Discop_baseline': 629 | single_decode_step_output = cy_baseline_decode_step(indices, probs, stego_t) 630 | if single_decode_step_output.message_decoded_t == b'x': 631 | raise ValueError('Fail to decode!') 632 | return single_decode_step_output.message_decoded_t 633 | -------------------------------------------------------------------------------- /src/stega_tts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tacotron import load_cmudict, Tacotron, text_to_id 3 | from univoc import Vocoder 4 | 5 | from config import Settings, audio_default_settings 6 | 7 | 8 | def get_tts_model(settings: Settings): 9 | assert settings.task == 'text-to-speech' and settings.model_name == 'univoc' 10 | vocoder = Vocoder.from_pretrained( 11 | "https://github.com/bshall/UniversalVocoding/releases/download/v0.2/univoc-ljspeech-7mtpaq.pt", 12 | map_location=settings.device).to(settings.device) 13 | tacotron = Tacotron.from_pretrained("https://github.com/bshall/Tacotron/releases/download/v0.1/tacotron-ljspeech-yspjx3.pt", 14 | map_location=settings.device).to(settings.device) 15 | cmudict = load_cmudict() 16 | return vocoder, tacotron, cmudict 17 | 18 | 19 | def encode_speech(vocoder: Vocoder, 20 | tacotron: Tacotron, 21 | cmudict, 22 | message_bits: str, 23 | text: str, 24 | settings: Settings = audio_default_settings, 25 | verbose: bool = False, 26 | tqdm_desc: str = 'Enc '): 27 | x = torch.tensor(text_to_id(text, cmudict), dtype=torch.long, device=settings.device).unsqueeze(0) 28 | mel, _ = tacotron.generate(x) 29 | mel = mel.transpose(1, 2) 30 | 31 | single_encode_step_output, sr = vocoder.encode_speech(mel, message_bits, settings=settings, tqdm_desc=tqdm_desc) 32 | 33 | return single_encode_step_output, sr 34 | 35 | 36 | def decode_speech(vocoder: Vocoder, 37 | tacotron: Tacotron, 38 | cmudict, 39 | speech, 40 | text: str, 41 | settings: Settings = audio_default_settings, 42 | verbose: bool = False, 43 | tqdm_desc: str = 'Dec '): 44 | x = torch.tensor(text_to_id(text, cmudict), dtype=torch.long, device=settings.device).unsqueeze(0) 45 | mel, _ = tacotron.generate(x) 46 | mel = mel.transpose(1, 2) 47 | 48 | message_decoded = vocoder.decode_speech(mel, speech, settings=settings, tqdm_desc=tqdm_desc) 49 | 50 | return message_decoded 51 | 52 | 53 | def random_sample_speech(vocoder, 54 | tacotron, 55 | cmudict, 56 | message_bits, 57 | text: str, 58 | settings: Settings = audio_default_settings, 59 | verbose: bool = False, 60 | tqdm_desc: str = 'Enc '): 61 | x = torch.tensor(text_to_id(text, cmudict), dtype=torch.long, device=settings.device).unsqueeze(0) 62 | mel, _ = tacotron.generate(x) 63 | mel = mel.transpose(1, 2) 64 | 65 | single_encode_step_output, sr = vocoder.random_sample_speech(mel, message_bits, tqdm_desc=tqdm_desc) 66 | 67 | return single_encode_step_output, sr -------------------------------------------------------------------------------- /src/tacotron/TTS_cleaner.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict 3 | 4 | import inflect 5 | # Borrows from https://github.com/coqui-ai/TTS/blob/main/TTS/tts/utils/text/cleaners.py 6 | 7 | _inflect = inflect.engine() 8 | 9 | 10 | # lowercase 11 | def lowercase(text): 12 | return text.lower() 13 | 14 | 15 | # time 16 | _time_re = re.compile( 17 | r"""\b 18 | ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours 19 | : 20 | ([0-5][0-9]) # minutes 21 | \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm 22 | \b""", 23 | re.IGNORECASE | re.X, 24 | ) 25 | 26 | 27 | def _expand_num(n: int) -> str: 28 | return _inflect.number_to_words(n) 29 | 30 | 31 | def _expand_time_english(match: "re.Match") -> str: 32 | hour = int(match.group(1)) 33 | past_noon = hour >= 12 34 | time = [] 35 | if hour > 12: 36 | hour -= 12 37 | elif hour == 0: 38 | hour = 12 39 | past_noon = True 40 | time.append(_expand_num(hour)) 41 | 42 | minute = int(match.group(6)) 43 | if minute > 0: 44 | if minute < 10: 45 | time.append("oh") 46 | time.append(_expand_num(minute)) 47 | am_pm = match.group(7) 48 | if am_pm is None: 49 | time.append("p m" if past_noon else "a m") 50 | else: 51 | time.extend(list(am_pm.replace(".", ""))) 52 | return " ".join(time) 53 | 54 | 55 | def expand_time_english(text: str) -> str: 56 | return re.sub(_time_re, _expand_time_english, text) 57 | 58 | 59 | # en_normalize_numbers 60 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 61 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 62 | _currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)") 63 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 64 | _number_re = re.compile(r"-?[0-9]+") 65 | 66 | 67 | def _remove_commas(m): 68 | return m.group(1).replace(",", "") 69 | 70 | 71 | def _expand_decimal_point(m): 72 | return m.group(1).replace(".", " point ") 73 | 74 | 75 | def __expand_currency(value: str, inflection: Dict[float, str]) -> str: 76 | parts = value.replace(",", "").split(".") 77 | if len(parts) > 2: 78 | return f"{value} {inflection[2]}" # Unexpected format 79 | text = [] 80 | integer = int(parts[0]) if parts[0] else 0 81 | if integer > 0: 82 | integer_unit = inflection.get(integer, inflection[2]) 83 | text.append(f"{integer} {integer_unit}") 84 | fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0 85 | if fraction > 0: 86 | fraction_unit = inflection.get(fraction / 100, inflection[0.02]) 87 | text.append(f"{fraction} {fraction_unit}") 88 | if len(text) == 0: 89 | return f"zero {inflection[2]}" 90 | return " ".join(text) 91 | 92 | 93 | def _expand_currency(m: "re.Match") -> str: 94 | currencies = { 95 | "$": { 96 | 0.01: "cent", 97 | 0.02: "cents", 98 | 1: "dollar", 99 | 2: "dollars", 100 | }, 101 | "€": { 102 | 0.01: "cent", 103 | 0.02: "cents", 104 | 1: "euro", 105 | 2: "euros", 106 | }, 107 | "£": { 108 | 0.01: "penny", 109 | 0.02: "pence", 110 | 1: "pound sterling", 111 | 2: "pounds sterling", 112 | }, 113 | "¥": { 114 | # TODO rin 115 | 0.02: "sen", 116 | 2: "yen", 117 | }, 118 | } 119 | unit = m.group(1) 120 | currency = currencies[unit] 121 | value = m.group(2) 122 | return __expand_currency(value, currency) 123 | 124 | 125 | def _expand_ordinal(m): 126 | return _inflect.number_to_words(m.group(0)) 127 | 128 | 129 | def _expand_number(m): 130 | num = int(m.group(0)) 131 | if 1000 < num < 3000: 132 | if num == 2000: 133 | return "two thousand" 134 | if 2000 < num < 2010: 135 | return "two thousand " + _inflect.number_to_words(num % 100) 136 | if num % 100 == 0: 137 | return _inflect.number_to_words(num // 100) + " hundred" 138 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 139 | return _inflect.number_to_words(num, andword="") 140 | 141 | 142 | def en_normalize_numbers(text): 143 | text = re.sub(_comma_number_re, _remove_commas, text) 144 | text = re.sub(_currency_re, _expand_currency, text) 145 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 146 | text = re.sub(_ordinal_re, _expand_ordinal, text) 147 | text = re.sub(_number_re, _expand_number, text) 148 | return text 149 | 150 | 151 | # expand_abbreviations 152 | abbreviations_en = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ 153 | ("mrs", "misess"), 154 | ("mr", "mister"), 155 | ("dr", "doctor"), 156 | ("st", "saint"), 157 | ("co", "company"), 158 | ("jr", "junior"), 159 | ("maj", "major"), 160 | ("gen", "general"), 161 | ("drs", "doctors"), 162 | ("rev", "reverend"), 163 | ("lt", "lieutenant"), 164 | ("hon", "honorable"), 165 | ("sgt", "sergeant"), 166 | ("capt", "captain"), 167 | ("esq", "esquire"), 168 | ("ltd", "limited"), 169 | ("col", "colonel"), 170 | ("ft", "fort"), 171 | ]] 172 | 173 | 174 | def expand_abbreviations(text): 175 | _abbreviations = abbreviations_en 176 | for regex, replacement in _abbreviations: 177 | text = re.sub(regex, replacement, text) 178 | return text 179 | 180 | 181 | # replace_symbols 182 | def replace_symbols(text, lang="en"): 183 | text = text.replace(";", ",") 184 | text = text.replace("-", " ") 185 | text = text.replace(":", ",") 186 | if lang == "en": 187 | text = text.replace("&", " and ") 188 | elif lang == "fr": 189 | text = text.replace("&", " et ") 190 | elif lang == "pt": 191 | text = text.replace("&", " e ") 192 | return text 193 | 194 | 195 | # remove_aux_symbols 196 | def remove_aux_symbols(text): 197 | text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text) 198 | return text 199 | 200 | 201 | def collapse_whitespace(text): 202 | _whitespace_re = re.compile(r"\s+") 203 | return re.sub(_whitespace_re, " ", text).strip() 204 | 205 | 206 | def english_cleaners(text): 207 | """Pipeline for English text, including number and abbreviation expansion.""" 208 | # text = convert_to_ascii(text) 209 | text = lowercase(text) 210 | text = expand_time_english(text) 211 | text = en_normalize_numbers(text) 212 | text = expand_abbreviations(text) 213 | text = replace_symbols(text) 214 | text = remove_aux_symbols(text) 215 | text = collapse_whitespace(text) 216 | return text -------------------------------------------------------------------------------- /src/tacotron/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.1" 2 | 3 | from .dataset import BucketBatchSampler, TTSDataset, pad_collate 4 | from .model import Tacotron 5 | from .text import load_cmudict, symbol_to_id, text_to_id 6 | 7 | __all__ = [ 8 | Tacotron, 9 | TTSDataset, 10 | BucketBatchSampler, 11 | load_cmudict, 12 | text_to_id, 13 | symbol_to_id, 14 | pad_collate, 15 | ] 16 | -------------------------------------------------------------------------------- /src/tacotron/config.toml: -------------------------------------------------------------------------------- 1 | [preprocess] 2 | sr = 16000 3 | hop_length = 200 4 | win_length = 800 5 | n_fft = 2048 6 | n_mels = 80 7 | fmin = 50 8 | preemph = 0.97 9 | top_db = 80 10 | ref_db = 20 11 | mulaw.bits = 10 12 | 13 | [train] 14 | batch_size = 64 15 | bucket_size_multiplier = 5 16 | n_steps = 250000 17 | checkpoint_interval = 5000 18 | n_workers = 8 19 | clip_grad_norm = 0.05 20 | 21 | [train.optimizer] 22 | lr = 1e-3 23 | 24 | [train.scheduler] 25 | milestones = [ 26 | 20000, 27 | 40000, 28 | 100000, 29 | 150000, 30 | 200000, 31 | ] 32 | gamma = 0.5 33 | 34 | [model.encoder] 35 | n_symbols = 91 36 | embedding_dim = 256 37 | 38 | [model.encoder.prenet] 39 | input_size = 256 # should match model.encoder.embedding_dim 40 | hidden_size = 256 41 | output_size = 128 42 | dropout = 0.5 43 | 44 | [model.encoder.cbhg] 45 | input_channels = 128 # should match model.encoder.prenet.output_size 46 | K = 16 47 | channels = 128 48 | projection_channels = 128 49 | n_highways = 4 50 | highway_size = 128 51 | rnn_size = 128 52 | 53 | [model.decoder] 54 | input_size = 128 # should match model.encoder.cbhg.channels 55 | n_mels = 80 # should match preprocess.n_mels 56 | attn_rnn_size = 256 57 | decoder_rnn_size = 256 58 | reduction_factor = 2 59 | zoneout_prob = 0.1 60 | 61 | [model.decoder.prenet] 62 | input_size = 80 # should match preprocess.n_mels 63 | hidden_size = 256 64 | output_size = 128 65 | dropout = 0.5 66 | 67 | [model.decoder.attention] 68 | attn_rnn_size = 256 # should match model.decoder.attn_rnn_size 69 | hidden_size = 128 70 | static_channels = 8 71 | static_kernel_size = 21 72 | dynamic_channels = 8 73 | dynamic_kernel_size = 21 74 | prior_length = 11 75 | alpha = 0.1 76 | beta = 0.9 77 | -------------------------------------------------------------------------------- /src/tacotron/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.utils.data.sampler as samplers 9 | from torch.nn.utils.rnn import pad_sequence 10 | from torch.utils.data import Dataset 11 | 12 | from .text import load_cmudict, symbol_to_id, text_to_id 13 | 14 | 15 | class SortedSampler(samplers.Sampler): 16 | """ 17 | Adapted from https://github.com/PetrochukM/PyTorch-NLP 18 | Copyright (c) James Bradbury and Soumith Chintala 2016, 19 | All rights reserved. 20 | """ 21 | 22 | def __init__(self, data, sort_key): 23 | super().__init__(data) 24 | self.data = data 25 | self.sort_key = sort_key 26 | zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] 27 | zip_ = sorted(zip_, key=lambda r: r[1], reverse=True) 28 | self.sorted_indexes = [item[0] for item in zip_] 29 | 30 | def __iter__(self): 31 | return iter(self.sorted_indexes) 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | 37 | class BucketBatchSampler(samplers.BatchSampler): 38 | """ 39 | Adapted from https://github.com/PetrochukM/PyTorch-NLP 40 | Copyright (c) James Bradbury and Soumith Chintala 2016, 41 | All rights reserved. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | sampler, 47 | batch_size, 48 | drop_last, 49 | sort_key, 50 | bucket_size_multiplier, 51 | ): 52 | super().__init__(sampler, batch_size, drop_last) 53 | self.sort_key = sort_key 54 | self.bucket_sampler = samplers.BatchSampler( 55 | sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False 56 | ) 57 | 58 | def __iter__(self): 59 | for bucket in self.bucket_sampler: 60 | sorted_sampler = SortedSampler(bucket, self.sort_key) 61 | for batch in samplers.SubsetRandomSampler( 62 | list( 63 | samplers.BatchSampler( 64 | sorted_sampler, self.batch_size, self.drop_last 65 | ) 66 | ) 67 | ): 68 | yield [bucket[i] for i in batch] 69 | 70 | def __len__(self): 71 | if self.drop_last: 72 | return len(self.sampler) // self.batch_size 73 | else: 74 | return math.ceil(len(self.sampler) / self.batch_size) 75 | 76 | 77 | class TTSDataset(Dataset): 78 | def __init__(self, root, text_path): 79 | self.root = Path(root) 80 | 81 | with open(self.root / "train.json") as file: 82 | metadata = json.load(file) 83 | self.metadata = [Path(path) for _, path in metadata] 84 | 85 | with open(self.root / "lengths.json") as file: 86 | lengths = json.load(file) 87 | self.lengths = [lengths[path.stem] for path in self.metadata] 88 | 89 | self.index_longest_mel = np.argmax(self.lengths) 90 | 91 | self.cmudict = load_cmudict() 92 | 93 | train_set = {path.stem for path in self.metadata} 94 | with open(text_path) as file: 95 | text = (line.strip().split("|") for line in file) 96 | self.text = { 97 | key: transcript for key, _, transcript in text if key in train_set 98 | } 99 | 100 | def sort_key(self, index): 101 | return self.lengths[index] 102 | 103 | def __len__(self): 104 | return len(self.metadata) 105 | 106 | def __getitem__(self, index): 107 | path = self.root / self.metadata[index] 108 | 109 | mel = np.load(path.with_suffix(".mel.npy")) 110 | 111 | text = text_to_id(self.text[path.stem], self.cmudict) 112 | 113 | return ( 114 | torch.Tensor(mel).transpose_(0, 1), 115 | torch.LongTensor(text), 116 | index == self.index_longest_mel, 117 | ) 118 | 119 | 120 | def pad_collate(batch, reduction_factor=2): 121 | mels, texts, attn_flag = zip(*batch) 122 | mels = list(mels) 123 | texts = list(texts) 124 | 125 | # if len(mels[0]) is not a multiple of reduction_factor pad 126 | # note: mels[0] is the longest mel in the batch so when we call pad_sequence 127 | # the whole batch will get padded to a multiple of reduction_factor 128 | if len(mels[0]) % reduction_factor != 0: 129 | mels[0] = F.pad(mels[0], (0, 0, 0, reduction_factor - 1)) 130 | 131 | mel_lengths = [len(mel) for mel in mels] 132 | text_lengths = [len(text) for text in texts] 133 | 134 | mels = pad_sequence(mels, batch_first=True) 135 | texts = pad_sequence(texts, batch_first=True, padding_value=symbol_to_id["_"]) 136 | 137 | attn_flag = [i for i, flag in enumerate(attn_flag) if flag] 138 | 139 | return mels.transpose_(1, 2), texts, mel_lengths, text_lengths, attn_flag 140 | -------------------------------------------------------------------------------- /src/tacotron/model.py: -------------------------------------------------------------------------------- 1 | import importlib_resources 2 | import numpy as np 3 | import toml 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy.stats import betabinom 8 | 9 | 10 | class Tacotron(nn.Module): 11 | def __init__(self, encoder, decoder): 12 | super().__init__() 13 | self.input_size = 2 * decoder["input_size"] 14 | self.attn_rnn_size = decoder["attn_rnn_size"] 15 | self.decoder_rnn_size = decoder["decoder_rnn_size"] 16 | self.n_mels = decoder["n_mels"] 17 | self.reduction_factor = decoder["reduction_factor"] 18 | 19 | self.encoder = Encoder(**encoder) 20 | self.decoder_cell = DecoderCell(**decoder) 21 | 22 | @classmethod 23 | def from_pretrained(cls, url, map_location=None, cfg_path=None): 24 | """ 25 | Loads the Torch serialized object at the given URL 26 | (uses torch.hub.load_state_dict_from_url). 27 | 28 | Parameters: 29 | url (string): URL of the weights to download 30 | map_location: a function or a dict specifying how to remap 31 | storage locations (see torch.load). 32 | cfg_path (Path): path to config file. 33 | Defaults to tacotron/config.toml 34 | """ 35 | cfg_ref = (importlib_resources.files("tacotron").joinpath("config.toml") if cfg_path is None else cfg_path) 36 | with cfg_ref.open() as file: 37 | cfg = toml.load(file) 38 | checkpoint = torch.hub.load_state_dict_from_url(url, map_location=map_location) 39 | model = cls(**cfg["model"]).to(device=map_location) 40 | model.load_state_dict(checkpoint["tacotron"]) 41 | model.eval() 42 | return model 43 | 44 | def forward(self, x, mels): 45 | B, N, T = mels.size() 46 | mels = mels.unbind(-1) 47 | 48 | h = self.encoder(x) 49 | 50 | alpha = F.one_hot(torch.zeros(B, dtype=torch.long, device=x.device), h.size(1)).float() 51 | c = torch.zeros(B, self.input_size, device=x.device) 52 | 53 | attn_hx = ( 54 | torch.zeros(B, self.attn_rnn_size, device=x.device), 55 | torch.zeros(B, self.attn_rnn_size, device=x.device), 56 | ) 57 | 58 | rnn1_hx = ( 59 | torch.zeros(B, self.decoder_rnn_size, device=x.device), 60 | torch.zeros(B, self.decoder_rnn_size, device=x.device), 61 | ) 62 | 63 | rnn2_hx = ( 64 | torch.zeros(B, self.decoder_rnn_size, device=x.device), 65 | torch.zeros(B, self.decoder_rnn_size, device=x.device), 66 | ) 67 | 68 | go_frame = torch.zeros(B, N, device=x.device) 69 | 70 | ys, alphas = [], [] 71 | for t in range(0, T, self.reduction_factor): 72 | y = mels[t - 1] if t > 0 else go_frame 73 | y, alpha, c, attn_hx, rnn1_hx, rnn2_hx = self.decoder_cell(h, y, alpha, c, attn_hx, rnn1_hx, rnn2_hx) 74 | ys.append(y) 75 | alphas.append(alpha) 76 | 77 | ys = torch.cat(ys, dim=-1) 78 | alphas = torch.stack(alphas, dim=2) 79 | return ys, alphas 80 | 81 | def generate(self, x, max_length=10000, stop_threshold=-0.2): 82 | """ 83 | Generates a log-Mel spectrogram from text. 84 | 85 | Parameters: 86 | x (Tensor): The text to synthesize converted to a sequence of symbol ids. 87 | See `text_to_id`. 88 | max_length (int): Maximum number of frames to generate. 89 | Defaults to 10000 frames i.e. 125 seconds. 90 | stop_threshold (float): If a frame is generated with all values exceeding 91 | `stop_threshold` then generation is stopped. 92 | 93 | Returns: 94 | Tensor: a log-Mel spectrogram of the synthesized speech. 95 | """ 96 | h = self.encoder(x) 97 | B, T, _ = h.size() 98 | 99 | alpha = F.one_hot(torch.zeros(B, dtype=torch.long, device=x.device), T).float() 100 | c = torch.zeros(B, self.input_size, device=x.device) 101 | 102 | attn_hx = ( 103 | torch.zeros(B, self.attn_rnn_size, device=x.device), 104 | torch.zeros(B, self.attn_rnn_size, device=x.device), 105 | ) 106 | 107 | rnn1_hx = ( 108 | torch.zeros(B, self.decoder_rnn_size, device=x.device), 109 | torch.zeros(B, self.decoder_rnn_size, device=x.device), 110 | ) 111 | 112 | rnn2_hx = ( 113 | torch.zeros(B, self.decoder_rnn_size, device=x.device), 114 | torch.zeros(B, self.decoder_rnn_size, device=x.device), 115 | ) 116 | 117 | go_frame = torch.zeros(B, self.n_mels, device=x.device) 118 | 119 | ys, alphas = [], [] 120 | for t in range(0, max_length, self.reduction_factor): 121 | y = ys[-1][:, :, -1] if t > 0 else go_frame 122 | y, alpha, c, attn_hx, rnn1_hx, rnn2_hx = self.decoder_cell(h, y, alpha, c, attn_hx, rnn1_hx, rnn2_hx) 123 | if torch.all(y[:, :, -1] > stop_threshold): 124 | break 125 | ys.append(y) 126 | alphas.append(alpha) 127 | 128 | ys = torch.cat(ys, dim=-1) 129 | alphas = torch.stack(alphas, dim=2) 130 | return ys, alphas 131 | 132 | 133 | class DynamicConvolutionAttention(nn.Module): 134 | def __init__( 135 | self, 136 | attn_rnn_size, 137 | hidden_size, 138 | static_channels, 139 | static_kernel_size, 140 | dynamic_channels, 141 | dynamic_kernel_size, 142 | prior_length, 143 | alpha, 144 | beta, 145 | ): 146 | super(DynamicConvolutionAttention, self).__init__() 147 | 148 | self.prior_length = prior_length 149 | self.dynamic_channels = dynamic_channels 150 | self.dynamic_kernel_size = dynamic_kernel_size 151 | 152 | P = betabinom.pmf(np.arange(prior_length), prior_length - 1, alpha, beta) 153 | 154 | self.register_buffer("P", torch.FloatTensor(P).flip(0)) 155 | self.W = nn.Linear(attn_rnn_size, hidden_size) 156 | self.V = nn.Linear(hidden_size, dynamic_channels * dynamic_kernel_size, bias=False) 157 | self.F = nn.Conv1d( 158 | 1, 159 | static_channels, 160 | static_kernel_size, 161 | padding=(static_kernel_size - 1) // 2, 162 | bias=False, 163 | ) 164 | self.U = nn.Linear(static_channels, hidden_size, bias=False) 165 | self.T = nn.Linear(dynamic_channels, hidden_size) 166 | self.v = nn.Linear(hidden_size, 1, bias=False) 167 | 168 | def forward(self, s, alpha): 169 | p = F.conv1d(F.pad(alpha.unsqueeze(1), (self.prior_length - 1, 0)), self.P.view(1, 1, -1)) 170 | p = torch.log(p.clamp_min_(1e-6)).squeeze(1) 171 | 172 | G = self.V(torch.tanh(self.W(s))) 173 | g = F.conv1d( 174 | alpha.unsqueeze(0), 175 | G.view(-1, 1, self.dynamic_kernel_size), 176 | padding=(self.dynamic_kernel_size - 1) // 2, 177 | groups=s.size(0), 178 | ) 179 | g = g.view(s.size(0), self.dynamic_channels, -1).transpose(1, 2) 180 | 181 | f = self.F(alpha.unsqueeze(1)).transpose(1, 2) 182 | 183 | e = self.v(torch.tanh(self.U(f) + self.T(g))).squeeze(-1) + p 184 | 185 | return F.softmax(e, dim=-1) 186 | 187 | 188 | class PreNet(nn.Module): 189 | def __init__( 190 | self, 191 | input_size, 192 | hidden_size, 193 | output_size, 194 | dropout=0.5, 195 | fixed=False, 196 | ): 197 | super().__init__() 198 | self.fc1 = nn.Linear(input_size, hidden_size) 199 | self.fc2 = nn.Linear(hidden_size, output_size) 200 | self.p = dropout 201 | self.fixed = fixed 202 | 203 | def forward(self, x): 204 | x = self.fc1(x) 205 | x = F.relu(x) 206 | x = F.dropout(x, self.p, training=self.training or self.fixed) 207 | x = self.fc2(x) 208 | x = F.relu(x) 209 | x = F.dropout(x, self.p, training=self.training or self.fixed) 210 | return x 211 | 212 | 213 | class BatchNormConv(nn.Module): 214 | def __init__(self, input_channels, output_channels, kernel_size, relu=True): 215 | super().__init__() 216 | self.conv = nn.Conv1d( 217 | input_channels, 218 | output_channels, 219 | kernel_size, 220 | stride=1, 221 | padding=kernel_size // 2, 222 | bias=False, 223 | ) 224 | self.bnorm = nn.BatchNorm1d(output_channels) 225 | self.relu = relu 226 | 227 | def forward(self, x): 228 | x = self.conv(x) 229 | x = F.relu(x) if self.relu is True else x 230 | return self.bnorm(x) 231 | 232 | 233 | class HighwayNetwork(nn.Module): 234 | def __init__(self, size): 235 | super().__init__() 236 | self.linear1 = nn.Linear(size, size) 237 | self.linear2 = nn.Linear(size, size) 238 | nn.init.zeros_(self.linear1.bias) 239 | 240 | def forward(self, x): 241 | x1 = self.linear1(x) 242 | x2 = self.linear2(x) 243 | g = torch.sigmoid(x2) 244 | return g * F.relu(x1) + (1.0 - g) * x 245 | 246 | 247 | class CBHG(nn.Module): 248 | def __init__( 249 | self, 250 | K, 251 | input_channels, 252 | channels, 253 | projection_channels, 254 | n_highways, 255 | highway_size, 256 | rnn_size, 257 | ): 258 | super().__init__() 259 | 260 | self.conv_bank = nn.ModuleList([BatchNormConv(input_channels, channels, kernel_size) for kernel_size in range(1, K + 1)]) 261 | self.max_pool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) 262 | 263 | self.conv_projections = nn.Sequential( 264 | BatchNormConv(K * channels, projection_channels, 3), 265 | BatchNormConv(projection_channels, input_channels, 3, relu=False), 266 | ) 267 | 268 | self.project = (nn.Linear(input_channels, highway_size, bias=False) if input_channels != highway_size else None) 269 | 270 | self.highway = nn.Sequential(*[HighwayNetwork(highway_size) for _ in range(n_highways)]) 271 | 272 | self.rnn = nn.GRU(highway_size, rnn_size, batch_first=True, bidirectional=True) 273 | 274 | def forward(self, x): 275 | T = x.size(-1) 276 | residual = x 277 | 278 | x = [conv(x)[:, :, :T] for conv in self.conv_bank] 279 | x = torch.cat(x, dim=1) 280 | 281 | x = self.max_pool(x) 282 | 283 | x = self.conv_projections(x[:, :, :T]) 284 | 285 | x = x + residual 286 | x = x.transpose(1, 2) 287 | 288 | if self.project is not None: 289 | x = self.project(x) 290 | 291 | x = self.highway(x) 292 | 293 | x, _ = self.rnn(x) 294 | return x 295 | 296 | 297 | class Encoder(nn.Module): 298 | def __init__(self, n_symbols, embedding_dim, prenet, cbhg): 299 | super().__init__() 300 | self.embedding = nn.Embedding(n_symbols, embedding_dim) 301 | self.pre_net = PreNet(**prenet) 302 | self.cbhg = CBHG(**cbhg) 303 | 304 | def forward(self, x): 305 | x = self.embedding(x) 306 | x = self.pre_net(x) 307 | x = self.cbhg(x.transpose(1, 2)) 308 | return x 309 | 310 | 311 | def zoneout(prev, current, p=0.1): 312 | mask = torch.empty_like(prev).bernoulli_(p) 313 | return mask * prev + (1 - mask) * current 314 | 315 | 316 | class DecoderCell(nn.Module): 317 | def __init__( 318 | self, 319 | prenet, 320 | attention, 321 | input_size, 322 | n_mels, 323 | attn_rnn_size, 324 | decoder_rnn_size, 325 | reduction_factor, 326 | zoneout_prob, 327 | ): 328 | super(DecoderCell, self).__init__() 329 | self.zoneout_prob = zoneout_prob 330 | 331 | self.prenet = PreNet(**prenet) 332 | self.dca = DynamicConvolutionAttention(**attention) 333 | self.attn_rnn = nn.LSTMCell(2 * input_size + prenet["output_size"], attn_rnn_size) 334 | self.linear = nn.Linear(2 * input_size + decoder_rnn_size, decoder_rnn_size) 335 | self.rnn1 = nn.LSTMCell(decoder_rnn_size, decoder_rnn_size) 336 | self.rnn2 = nn.LSTMCell(decoder_rnn_size, decoder_rnn_size) 337 | self.proj = nn.Linear(decoder_rnn_size, n_mels * reduction_factor, bias=False) 338 | 339 | def forward(self, h, y, alpha, c, attn_hx, rnn1_hx, rnn2_hx): 340 | B, N = y.size() 341 | 342 | y = self.prenet(y) 343 | attn_h, attn_c = self.attn_rnn(torch.cat((c, y), dim=-1), attn_hx) 344 | if self.training: 345 | attn_h = zoneout(attn_hx[0], attn_h, p=self.zoneout_prob) 346 | 347 | alpha = self.dca(attn_h, alpha) 348 | 349 | c = torch.matmul(alpha.unsqueeze(1), h).squeeze(1) 350 | 351 | x = self.linear(torch.cat((c, attn_h), dim=-1)) 352 | 353 | rnn1_h, rnn1_c = self.rnn1(x, rnn1_hx) 354 | if self.training: 355 | rnn1_h = zoneout(rnn1_hx[0], rnn1_h, p=self.zoneout_prob) 356 | x = x + rnn1_h 357 | 358 | rnn2_h, rnn2_c = self.rnn2(x, rnn2_hx) 359 | if self.training: 360 | rnn2_h = zoneout(rnn2_hx[0], rnn2_h, p=self.zoneout_prob) 361 | x = x + rnn2_h 362 | 363 | y = self.proj(x).view(B, N, 2) 364 | return y, alpha, c, (attn_h, attn_c), (rnn1_h, rnn1_c), (rnn2_h, rnn2_c) 365 | -------------------------------------------------------------------------------- /src/tacotron/text.py: -------------------------------------------------------------------------------- 1 | """ adapted from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | from itertools import islice 5 | 6 | import importlib_resources 7 | 8 | from .TTS_cleaner import english_cleaners 9 | 10 | # fmt: off 11 | PUNCTUATION = ['!', ',', '.', '?'] 12 | SYMBOLS = [ 13 | '_', '~', ' ', *PUNCTUATION, 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 'AO', 'AO0', 14 | 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 15 | 'ER0', 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 16 | 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 17 | 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 18 | ] 19 | # fmt: on 20 | 21 | symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} 22 | id_to_symbol = {i: s for i, s in enumerate(SYMBOLS)} 23 | 24 | abbreviations = [(re.compile(fr"\b{abbreviation}\.", re.IGNORECASE), replacement.upper()) for abbreviation, replacement in [ 25 | ("mrs", "missis"), 26 | ("mr", "mister"), 27 | ("dr", "doctor"), 28 | ("st", "saint"), 29 | ("co", "company"), 30 | ("jr", "junior"), 31 | ("maj", "major"), 32 | ("gen", "general"), 33 | ("drs", "doctors"), 34 | ("rev", "reverend"), 35 | ("lt", "lieutenant"), 36 | ("hon", "honorable"), 37 | ("sgt", "sergeant"), 38 | ("capt", "captain"), 39 | ("esq", "esquire"), 40 | ("ltd", "limited"), 41 | ("col", "colonel"), 42 | ("ft", "fort"), 43 | ("etc", "etcetera"), 44 | ]] 45 | parentheses_pattern = re.compile(r"(?<=[.,!?] )[\(\[]|[\)\]](?=[.,!?])|^[\(\[]|[\)\]]$") 46 | dash_pattern = re.compile(r"(?<=[.,!?] )-- ") 47 | alt_entry_pattern = re.compile(r"(?<=\w)\((\d)\)") 48 | tokenizer_pattern = re.compile(r"[\w\{\}']+|[.,!?]") 49 | 50 | 51 | def expand_abbreviations(text): 52 | for pattern, replacement in abbreviations: 53 | text = pattern.sub(replacement, text) 54 | return text 55 | 56 | 57 | def format_alt_entry(text): 58 | return alt_entry_pattern.sub(r"{\1}", text) 59 | 60 | 61 | def replace_symbols(text): 62 | # replace semi-colons and colons with commas 63 | text = text.replace(";", ",") 64 | text = text.replace(":", ",") 65 | 66 | # replace dashes with commas 67 | text = dash_pattern.sub("", text) 68 | text = text.replace(" --", ",") 69 | text = text.replace(" - ", ", ") 70 | 71 | # split hyphenated words 72 | text = text.replace("-", " ") 73 | 74 | # use {#} to indicate alternate pronunciations 75 | text = format_alt_entry(text) 76 | 77 | # replace parentheses with commas 78 | text = parentheses_pattern.sub("", text) 79 | text = text.replace(")", ",") 80 | text = text.replace(" (", ", ") 81 | text = text.replace("]", ",") 82 | text = text.replace(" [", ", ") 83 | return text 84 | 85 | 86 | def clean(text): 87 | text = text.upper() 88 | text = expand_abbreviations(text) 89 | text = replace_symbols(text) 90 | return text 91 | 92 | 93 | def tokenize(text): 94 | return tokenizer_pattern.findall(text) 95 | 96 | 97 | def load_cmudict(): 98 | """Loads the CMU Pronouncing Dictionary""" 99 | 100 | dict_ref = importlib_resources.files("tacotron").joinpath("cmudict-0.7b.txt") 101 | with open(dict_ref, encoding="ISO-8859-1") as file: 102 | cmudict = (line.strip().split(" ") for line in islice(file, 126, 133905)) 103 | cmudict = {format_alt_entry(word): pronunciation for word, pronunciation in cmudict} 104 | return cmudict 105 | 106 | 107 | def parse_text(text, cmudict): 108 | words = tokenize(english_cleaners(text).upper()) 109 | 110 | # check if any words are not in the dictionary 111 | stripped = (word for word in words if word not in PUNCTUATION) 112 | out_of_vocab = set(word for word in stripped if word not in cmudict) 113 | if out_of_vocab: 114 | out_of_vocab_list = ", ".join(out_of_vocab) 115 | words_new = [word for word in words if word not in out_of_vocab] 116 | words = words_new 117 | if len(out_of_vocab) > 1: 118 | print(f'[*] {out_of_vocab_list} will be removed because they are not in the pronunciation dictionary!') 119 | else: 120 | print(f'[*] {out_of_vocab_list} will be removed because it is not in the pronunciation dictionary!') 121 | # raise KeyError( 122 | # f"Please add {out_of_vocab_list} to the pronunciation dictionary." 123 | # ) 124 | 125 | words = (cmudict[word] if word not in PUNCTUATION else word for word in words) 126 | words = (word.split(" ") for word in words) 127 | words = (x for word in words for x in (word, [" "])) 128 | symbols = list(symbol for word in words for symbol in word) 129 | symbols.append("~") 130 | return symbols 131 | 132 | 133 | def text_to_id(text, cmudict): 134 | """ 135 | Converts text to a sequence of symbol ids. 136 | 137 | Parameters: 138 | text (string): The input text. 139 | cmudict (dict): The pronuniation dictionary used for 140 | grapheme-to-phone conversion 141 | 142 | Returns: 143 | Tensor: The sequence of symbol ids. 144 | """ 145 | symbols = parse_text(text, cmudict) 146 | return [symbol_to_id[symbol] for symbol in symbols] 147 | -------------------------------------------------------------------------------- /src/univoc/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Vocoder 2 | from .dataset import VocoderDataset -------------------------------------------------------------------------------- /src/univoc/config/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | preprocess: 3 | sr: 16000 4 | hop_length: 200 5 | win_length: 800 6 | n_fft: 2048 7 | n_mels: 80 8 | fmin: 50 9 | preemph: 0.97 10 | top_db: 80 11 | ref_db: 20 12 | mulaw: 13 | bits: 10 14 | 15 | model: 16 | n_mels: ${preprocess.n_mels} 17 | conditioning_size: 128 18 | embedding_dim: 256 19 | rnn_size: 896 20 | fc_size: 1024 21 | bits: ${preprocess.mulaw.bits} 22 | hop_length: ${preprocess.hop_length} 23 | sr: ${preprocess.sr} 24 | 25 | train: 26 | batch_size: 32 27 | n_steps: 150000 28 | sample_frames: 24 29 | optimizer: 30 | lr: 4e-4 31 | scheduler: 32 | step_size: 25000 33 | gamma: 0.5 34 | checkpoint_interval: 25000 35 | n_workers: 8 36 | 37 | -------------------------------------------------------------------------------- /src/univoc/config/preprocess.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - config 3 | 4 | in_dir: ??? 5 | out_dir: ??? -------------------------------------------------------------------------------- /src/univoc/config/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - config 3 | 4 | resume: false 5 | checkpoint_dir: ??? 6 | dataset_dir: ??? -------------------------------------------------------------------------------- /src/univoc/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import torch 4 | import json 5 | import random 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class VocoderDataset(Dataset): 10 | def __init__(self, root, sample_frames=24, hop_length=200): 11 | self.root = Path(root) 12 | self.sample_frames = sample_frames 13 | self.hop_length = hop_length 14 | 15 | metadata_path = self.root / "train.json" 16 | with open(metadata_path) as file: 17 | metadata = json.load(file) 18 | self.metadata = [Path(path) for _, path in metadata] 19 | 20 | def __len__(self): 21 | return len(self.metadata) 22 | 23 | def __getitem__(self, index): 24 | path = self.metadata[index] 25 | path = self.root / path 26 | 27 | audio = np.load(path.with_suffix(".wav.npy")) 28 | mel = np.load(path.with_suffix(".mel.npy")) 29 | 30 | pos = random.randint(0, mel.shape[-1] - self.sample_frames - 1) 31 | mel = mel[:, pos : pos + self.sample_frames] 32 | 33 | p, q = pos, pos + self.sample_frames 34 | audio = audio[p * self.hop_length : q * self.hop_length + 1] 35 | 36 | return torch.LongTensor(audio), torch.FloatTensor(mel.T) 37 | -------------------------------------------------------------------------------- /src/univoc/model.py: -------------------------------------------------------------------------------- 1 | from math import log2 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from tqdm import tqdm 8 | import librosa 9 | 10 | import importlib_resources 11 | from omegaconf import OmegaConf 12 | 13 | from config import Settings, audio_default_settings 14 | from utils import set_seed, SingleExampleOutput 15 | 16 | 17 | def get_gru_cell(gru): 18 | gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) 19 | gru_cell.weight_hh.data = gru.weight_hh_l0.data 20 | gru_cell.weight_ih.data = gru.weight_ih_l0.data 21 | gru_cell.bias_hh.data = gru.bias_hh_l0.data 22 | gru_cell.bias_ih.data = gru.bias_ih_l0.data 23 | return gru_cell 24 | 25 | 26 | class Vocoder(nn.Module): 27 | 28 | def __init__( 29 | self, 30 | n_mels, 31 | conditioning_size, 32 | embedding_dim, 33 | rnn_size, 34 | fc_size, 35 | bits, 36 | hop_length, 37 | sr, 38 | ): 39 | super().__init__() 40 | self.rnn_size = rnn_size 41 | self.bits = bits 42 | self.hop_length = hop_length 43 | self.sr = sr 44 | 45 | self.rnn1 = nn.GRU( 46 | n_mels, 47 | conditioning_size, 48 | num_layers=2, 49 | batch_first=True, 50 | bidirectional=True, 51 | ) 52 | self.embedding = nn.Embedding(2**self.bits, embedding_dim) 53 | self.rnn2 = nn.GRU(embedding_dim + 2 * conditioning_size, rnn_size, batch_first=True) 54 | self.fc1 = nn.Linear(rnn_size, fc_size) 55 | self.fc2 = nn.Linear(fc_size, 2**self.bits) 56 | 57 | @classmethod 58 | def from_pretrained(cls, url, map_location=None, cfg_path=None): 59 | r""" 60 | Loads the Torch serialized object at the given URL (uses torch.hub.load_state_dict_from_url). 61 | 62 | Parameters: 63 | url (string): URL of the weights to download 64 | cfg_path (Path): path to config file. Defaults to univoc/config/config.yaml 65 | """ 66 | cfg_ref = (importlib_resources.files("univoc.config").joinpath("config.yaml") if cfg_path is None else cfg_path) 67 | with cfg_ref.open() as file: 68 | cfg = OmegaConf.load(file) 69 | checkpoint = torch.hub.load_state_dict_from_url(url, map_location=map_location) 70 | model = cls(**cfg.model) 71 | model.load_state_dict(checkpoint["model"]) 72 | model.eval() 73 | return model 74 | 75 | def forward(self, x, mels): 76 | mels, _ = self.rnn1(mels) 77 | 78 | mels = F.interpolate(mels.transpose(1, 2), scale_factor=self.hop_length) 79 | mels = mels.transpose(1, 2) 80 | 81 | x = self.embedding(x) 82 | 83 | x, _ = self.rnn2(torch.cat((x, mels), dim=2)) 84 | 85 | x = F.relu(self.fc1(x)) 86 | x = self.fc2(x) 87 | return x 88 | 89 | @torch.no_grad() 90 | def generate(self, mel): 91 | r""" 92 | Generates an audio waverform from a log-Mel spectrogram. 93 | 94 | Parameters: 95 | mel (Tensor): of shape (1, seq_len, n_mels) containing the log-Mel spectrogram. 96 | 97 | Returns: 98 | Tuple[np.array, int]: The resulting waveform of shape (seq_len * hop_length) and sample rate in Hz. 99 | """ 100 | wav = [] 101 | cell = get_gru_cell(self.rnn2) 102 | 103 | mel, _ = self.rnn1(mel) 104 | 105 | mel = F.interpolate(mel.transpose(1, 2), scale_factor=self.hop_length) 106 | mel = mel.transpose(1, 2) 107 | 108 | h = torch.zeros(mel.size(0), self.rnn_size, device=mel.device) 109 | x = torch.zeros(mel.size(0), device=mel.device, dtype=torch.long) 110 | x = x.fill_(2**(self.bits - 1)) 111 | 112 | for m in tqdm(torch.unbind(mel, dim=1), leave=False): 113 | x = self.embedding(x) 114 | h = cell(torch.cat((x, m), dim=1), h) 115 | 116 | x = F.relu(self.fc1(h)) 117 | logits = self.fc2(x) 118 | 119 | posterior = F.softmax(logits, dim=1) 120 | dist = torch.distributions.Categorical(posterior) 121 | 122 | x = dist.sample() 123 | wav.append(x.item()) 124 | 125 | wav = np.asarray(wav, dtype=np.int) 126 | wav = librosa.mu_expand(wav - 2**(self.bits - 1), mu=2**self.bits - 1) 127 | return wav, self.sr 128 | 129 | @torch.no_grad() 130 | def encode_speech(self, mel, message: str, settings: Settings = audio_default_settings, tqdm_desc: str = 'Enc '): 131 | from stega_cy import encode_step 132 | algo, temp, top_p, _, seed = settings() 133 | set_seed(seed) 134 | 135 | wav = [] 136 | cell = get_gru_cell(self.rnn2) 137 | 138 | mel, _ = self.rnn1(mel) 139 | 140 | mel = F.interpolate(mel.transpose(1, 2), scale_factor=self.hop_length) 141 | mel = mel.transpose(1, 2) 142 | 143 | # mel = mel[:, :48000, :] # generate only first 3 seconds 144 | 145 | h = torch.zeros(mel.size(0), self.rnn_size, device=mel.device) 146 | x = torch.zeros(mel.size(0), device=mel.device, dtype=torch.long) 147 | x = x.fill_(2**(self.bits - 1)) 148 | 149 | total_capacity = 0 150 | total_entropy = 0 151 | total_log_probs = 0 # to calculate the perplexity 152 | total_kld = 0 153 | total_minimum_entropy = 0 154 | max_kld = 0 155 | 156 | message_encoded = '' 157 | start = time.time() 158 | for m in tqdm(torch.unbind(mel, dim=1), desc=tqdm_desc, ncols=70): 159 | x = self.embedding(x) 160 | h = cell(torch.cat((x, m), dim=1), h) 161 | 162 | x = F.relu(self.fc1(h)) 163 | logits = self.fc2(x) 164 | logits, indices = logits[0, :].sort(descending=True) 165 | logits = logits.double() 166 | 167 | logits = logits / temp 168 | 169 | probs = F.softmax(logits, dim=-1) 170 | 171 | if top_p < 1.0: 172 | cum_probs = probs.cumsum(0) 173 | 174 | k = (cum_probs > top_p).nonzero()[0].item() + 1 175 | probs = probs[:k] 176 | indices = indices[:k] 177 | probs = 1 / cum_probs[k - 1] * probs # Normalization 178 | probs = probs.tolist() 179 | indices = indices.tolist() 180 | 181 | sampled_index, capacity_t, entropy_t, kld_step, min_entropy_t = encode_step(settings, indices, probs, message)() 182 | 183 | indices_idx = indices.index(sampled_index) 184 | total_entropy += entropy_t 185 | total_log_probs += log2(probs[indices_idx]) 186 | total_kld += kld_step 187 | total_minimum_entropy += -log2(probs[0]) 188 | if kld_step > max_kld: 189 | max_kld = kld_step 190 | 191 | x = torch.tensor([sampled_index], device=mel.device) 192 | 193 | if capacity_t > 0: 194 | total_capacity += capacity_t 195 | message_encoded += message[:capacity_t] 196 | message = message[capacity_t:] # remove the encoded part of `message` 197 | 198 | wav.append(x.item()) 199 | end = time.time() 200 | perplexity = 2**(-1 / len(wav) * total_log_probs) 201 | wav = np.asarray(wav, dtype=np.int) 202 | # print(wav) 203 | wav = librosa.mu_expand(wav - 2**(self.bits - 1), mu=2**self.bits - 1) 204 | # return wav, self.sr 205 | return SingleExampleOutput(None, wav, total_capacity, total_entropy, total_kld / len(wav), max_kld, perplexity, 206 | end - start, settings, total_minimum_entropy), self.sr 207 | 208 | @torch.no_grad() 209 | def decode_speech(self, mel, stego_speech: np.ndarray, settings: Settings = audio_default_settings, tqdm_desc: str = 'Dec '): 210 | from stega_cy import decode_step 211 | algo, temp, top_p, _, seed = settings() 212 | set_seed(seed) 213 | 214 | cell = get_gru_cell(self.rnn2) 215 | 216 | mel, _ = self.rnn1(mel) 217 | 218 | mel = F.interpolate(mel.transpose(1, 2), scale_factor=self.hop_length) 219 | mel = mel.transpose(1, 2) 220 | 221 | h = torch.zeros(mel.size(0), self.rnn_size, device=mel.device) 222 | x = torch.zeros(mel.size(0), device=mel.device, dtype=torch.long) 223 | x = x.fill_(2**(self.bits - 1)) 224 | 225 | # mu_compress 226 | stego_speech = librosa.mu_compress(stego_speech, mu=2**self.bits - 1) + 2**(self.bits - 1) 227 | 228 | t = 0 229 | message_decoded = '' 230 | start = time.time() 231 | for m in tqdm(torch.unbind(mel, dim=1), desc=tqdm_desc, ncols=70): 232 | x = self.embedding(x) 233 | h = cell(torch.cat((x, m), dim=1), h) 234 | 235 | x = F.relu(self.fc1(h)) 236 | logits = self.fc2(x) 237 | logits, indices = logits[0, :].sort(descending=True) 238 | logits = logits.double() 239 | 240 | logits = logits / temp 241 | 242 | probs = F.softmax(logits, dim=-1) 243 | 244 | if top_p < 1.0: 245 | cum_probs = probs.cumsum(0) 246 | 247 | k = (cum_probs > top_p).nonzero()[0].item() + 1 248 | probs = probs[:k] 249 | indices = indices[:k] 250 | probs = 1 / cum_probs[k - 1] * probs # Normalization 251 | probs = probs.tolist() 252 | indices = indices.tolist() 253 | 254 | message_decoded_t = decode_step(settings, indices, probs, stego_speech[t]) 255 | 256 | if message_decoded_t == 'x': 257 | raise ValueError('Failed to decode!') 258 | message_decoded += message_decoded_t 259 | 260 | x = torch.tensor([stego_speech[t]], device=mel.device) 261 | t += 1 262 | end = time.time() 263 | print('Decode time: {:.2f}s'.format(end - start)) 264 | return message_decoded 265 | 266 | @torch.no_grad() 267 | def random_sample_speech(self, mel, message, settings: Settings = audio_default_settings, tqdm_desc: str = 'Enc '): 268 | from random_sample_cy import encode_step 269 | algo, temp, top_p, _, seed = settings() 270 | set_seed(seed) 271 | 272 | wav = [] 273 | cell = get_gru_cell(self.rnn2) 274 | 275 | mel, _ = self.rnn1(mel) 276 | 277 | mel = F.interpolate(mel.transpose(1, 2), scale_factor=self.hop_length) 278 | mel = mel.transpose(1, 2) 279 | 280 | # mel = mel[:, :48000, :] # generate only first 3 seconds 281 | 282 | h = torch.zeros(mel.size(0), self.rnn_size, device=mel.device) 283 | x = torch.zeros(mel.size(0), device=mel.device, dtype=torch.long) 284 | x = x.fill_(2**(self.bits - 1)) 285 | 286 | total_capacity = 0 287 | total_entropy = 0 288 | total_log_probs = 0 # to calculate the perplexity 289 | total_kld = 0 290 | total_minimum_entropy = 0 291 | max_kld = 0 292 | 293 | message_encoded = '' 294 | start = time.time() 295 | for m in tqdm(torch.unbind(mel, dim=1), desc=tqdm_desc, ncols=70): 296 | x = self.embedding(x) 297 | h = cell(torch.cat((x, m), dim=1), h) 298 | 299 | x = F.relu(self.fc1(h)) 300 | logits = self.fc2(x) 301 | logits, indices = logits[0, :].sort(descending=True) 302 | logits = logits.double() 303 | 304 | logits = logits / temp 305 | 306 | probs = F.softmax(logits, dim=-1) 307 | 308 | if top_p < 1.0: 309 | cum_probs = probs.cumsum(0) 310 | 311 | k = (cum_probs > top_p).nonzero()[0].item() + 1 312 | probs = probs[:k] 313 | indices = indices[:k] 314 | probs = 1 / cum_probs[k - 1] * probs # Normalization 315 | probs = probs.tolist() 316 | indices = indices.tolist() 317 | 318 | sampled_index, capacity_t, entropy_t, kld_step, min_entropy_t = encode_step(indices, probs, message)() 319 | 320 | indices_idx = indices.index(sampled_index) 321 | total_entropy += entropy_t 322 | total_log_probs += log2(probs[indices_idx]) 323 | total_kld += kld_step 324 | total_minimum_entropy += -log2(probs[0]) 325 | if kld_step > max_kld: 326 | max_kld = kld_step 327 | 328 | x = torch.tensor([sampled_index], device=mel.device) 329 | 330 | # if capacity_t > 0: 331 | # total_capacity += capacity_t 332 | # message_encoded += message[:capacity_t] 333 | # message = message[capacity_t:] # remove the encoded part of `message` 334 | 335 | wav.append(x.item()) 336 | end = time.time() 337 | perplexity = 2**(-1 / len(wav) * total_log_probs) 338 | wav = np.asarray(wav, dtype=np.int) 339 | wav = librosa.mu_expand(wav - 2**(self.bits - 1), mu=2**self.bits - 1) 340 | # return wav, self.sr 341 | return SingleExampleOutput(None, wav, total_capacity, total_entropy, total_kld / len(wav), max_kld, perplexity, 342 | end - start, settings, total_minimum_entropy), self.sr 343 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torch.nn.functional as F 5 | from typing import List, Tuple 6 | from transformers import PreTrainedTokenizer, PreTrainedModel 7 | 8 | from config import Settings 9 | 10 | 11 | # Sampling (Encoding) results and statistics for single example 12 | class SingleExampleOutput: 13 | def __init__(self, generated_ids, stego_object, n_bits, total_entropy, ave_kld, max_kld, perplexity, time_cost, settings, 14 | total_minimum_entropy): 15 | self.generated_ids = generated_ids 16 | self.stego_object = stego_object 17 | self.algo = settings.algo 18 | self.temp = settings.temp 19 | self.top_p = settings.top_p 20 | self.n_bits = n_bits 21 | if generated_ids is not None: 22 | self.n_tokens = len(generated_ids) 23 | else: 24 | self.n_tokens = len(stego_object) 25 | self.total_entropy = total_entropy 26 | self.ave_kld = ave_kld 27 | self.max_kld = max_kld 28 | self.embedding_rate = n_bits / self.n_tokens 29 | self.utilization_rate = n_bits / total_entropy if total_entropy != 0 else 0 30 | self.perplexity = perplexity 31 | self.time_cost = time_cost 32 | self.total_minimum_entropy = total_minimum_entropy 33 | 34 | def __str__(self) -> str: 35 | d = self.__dict__ 36 | excluded_attr = ['generated_ids'] 37 | selected_attr = list(d.keys()) 38 | for x in excluded_attr: 39 | selected_attr.remove(x) 40 | return '\n'.join('{} = {}'.format(key, d[key]) for key in selected_attr) 41 | 42 | 43 | def set_seed(sd): 44 | random.seed(sd) 45 | 46 | 47 | # The token indices should be filtered out and their corresponding reasons 48 | # https://huggingface.co/gpt2/raw/main/vocab.json 49 | # filter_out_indices_gpt = { 50 | # -1: "endoftext can't happen", 51 | # 198: "1 newline can't happen", 52 | # 628: "2 newlines can't happen", 53 | # 220: "just one space can't happen", 54 | # 302: "`\u0120re` can't happen", 55 | # 797: "`\u0120Re` can't happen", 56 | # 15860: "`\u0120Enh` can't happen", 57 | # 2943: "`EC` can't happen", 58 | # 764: "`\u0120.` (764) may cause failed decoding to `.` (13)", 59 | # 837: "`\u0120,` (837) may cause failer decoding to `,` (11)" 60 | # } 61 | filter_out_indices_gpt = { 62 | -1: "endoftext can't happen", 63 | 198: "1 newline can't happen", 64 | 628: "2 newlines can't happen", 65 | 764: "`\u0120.` (764) may cause failed decoding to `.` (13)", 66 | 837: "`\u0120,` (837) may cause failer decoding to `,` (11)" 67 | } 68 | contain_dollar_lst = [ 69 | 3, 720, 7198, 13702, 16763, 17971, 22799, 25597, 29568, 29953, 32047, 32382, 32624, 34206, 35307, 36737, 38892, 39280, 40111, 70 | 43641, 45491, 47113, 48082 71 | ] 72 | contain_bad_ellipsis_lst = [19424, 20004, 39864, 44713, 44912, 47082] 73 | 74 | 75 | def gen_random_message(seed=None, length: int = 1000, save_path: str = os.path.join('temp', 'message.txt')) -> None: 76 | # Generating binary message (str) randomly via build-in `random` lib 77 | import random 78 | random.seed(seed) 79 | 80 | message = '' 81 | for _ in range(length): 82 | message += str(random.randint(0, 1)) 83 | print(message) 84 | 85 | if save_path is None: 86 | return message 87 | with open(save_path, 'w', encoding='utf-8') as fout: 88 | fout.write(message) 89 | 90 | 91 | def limit_past(past): 92 | if past is None: 93 | return None 94 | past = list(past) 95 | for i in range(len(past)): 96 | past[i] = list(past[i]) 97 | for j in range(len(past[i])): 98 | past[i][j] = past[i][j][:, :, -1022:] 99 | return past 100 | 101 | 102 | @torch.no_grad() 103 | def get_probs_indices_past(model: PreTrainedModel, 104 | prev=None, 105 | past=None, 106 | settings: Settings = Settings(), 107 | gpt_filter: bool = True) -> Tuple: 108 | # first, get logits from the model 109 | if settings.task == 'text': 110 | if 'gpt2' in settings.model_name: 111 | past = limit_past(past) 112 | model_output = model(prev, past_key_values=past) 113 | past = model_output.past_key_values 114 | logits = model_output.logits[0, -1, :].to(settings.device) 115 | if gpt_filter: 116 | for ele in filter_out_indices_gpt.keys(): 117 | logits[ele] = -1e10 118 | elif settings.model_name == 'transfo-xl-wt103': 119 | model_output = model(prev, mems=past) 120 | past = model_output.mems 121 | logits = model_output.logits[0, -1, :].to(settings.device) 122 | logits[0] = -1e10 # 123 | logits[24] = -1e10 # 124 | elif settings.task == 'image': 125 | model_output = model(prev, past_key_values=past) 126 | past = model_output.past_key_values 127 | logits = model_output.logits[0, :].to(settings.device) 128 | 129 | logits, indices = logits.sort(descending=True) 130 | logits = logits.double() 131 | indices = indices.int() 132 | 133 | if settings.temp is None: 134 | settings.temp = 1.0 135 | logits_temp = logits / settings.temp 136 | probs = F.softmax(logits_temp, dim=-1) 137 | 138 | # Getting the top-p `probs` and `indices` from the last layer of `logits` 139 | if not (settings.top_p is None or settings.top_p == 1.0): 140 | assert settings.top_p > 0 and settings.top_p < 1.0, '`top_p` must be >0 and <=1!' 141 | cum_probs = probs.cumsum(0) 142 | k = (cum_probs > settings.top_p).nonzero()[0].item() + 1 143 | probs = probs[:k] 144 | indices = indices[:k] 145 | probs = 1 / cum_probs[k - 1] * probs # Normalizing 146 | return probs, indices, past 147 | 148 | 149 | def is_alpha(s: str) -> bool: 150 | # A-Za-z 151 | for i in range(len(s)): 152 | c = s[i].lower() 153 | if ord(c) < ord('a') or ord(c) > ord('z'): 154 | return False 155 | return True 156 | 157 | 158 | def check_dir(dir: str): 159 | if not os.path.exists(dir): 160 | os.makedirs(dir, exist_ok=True) 161 | print('A folder called "{}" is created.'.format(dir)) 162 | 163 | 164 | if __name__ == '__main__': 165 | gen_random_message(length=1000000) 166 | -------------------------------------------------------------------------------- /temp/small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/comydream/Discop/3c3a10099a242eae405b49cc4d09fba1abb148ad/temp/small.png --------------------------------------------------------------------------------