├── data ├── __init__.py ├── data_collator.py └── dataset.py ├── s2igan ├── __init__.py ├── sen │ ├── __init__.py │ ├── ied.py │ ├── sed.py │ └── utils.py ├── utils.py ├── rdg │ ├── __init__.py │ ├── relation_supervisor.py │ ├── generator.py │ ├── discriminator.py │ └── utils.py └── loss.py ├── .vscode └── settings.json ├── requirements.txt ├── main.py ├── README.md ├── conf ├── sed_config.yaml ├── sen_config.yaml └── rdg_config.yaml ├── .gitignore ├── train_sed.py ├── train_sen.py └── train_rdg.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /s2igan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.formatting.provider": "black" 3 | } -------------------------------------------------------------------------------- /s2igan/sen/__init__.py: -------------------------------------------------------------------------------- 1 | from .ied import ImageEncoder 2 | from .sed import SpeechEncoder 3 | -------------------------------------------------------------------------------- /s2igan/utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | 4 | def set_non_grad(model): 5 | for param in model.parameters(): 6 | param.requires_grad = False 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.2.0 2 | PySoundFile==0.9.0.post1 3 | python-dotenv==0.21.0 4 | torch==1.11.0 5 | torchaudio==0.11.0 6 | torch-summary==1.4.5 7 | torchvision==0.12.0 8 | tqdm==4.64.1 9 | wandb==0.12.21 -------------------------------------------------------------------------------- /s2igan/rdg/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminator import ( 2 | DiscriminatorFor64By64, 3 | DiscriminatorFor128By128, 4 | DiscriminatorFor256By256, 5 | ) 6 | from .generator import DenselyStackedGenerator 7 | from .relation_supervisor import RelationClassifier 8 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from data.dataset import SENDataset 2 | 3 | 4 | def main(): 5 | dataset = SENDataset( 6 | r"C:\Users\nvatu\OneDrive\Desktop\s2idata\train_flower_en2vi.json", 7 | r"C:\Users\nvatu\OneDrive\Desktop\s2idata\image_oxford\image_oxford\train", 8 | r"C:\Users\nvatu\OneDrive\Desktop\s2idata\oxford\oxford\train", 9 | ) 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /s2igan/rdg/relation_supervisor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class RelationClassifier(nn.Module): 6 | def __init__(self, inp_dim: int = 2048, hid_dim: int = 128): 7 | super().__init__() 8 | self.seq = nn.Sequential( 9 | nn.Linear(inp_dim, hid_dim), 10 | nn.BatchNorm1d(hid_dim), 11 | nn.ReLU(), 12 | nn.Linear(hid_dim, 4), 13 | ) 14 | 15 | def get_params(self): 16 | return [p for p in self.parameters() if p.requires_grad] 17 | 18 | def forward(self, inp_a, inp_b): 19 | inp = torch.cat((inp_a, inp_b), dim=1) 20 | return self.seq(inp) 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 1. Update D 2 | 2. Update RS 3 | 3. Update G 4 | 5 | ``` 6 | train: 7 | json_file: /home/admin/workspace/tuan/s2igan_model/dataset/train_flower_en2vi.json 8 | img_path: /home/admin/workspace/tuan/s2igan_model/dataset/image_oxford/image_oxford 9 | audio_path: /home/admin/workspace/tuan/s2igan_model/dataset/oxford_audio/oxford 10 | input_size: ${data.general.input_size} 11 | n_fft: ${data.general.n_fft} 12 | n_mels: ${data.general.n_mels} 13 | win_length: ${data.general.win_length} 14 | hop_length: ${data.general.hop_length} 15 | test: 16 | json_file: /home/admin/workspace/tuan/s2igan_model/dataset/test_flower_en2vi.json 17 | img_path: /home/admin/workspace/tuan/s2igan_model/dataset/image_oxford/image_oxford 18 | audio_path: /home/admin/workspace/tuan/s2igan_model/dataset/oxford_audio/oxford 19 | input_size: ${data.general.input_size} 20 | n_fft: ${data.general.n_fft} 21 | n_mels: ${data.general.n_mels} 22 | win_length: ${data.general.win_length} 23 | hop_length: ${data.general.hop_length} 24 | ``` 25 | -------------------------------------------------------------------------------- /data/data_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | 6 | def sen_collate_fn(batch): 7 | imgs = torch.stack([i[0] for i in batch]) 8 | specs = pad_sequence([i[1] for i in batch], batch_first=True) # (-1, len, n_mels) 9 | len_specs = torch.LongTensor([i[2] for i in batch]) 10 | labels = torch.LongTensor([i[3] for i in batch]) 11 | 12 | specs = specs.permute(0, 2, 1) # (-1, n_mels, len) 13 | 14 | return imgs, specs, len_specs, labels 15 | 16 | def rdg_collate_fn(batch): 17 | real_imgs = torch.stack([i[0] for i in batch]) 18 | similar_imgs = torch.stack([i[1] for i in batch]) 19 | wrong_imgs = torch.stack([i[2] for i in batch]) 20 | specs = pad_sequence([i[3] for i in batch], batch_first=True) # (-1, len, n_mels) 21 | len_specs = torch.LongTensor([i[4] for i in batch]) 22 | raw_audio = [i[5] for i in batch] 23 | 24 | specs = specs.permute(0, 2, 1) # (-1, n_mels, len) 25 | 26 | return real_imgs, similar_imgs, wrong_imgs, specs, len_specs, raw_audio 27 | 28 | # sed 29 | def sed_collate_fn(batch): 30 | specs = pad_sequence([i[0] for i in batch], batch_first=True) # (-1, len, n_mels) 31 | len_specs = torch.LongTensor([i[1] for i in batch]) 32 | labels = torch.LongTensor([i[2] for i in batch]) 33 | 34 | specs = specs.permute(0, 2, 1) # (-1, n_mels, len) 35 | 36 | return specs, len_specs, labels 37 | -------------------------------------------------------------------------------- /s2igan/sen/ied.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | 5 | from s2igan.utils import set_non_grad 6 | 7 | 8 | class ImageEncoder(nn.Module): 9 | """ 10 | image encoder pretrained on imagenet 11 | """ 12 | 13 | def __init__(self, output_dim: int = 1024): 14 | super().__init__() 15 | # try: 16 | # weights = torchvision.models.get_weight("Inception_V3_Weights.DEFAULT") 17 | # model = torchvision.models.inception_v3(weights=weights) 18 | # except: 19 | # model = torchvision.models.inception_v3(pretrained=True) 20 | model = torchvision.models.resnet50(pretrained=True) #new 21 | # model = torchvision.models.resnet152(pretrained=True) #new 22 | model.AuxLogits = None 23 | model.aux_logits = False 24 | set_non_grad(model) 25 | self.model = model 26 | self.model.fc = nn.Linear(2048, output_dim) 27 | # self.model.fc = nn.Linear(model.fc.in_features, output_dim) #new 28 | 29 | def get_params(self): 30 | return [p for p in self.parameters() if p.requires_grad] 31 | 32 | def freeze_params(self): 33 | for p in self.parameters(): 34 | p.requires_grads = False 35 | 36 | def forward(self, img): 37 | """ 38 | img: (-1, 3, 299, 299) 39 | out: (-1, output_dim) 40 | """ 41 | img = nn.functional.interpolate( 42 | img, size=(299, 299), mode="bilinear", align_corners=False 43 | ) 44 | out = self.model(img) 45 | return nn.functional.normalize(out, p=2, dim=1) 46 | -------------------------------------------------------------------------------- /conf/sed_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | general: 3 | input_size: 299 4 | n_fft: 512 5 | n_mels: 40 6 | win_length: 250 7 | hop_length: 100 8 | n_class: 103 # get from dataset externally 9 | batch_size: 12 10 | num_workers: 20 11 | train: 12 | json_file: /kaggle/input/s2igan-oxford/s2igan_oxford/train_flower_en2vi.json 13 | audio_path: /kaggle/input/s2igan-oxford/s2igan_oxford/oxford_audio/oxford_audio 14 | input_size: ${data.general.input_size} 15 | n_fft: ${data.general.n_fft} 16 | n_mels: ${data.general.n_mels} 17 | win_length: ${data.general.win_length} 18 | hop_length: ${data.general.hop_length} 19 | test: 20 | json_file: /kaggle/input/s2igan-oxford/s2igan_oxford/test_flower_en2vi.json 21 | audio_path: /kaggle/input/s2igan-oxford/s2igan_oxford/oxford_audio/oxford_audio 22 | input_size: ${data.general.input_size} 23 | n_fft: ${data.general.n_fft} 24 | n_mels: ${data.general.n_mels} 25 | win_length: ${data.general.win_length} 26 | hop_length: ${data.general.hop_length} 27 | 28 | model: 29 | image_encoder: 30 | output_dim: 1024 31 | speech_encoder: 32 | input_dim: ${data.general.n_mels} 33 | cnn_dim: [64, 128] 34 | kernel_size: 6 35 | stride: 2 36 | rnn_dim: 512 37 | rnn_num_layers: 2 38 | rnn_type: gru 39 | rnn_dropout: 0.1 40 | rnn_bidirectional: True 41 | attn_heads: 1 42 | attn_dropout: 0.1 43 | classifier: 44 | in_features: ${model.image_encoder.output_dim} 45 | out_features: ${data.general.n_class} 46 | 47 | optimizer: 48 | lr: 0.000004 49 | 50 | scheduler: 51 | use: true 52 | pct_start: 0.5 53 | 54 | experiment: 55 | train: true 56 | test: true 57 | max_epoch: 50 58 | log_wandb: true 59 | 60 | loss: 61 | beta: 10 62 | 63 | ckpt: 64 | speech_encoder: null 65 | image_encoder: null 66 | classifier: null 67 | kaggle: 68 | user: "" -------------------------------------------------------------------------------- /conf/sen_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | general: 3 | input_size: 299 4 | n_fft: 512 5 | n_mels: 40 6 | win_length: 250 7 | hop_length: 100 8 | n_class: 103 # get from dataset externally 9 | batch_size: 12 10 | num_workers: 20 11 | train: 12 | json_file: /kaggle/input/s2igan-oxford/s2igan_oxford/train_flower_en2vi.json 13 | # img_path: /kaggle/input/s2igan-oxford/s2igan_oxford/image_oxford/image_oxford 14 | audio_path: /kaggle/input/s2igan-oxford/s2igan_oxford/oxford_audio/oxford_audio 15 | input_size: ${data.general.input_size} 16 | n_fft: ${data.general.n_fft} 17 | n_mels: ${data.general.n_mels} 18 | win_length: ${data.general.win_length} 19 | hop_length: ${data.general.hop_length} 20 | test: 21 | json_file: /kaggle/input/s2igan-oxford/s2igan_oxford/test_flower_en2vi.json 22 | # img_path: /kaggle/input/s2igan-oxford/s2igan_oxford/image_oxford/image_oxford 23 | audio_path: /kaggle/input/s2igan-oxford/s2igan_oxford/oxford_audio/oxford_audio 24 | input_size: ${data.general.input_size} 25 | n_fft: ${data.general.n_fft} 26 | n_mels: ${data.general.n_mels} 27 | win_length: ${data.general.win_length} 28 | hop_length: ${data.general.hop_length} 29 | 30 | model: 31 | image_encoder: 32 | output_dim: 1024 33 | speech_encoder: 34 | input_dim: ${data.general.n_mels} 35 | cnn_dim: [64, 128] 36 | kernel_size: 6 37 | stride: 2 38 | rnn_dim: 512 39 | rnn_num_layers: 2 40 | rnn_type: gru 41 | rnn_dropout: 0.1 42 | rnn_bidirectional: True 43 | attn_heads: 1 44 | attn_dropout: 0.1 45 | classifier: 46 | in_features: ${model.image_encoder.output_dim} 47 | out_features: ${data.general.n_class} 48 | 49 | optimizer: 50 | lr: 0.000004 51 | 52 | scheduler: 53 | use: true 54 | pct_start: 0.5 55 | 56 | experiment: 57 | train: true 58 | test: true 59 | max_epoch: 50 60 | log_wandb: true 61 | 62 | loss: 63 | beta: 10 64 | 65 | ckpt: 66 | speech_encoder: null 67 | image_encoder: null 68 | classifier: null 69 | kaggle: 70 | user: "" -------------------------------------------------------------------------------- /conf/rdg_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | general: 3 | input_size: 299 4 | n_fft: 512 5 | n_mels: 40 6 | win_length: 250 7 | hop_length: 100 8 | n_class: 103 # get from dataset externally 9 | batch_size: 70 10 | num_workers: 32 11 | train: 12 | json_file: /kaggle/input/s2igan-oxford/s2igan_oxford/train_flower_en2vi.json 13 | img_path: /kaggle/input/s2igan-oxford/s2igan_oxford/image_oxford/image_oxford 14 | audio_path: /kaggle/input/s2igan-oxford/s2igan_oxford/oxford_audio/oxford_audio 15 | input_size: ${data.general.input_size} 16 | n_fft: ${data.general.n_fft} 17 | n_mels: ${data.general.n_mels} 18 | win_length: ${data.general.win_length} 19 | hop_length: ${data.general.hop_length} 20 | test: 21 | json_file: /kaggle/input/s2igan-oxford/s2igan_oxford/test_flower_en2vi.json 22 | img_path: /kaggle/input/s2igan-oxford/s2igan_oxford/image_oxford/image_oxford 23 | audio_path: /kaggle/input/s2igan-oxford/s2igan_oxford/oxford_audio/oxford_audio 24 | input_size: ${data.general.input_size} 25 | n_fft: ${data.general.n_fft} 26 | n_mels: ${data.general.n_mels} 27 | win_length: ${data.general.win_length} 28 | hop_length: ${data.general.hop_length} 29 | 30 | model: 31 | generator: 32 | latent_space_dim: 100 33 | speech_emb_dim: 1024 34 | gan_emb_dim: 128 35 | gen_dim: 64 36 | discriminator: 37 | disc_dim: 64 38 | gan_emb_dim: ${model.generator.gan_emb_dim} 39 | relation_classifier: 40 | inp_dim: 2048 41 | hid_dim: 128 42 | image_encoder: 43 | output_dim: 1024 44 | speech_encoder: 45 | input_dim: ${data.general.n_mels} 46 | cnn_dim: [64, 128] 47 | kernel_size: 6 48 | stride: 2 49 | rnn_dim: 512 50 | rnn_num_layers: 2 51 | rnn_type: gru 52 | rnn_dropout: 0.1 53 | rnn_bidirectional: True 54 | attn_heads: 1 55 | attn_dropout: 0.1 56 | 57 | 58 | optimizer: 59 | lr: 0.0002 60 | 61 | scheduler: 62 | use: true 63 | pct_start: 0.5 64 | 65 | experiment: 66 | train: true 67 | test: true 68 | max_epoch: 100 69 | log_wandb: true 70 | specific_params: 71 | latent_space_dim: ${model.generator.latent_space_dim} 72 | img_dims: 73 | - 64 74 | - 128 75 | - 256 76 | kl_loss_coef: 2 77 | 78 | loss: 79 | beta: 10 80 | 81 | ckpt: 82 | speech_encoder: ckpt/speech_encoder.pt 83 | image_encoder: ckpt/image_encoder.pt -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | sample_data/ 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | ven/ 110 | book_venv/ 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | outputs/ 136 | SpeechCommands 137 | 138 | tb_logs/ -------------------------------------------------------------------------------- /s2igan/sen/sed.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | class SpeechEncoder(nn.Module): 7 | def __init__( 8 | self, 9 | input_dim: int = 40, 10 | cnn_dim: List[int] = [64, 128], 11 | kernel_size: int = 6, 12 | stride: int = 2, 13 | rnn_dim: int = 512, 14 | rnn_num_layers: int = 2, 15 | rnn_type: str = "gru", 16 | rnn_dropout: float = 0.1, 17 | rnn_bidirectional: bool = True, 18 | attn_heads: int = 1, 19 | attn_dropout: float = 0.1, 20 | ): 21 | super().__init__() 22 | assert rnn_type in ["lstm", "gru"] 23 | self.cnn_1 = nn.Sequential( 24 | nn.Conv1d(input_dim, cnn_dim[0], 7 , stride), 25 | nn.BatchNorm1d(cnn_dim[0]), 26 | nn.SiLU(), 27 | nn.Conv1d(cnn_dim[0], 1024, 5 , stride), 28 | nn.BatchNorm1d(1024), 29 | nn.SiLU() 30 | ) 31 | self.cnn_2 = nn.Sequential( 32 | nn.Conv1d(1024, cnn_dim[1], 7 , stride), 33 | nn.BatchNorm1d(cnn_dim[1]), 34 | nn.SiLU(), 35 | nn.Conv1d(cnn_dim[1], 512, 5 , stride), 36 | nn.BatchNorm1d(512), 37 | nn.SiLU(), 38 | ) 39 | 40 | self.kernel_size = kernel_size 41 | self.stride = stride 42 | 43 | rnn_kwargs = dict( 44 | input_size=512, 45 | hidden_size=rnn_dim, 46 | num_layers=rnn_num_layers, 47 | batch_first=True, 48 | dropout=rnn_dropout, 49 | bidirectional=rnn_bidirectional, 50 | ) 51 | if rnn_type == "lstm": 52 | self.rnn = nn.LSTM(**rnn_kwargs) 53 | else: 54 | self.rnn = nn.GRU(**rnn_kwargs) 55 | self.output_dim = rnn_dim * (int(rnn_bidirectional) + 1) 56 | self.self_attention = nn.MultiheadAttention( 57 | embed_dim=self.output_dim, 58 | num_heads=attn_heads, 59 | dropout=attn_dropout, 60 | batch_first=True, 61 | ) 62 | self.feed_forward = nn.Sequential( 63 | nn.Linear(self.output_dim, self.output_dim*2), 64 | nn.Linear(self.output_dim*2, self.output_dim*4), 65 | nn.SiLU(), 66 | nn.Linear(self.output_dim*4, self.output_dim), 67 | ) 68 | 69 | def get_params(self): 70 | return [p for p in self.parameters() if p.requires_grad] 71 | 72 | def forward(self, mel_spec, mel_spec_len): 73 | """ 74 | mel_spec (-1, 40, len) 75 | output (-1, len, rnn_dim * (int(bidirectional) + 1)) 76 | """ 77 | cnn_out = self.cnn_1(mel_spec) 78 | cnn_out = self.cnn_2(cnn_out) 79 | 80 | # l = [ 81 | # torch.div(y - self.kernel_size, self.stride, rounding_mode="trunc") + 1 82 | # for y in mel_spec_len 83 | # ]R 84 | # l = [ 85 | # torch.div(y - self.kernel_size, self.stride, rounding_mode="trunc") + 1 86 | # for y in l 87 | # ] 88 | 89 | cnn_out = cnn_out.permute(0, 2, 1) 90 | 91 | # packed = pack_padded_sequence( 92 | # cnn_out, l, batch_first=True, enforce_sorted=False 93 | # ) 94 | # self.rnn.flatteSn_parameters() 95 | # out, hidden_state = self.rnn(packed) 96 | # out, seq_len = pad_packed_sequence(out, batch_first=True) 97 | # pack input before RNN to reduce computing efforts 98 | out, hidden_state = self.rnn(cnn_out) 99 | 100 | out, weights = self.self_attention(out, out, out) 101 | out = out.mean(dim=1) 102 | out = torch.nn.functional.normalize(out) 103 | out = self.feed_forward(out) 104 | return out 105 | -------------------------------------------------------------------------------- /s2igan/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MatchingLoss(nn.Module): 6 | def __init__(self, beta: float = 10): 7 | """ 8 | from this paper: https://arxiv.org/abs/1711.10485 9 | beta is gamma_3 10 | """ 11 | super().__init__() 12 | self.beta = beta 13 | self.eps = 1e-9 14 | self.criterion = nn.CrossEntropyLoss() 15 | 16 | def forward(self, x, y, labels): 17 | bs = labels.shape[0] 18 | 19 | mask = self.create_mask(labels) 20 | sim = self.cosine_sim(x, y) 21 | sim = sim + mask.log() 22 | 23 | diag_label = torch.autograd.Variable(torch.LongTensor(list(range(bs)))) 24 | if torch.cuda.is_available(): 25 | diag_label = diag_label.to("cuda:0") 26 | loss_0 = self.criterion(sim, diag_label) 27 | loss_1 = self.criterion(sim.T, diag_label) 28 | 29 | return loss_0 + loss_1 30 | 31 | def create_mask(self, labels): 32 | mask = torch.ne(labels.view(1, -1), labels.view(-1, 1)) 33 | mask = mask.fill_diagonal_(1) 34 | mask = mask.to(dtype=torch.float, device=labels.device) 35 | return mask 36 | 37 | def cosine_sim(self, x, y): 38 | x = x.unsqueeze(0) 39 | y = y.unsqueeze(0) 40 | 41 | norm_x = torch.linalg.vector_norm(x, ord=2, dim=-1, keepdim=True) 42 | norm_y = torch.linalg.vector_norm(y, ord=2, dim=-1, keepdim=True) 43 | 44 | num = torch.bmm(x, y.transpose(1, 2)) 45 | den = torch.bmm(norm_x, norm_y.transpose(1, 2)) 46 | 47 | sim = self.beta * (num / den.clamp(min=1e-8)) 48 | 49 | return sim.squeeze() 50 | 51 | 52 | class DistinctiveLoss(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | self.crit = nn.CrossEntropyLoss() 56 | 57 | def forward(self, cls_x, cls_y, labels): 58 | return self.crit(cls_x, labels) + self.crit(cls_y, labels) 59 | 60 | 61 | class SENLoss(nn.Module): 62 | def __init__(self, beta: int = 10): 63 | super().__init__() 64 | self.matching_loss = MatchingLoss(beta) 65 | self.distinctive_loss = DistinctiveLoss() 66 | 67 | def forward(self, x, y, cls_x, cls_y, labels): 68 | match_loss = self.matching_loss(x, y, labels) 69 | dist_loss = self.distinctive_loss(cls_x, cls_y, labels) 70 | return match_loss.detach(), dist_loss.detach(), match_loss + dist_loss 71 | 72 | # sed 73 | class DistinctiveLossSED(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | self.crit = nn.CrossEntropyLoss() 77 | 78 | def forward(self, cls_x, labels): 79 | return self.crit(cls_x, labels) 80 | class SEDLoss(nn.Module): 81 | def __init__(self): 82 | super().__init__() 83 | self.distinctive_loss = DistinctiveLossSED() 84 | 85 | def forward(self, cls_x,labels): 86 | dist_loss = self.distinctive_loss(cls_x, labels) 87 | return dist_loss 88 | 89 | class KLDivergenceLoss(nn.Module): 90 | def __init__(self): 91 | super(KLDivergenceLoss, self).__init__() 92 | 93 | def forward(self, x_mean, x_logvar): 94 | # Compute kl divergence loss 95 | kl_div = torch.mean(x_mean.pow(2) + x_logvar.exp() - 1 - x_logvar) 96 | 97 | return kl_div 98 | 99 | 100 | class RSLoss(nn.Module): 101 | def __init__(self): 102 | super().__init__() 103 | self.crit = nn.CrossEntropyLoss() 104 | 105 | def forward(self, R1, R2, R3, R_GT_FI, zero_labels, one_labels, two_labels): 106 | return ( 107 | self.crit(R1, one_labels.long()) 108 | + self.crit(R2, zero_labels.long()) 109 | + self.crit(R3, two_labels.long()) 110 | + self.crit(R_GT_FI, zero_labels.long()) 111 | ) 112 | -------------------------------------------------------------------------------- /train_sed.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import hydra 4 | import torch 5 | from dotenv import load_dotenv 6 | from omegaconf import DictConfig, OmegaConf 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from torchsummary import summary 10 | import os 11 | import wandb 12 | from data.data_collator import sed_collate_fn 13 | from data.dataset import SEDDataset 14 | from s2igan.loss import SEDLoss 15 | from s2igan.sen.sed import SpeechEncoder 16 | from s2igan.sen.utils import sed_train_epoch, sed_eval_epoch 17 | 18 | config_path = "conf" 19 | config_name = "sed_config" 20 | 21 | 22 | @hydra.main(version_base=None, config_path=config_path, config_name=config_name) 23 | def main(cfg: DictConfig): 24 | bs = cfg.data.general.batch_size 25 | attn_heads = cfg.model.speech_encoder.get("attn_heads", 1) 26 | attn_dropout = cfg.model.speech_encoder.get("attn_dropout", 0.1) 27 | rnn_dropout = cfg.model.speech_encoder.get("rnn_dropout", 0.0) 28 | lr = cfg.optimizer.get("lr", 0.001) 29 | wandb.init(project="sed", name=f"SEN_bs{bs}_lr{lr}_attn{attn_heads}_ad{attn_dropout}_rd{rnn_dropout}_{cfg.kaggle.user}") 30 | 31 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 32 | multi_gpu = torch.cuda.device_count() > 1 33 | device_ids = list(range(torch.cuda.device_count())) 34 | 35 | train_set = SEDDataset(**cfg.data.train) 36 | test_set = SEDDataset(**cfg.data.test) 37 | 38 | nwkers = cfg.data.general.num_workers 39 | train_dataloader = DataLoader( 40 | train_set, bs, shuffle=True, num_workers=nwkers, collate_fn=sed_collate_fn 41 | ) 42 | test_dataloder = DataLoader( 43 | test_set, bs, shuffle=False, num_workers=nwkers, collate_fn=sed_collate_fn 44 | ) 45 | 46 | speech_encoder = SpeechEncoder(**cfg.model.speech_encoder) 47 | classifier = nn.Linear(**cfg.model.classifier) 48 | nn.init.xavier_uniform_(classifier.weight.data) 49 | if cfg.ckpt.speech_encoder: 50 | print("Loading Speech Encoder state dict...") 51 | print(speech_encoder.load_state_dict(torch.load(cfg.ckpt.speech_encoder))) 52 | 53 | if cfg.ckpt.classifier: 54 | print("Loading Classifier state dict...") 55 | print(classifier.load_state_dict(torch.load(cfg.ckpt.classifier))) 56 | 57 | if multi_gpu: 58 | speech_encoder = nn.DataParallel(speech_encoder, device_ids=device_ids) 59 | classifier = nn.DataParallel(classifier, device_ids=device_ids) 60 | 61 | speech_encoder = speech_encoder.to(device) 62 | classifier = classifier.to(device) 63 | 64 | # try: 65 | # image_encoder = torch.compile(image_encoder) 66 | # speech_encoder = torch.compile(speech_encoder) 67 | # classifier = torch.compile(classifier) 68 | # except: 69 | # print("Can't activate Pytorch 2.0") 70 | 71 | if multi_gpu: 72 | model_params = ( 73 | speech_encoder.module.get_params() 74 | + list(classifier.module.parameters()) 75 | ) 76 | else: 77 | model_params = ( 78 | speech_encoder.get_params() 79 | + list(classifier.parameters()) 80 | ) 81 | 82 | optimizer = torch.optim.AdamW(model_params, **cfg.optimizer) 83 | scheduler = None 84 | if cfg.scheduler.use: 85 | steps_per_epoch = len(train_dataloader) 86 | sched_dict = dict( 87 | epochs=cfg.experiment.max_epoch, 88 | steps_per_epoch=steps_per_epoch, 89 | max_lr=cfg.optimizer.lr, 90 | pct_start=cfg.scheduler.pct_start, 91 | ) 92 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, **sched_dict) 93 | 94 | criterion = SEDLoss().to(device) 95 | 96 | log_wandb = cfg.experiment.log_wandb 97 | 98 | if cfg.experiment.train: 99 | for epoch in range(cfg.experiment.max_epoch): 100 | train_result = sed_train_epoch( 101 | speech_encoder, 102 | classifier, 103 | train_dataloader, 104 | optimizer, 105 | scheduler, 106 | criterion, 107 | device, 108 | epoch, 109 | log_wandb, 110 | ) 111 | eval_result = sed_eval_epoch( 112 | speech_encoder, 113 | classifier, 114 | test_dataloder, 115 | criterion, 116 | device, 117 | epoch, 118 | log_wandb, 119 | ) 120 | 121 | save_dir = "/kaggle/working/save_ckpt" 122 | if not os.path.exists(save_dir): 123 | os.makedirs(save_dir) 124 | 125 | # Tiếp tục lưu trữ trọng số của mô hình 126 | torch.save(speech_encoder.state_dict(), os.path.join(save_dir, "speech_encoder_SED.pt")) 127 | torch.save(classifier.state_dict(), os.path.join(save_dir, "classifier.pt")) 128 | 129 | print("Train result:", train_result) 130 | print("Eval result:", eval_result) 131 | 132 | 133 | if __name__ == "__main__": 134 | load_dotenv() 135 | main() -------------------------------------------------------------------------------- /train_sen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import hydra 4 | import torch 5 | from dotenv import load_dotenv 6 | from omegaconf import DictConfig, OmegaConf 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from torchsummary import summary 10 | import os 11 | import wandb 12 | from data.data_collator import sen_collate_fn 13 | from data.dataset import SENDataset 14 | from s2igan.loss import SENLoss 15 | from s2igan.sen import ImageEncoder, SpeechEncoder 16 | from s2igan.sen.utils import sen_train_epoch, sen_eval_epoch 17 | 18 | config_path = "conf" 19 | config_name = "sen_config" 20 | 21 | 22 | @hydra.main(version_base=None, config_path=config_path, config_name=config_name) 23 | def main(cfg: DictConfig): 24 | bs = cfg.data.general.batch_size 25 | attn_heads = cfg.model.speech_encoder.attn_heads 26 | attn_dropout = cfg.model.speech_encoder.attn_dropout 27 | rnn_dropout = cfg.model.speech_encoder.rnn_dropout 28 | lr = cfg.optimizer.lr 29 | wandb.init(project="speech2image", name=f"SEN_bs{bs}_lr{lr}_attn{attn_heads}_ad{attn_dropout}_rd{rnn_dropout}_{cfg.kaggle.user}") 30 | 31 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 32 | multi_gpu = torch.cuda.device_count() > 1 33 | device_ids = list(range(torch.cuda.device_count())) 34 | 35 | train_set = SENDataset(**cfg.data.train) 36 | test_set = SENDataset(**cfg.data.test) 37 | 38 | nwkers = cfg.data.general.num_workers 39 | train_dataloader = DataLoader( 40 | train_set, bs, shuffle=True, num_workers=nwkers, collate_fn=sen_collate_fn 41 | ) 42 | test_dataloder = DataLoader( 43 | test_set, bs, shuffle=False, num_workers=nwkers, collate_fn=sen_collate_fn 44 | ) 45 | 46 | image_encoder = ImageEncoder(**cfg.model.image_encoder) 47 | speech_encoder = SpeechEncoder(**cfg.model.speech_encoder) 48 | classifier = nn.Linear(**cfg.model.classifier) 49 | nn.init.xavier_uniform_(classifier.weight.data) 50 | 51 | if cfg.ckpt.image_encoder: 52 | print("Loading Image Encoder state dict...") 53 | print(image_encoder.load_state_dict(torch.load(cfg.ckpt.image_encoder))) 54 | 55 | if cfg.ckpt.speech_encoder: 56 | print("Loading Speech Encoder state dict...") 57 | print(speech_encoder.load_state_dict(torch.load(cfg.ckpt.speech_encoder))) 58 | 59 | if cfg.ckpt.classifier: 60 | print("Loading Classifier state dict...") 61 | print(classifier.load_state_dict(torch.load(cfg.ckpt.classifier))) 62 | 63 | if multi_gpu: 64 | image_encoder = nn.DataParallel(image_encoder, device_ids=device_ids) 65 | speech_encoder = nn.DataParallel(speech_encoder, device_ids=device_ids) 66 | classifier = nn.DataParallel(classifier, device_ids=device_ids) 67 | 68 | image_encoder = image_encoder.to(device) 69 | speech_encoder = speech_encoder.to(device) 70 | classifier = classifier.to(device) 71 | 72 | # try: 73 | # image_encoder = torch.compile(image_encoder) 74 | # speech_encoder = torch.compile(speech_encoder) 75 | # classifier = torch.compile(classifier) 76 | # except: 77 | # print("Can't activate Pytorch 2.0") 78 | 79 | if multi_gpu: 80 | model_params = ( 81 | image_encoder.module.get_params() 82 | + speech_encoder.module.get_params() 83 | + list(classifier.module.parameters()) 84 | ) 85 | else: 86 | model_params = ( 87 | image_encoder.get_params() 88 | + speech_encoder.get_params() 89 | + list(classifier.parameters()) 90 | ) 91 | 92 | optimizer = torch.optim.AdamW(model_params, **cfg.optimizer) 93 | scheduler = None 94 | if cfg.scheduler.use: 95 | steps_per_epoch = len(train_dataloader) 96 | sched_dict = dict( 97 | epochs=cfg.experiment.max_epoch, 98 | steps_per_epoch=steps_per_epoch, 99 | max_lr=cfg.optimizer.lr, 100 | pct_start=cfg.scheduler.pct_start, 101 | ) 102 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, **sched_dict) 103 | 104 | criterion = SENLoss(**cfg.loss).to(device) 105 | 106 | log_wandb = cfg.experiment.log_wandb 107 | 108 | if cfg.experiment.train: 109 | for epoch in range(cfg.experiment.max_epoch): 110 | train_result = sen_train_epoch( 111 | image_encoder, 112 | speech_encoder, 113 | classifier, 114 | train_dataloader, 115 | optimizer, 116 | scheduler, 117 | criterion, 118 | device, 119 | epoch, 120 | log_wandb, 121 | ) 122 | eval_result = sen_eval_epoch( 123 | image_encoder, 124 | speech_encoder, 125 | classifier, 126 | test_dataloder, 127 | criterion, 128 | device, 129 | epoch, 130 | log_wandb, 131 | ) 132 | 133 | save_dir = "/kaggle/working/save_ckpt" 134 | if not os.path.exists(save_dir): 135 | os.makedirs(save_dir) 136 | 137 | # Tiếp tục lưu trữ trọng số của mô hình 138 | torch.save(speech_encoder.state_dict(), os.path.join(save_dir, "speech_encoder.pt")) 139 | torch.save(image_encoder.state_dict(), os.path.join(save_dir, "image_encoder.pt")) 140 | torch.save(classifier.state_dict(), os.path.join(save_dir, "classifier.pt")) 141 | 142 | print("Train result:", train_result) 143 | print("Eval result:", eval_result) 144 | 145 | 146 | if __name__ == "__main__": 147 | load_dotenv() 148 | main() -------------------------------------------------------------------------------- /train_rdg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import hydra 4 | import torch 5 | from dotenv import load_dotenv 6 | from omegaconf import DictConfig, OmegaConf 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from torchsummary import summary 10 | 11 | import wandb 12 | from data.data_collator import rdg_collate_fn 13 | from data.dataset import RDGDataset 14 | from s2igan.loss import KLDivergenceLoss, RSLoss 15 | from s2igan.rdg import ( 16 | DenselyStackedGenerator, 17 | DiscriminatorFor64By64, 18 | DiscriminatorFor128By128, 19 | DiscriminatorFor256By256, 20 | RelationClassifier, 21 | ) 22 | from s2igan.rdg.utils import rdg_train_epoch 23 | from s2igan.sen import ImageEncoder, SpeechEncoder 24 | from s2igan.utils import set_non_grad 25 | 26 | config_path = "conf" 27 | config_name = "rdg_config" 28 | 29 | 30 | @hydra.main(version_base=None, config_path=config_path, config_name=config_name) 31 | def main(cfg: DictConfig): 32 | if cfg.experiment.log_wandb: 33 | wandb.init(project="speech2image", name="RDG") 34 | 35 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 36 | multi_gpu = torch.cuda.device_count() > 1 37 | device_ids = list(range(torch.cuda.device_count())) 38 | 39 | train_set = RDGDataset(**cfg.data.train) 40 | test_set = RDGDataset(**cfg.data.test) 41 | 42 | bs = cfg.data.general.batch_size 43 | nwkers = cfg.data.general.num_workers 44 | train_dataloader = DataLoader( 45 | train_set, bs, shuffle=True, num_workers=nwkers, collate_fn=rdg_collate_fn 46 | ) 47 | test_dataloder = DataLoader( 48 | test_set, bs, shuffle=False, num_workers=nwkers, collate_fn=rdg_collate_fn 49 | ) 50 | 51 | generator = DenselyStackedGenerator(**cfg.model.generator) 52 | discrminator_64 = DiscriminatorFor64By64(**cfg.model.discriminator) 53 | discrminator_128 = DiscriminatorFor128By128(**cfg.model.discriminator) 54 | discrminator_256 = DiscriminatorFor256By256(**cfg.model.discriminator) 55 | relation_classifier = RelationClassifier(**cfg.model.relation_classifier) 56 | image_encoder = ImageEncoder(**cfg.model.image_encoder) 57 | speech_encoder = SpeechEncoder(**cfg.model.speech_encoder) 58 | 59 | if cfg.ckpt.image_encoder: 60 | print("Loading Image Encoder state dict...") 61 | print(image_encoder.load_state_dict(torch.load(cfg.ckpt.image_encoder))) 62 | if cfg.ckpt.speech_encoder: 63 | print("Loading Speech Encoder state dict...") 64 | print(speech_encoder.load_state_dict(torch.load(cfg.ckpt.speech_encoder))) 65 | set_non_grad(image_encoder) 66 | set_non_grad(speech_encoder) 67 | 68 | if multi_gpu: 69 | generator = nn.DataParallel(generator, device_ids=device_ids) 70 | discrminator_64 = nn.DataParallel(discrminator_64, device_ids=device_ids) 71 | discrminator_128 = nn.DataParallel(discrminator_128, device_ids=device_ids) 72 | discrminator_256 = nn.DataParallel(discrminator_256, device_ids=device_ids) 73 | relation_classifier = nn.DataParallel( 74 | relation_classifier, device_ids=device_ids 75 | ) 76 | image_encoder = nn.DataParallel(image_encoder, device_ids=device_ids) 77 | speech_encoder = nn.DataParallel(speech_encoder, device_ids=device_ids) 78 | 79 | generator = generator.to(device) 80 | discrminator_64 = discrminator_64.to(device) 81 | discrminator_128 = discrminator_128.to(device) 82 | discrminator_256 = discrminator_256.to(device) 83 | relation_classifier = relation_classifier.to(device) 84 | image_encoder = image_encoder.to(device) 85 | speech_encoder = speech_encoder.to(device) 86 | 87 | # try: 88 | # image_encoder = torch.compile(image_encoder) 89 | # discrminator_64 = torch.compile(discrminator_64) 90 | # discrminator_128 = torch.compile(discrminator_128) 91 | # discrminator_256 = torch.compile(discrminator_256) 92 | # relation_classifier = torch.compile(relation_classifier) 93 | # image_encoder = torch.compile(image_encoder) 94 | # speech_encoder = torch.compile(speech_encoder) 95 | # except: 96 | # print("Can't activate Pytorch 2.0") 97 | 98 | discriminators = { 99 | 64: discrminator_64, 100 | 128: discrminator_128, 101 | 256: discrminator_256, 102 | } 103 | 104 | models = { 105 | "gen": generator, 106 | "disc": discriminators, 107 | "rs": relation_classifier, 108 | "ied": image_encoder, 109 | "sed": speech_encoder, 110 | } 111 | 112 | optimizer_generator = torch.optim.AdamW(generator.parameters(), **cfg.optimizer) 113 | optimizer_discrminator = { 114 | key: torch.optim.AdamW(discriminators[key].parameters(), **cfg.optimizer) 115 | for key in discriminators.keys() 116 | } 117 | optimizer_rs = torch.optim.AdamW(relation_classifier.parameters(), **cfg.optimizer) 118 | optimizers = { 119 | "gen": optimizer_generator, 120 | "disc": optimizer_discrminator, 121 | "rs": optimizer_rs, 122 | } 123 | 124 | steps_per_epoch = len(train_dataloader) 125 | sched_dict = dict( 126 | epochs=cfg.experiment.max_epoch, 127 | steps_per_epoch=steps_per_epoch, 128 | max_lr=cfg.optimizer.lr, 129 | pct_start=cfg.scheduler.pct_start, 130 | ) 131 | schedulers = { 132 | "gen": torch.optim.lr_scheduler.OneCycleLR(optimizer_generator, **sched_dict), 133 | "disc": { 134 | key: torch.optim.lr_scheduler.OneCycleLR( 135 | optimizer_discrminator[key], **sched_dict 136 | ) 137 | for key in discriminators.keys() 138 | }, 139 | "rs": torch.optim.lr_scheduler.OneCycleLR(optimizer_rs, **sched_dict), 140 | } 141 | 142 | criterions = { 143 | "kl": KLDivergenceLoss().to(device), 144 | "rs": RSLoss().to(device), 145 | "bce": nn.BCELoss().to(device), 146 | "ce": nn.CrossEntropyLoss().to(device), 147 | } 148 | 149 | log_wandb = cfg.experiment.log_wandb 150 | specific_params = cfg.experiment.specific_params 151 | if cfg.experiment.train: 152 | for epoch in range(cfg.experiment.max_epoch): 153 | train_result = rdg_train_epoch( 154 | models, 155 | train_dataloader, 156 | optimizers, 157 | schedulers, 158 | criterions, 159 | specific_params, 160 | device, 161 | epoch, 162 | log_wandb, 163 | ) 164 | print("Train result:", train_result) 165 | 166 | 167 | if __name__ == "__main__": 168 | load_dotenv() 169 | main() 170 | -------------------------------------------------------------------------------- /s2igan/rdg/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ConditioningAugmentationNetwork(nn.Module): 6 | def __init__( 7 | self, 8 | speech_emb_dim: int = 1024, 9 | gan_emb_dim: int = 128, 10 | negative_slope: float = 0.2, 11 | ): 12 | super().__init__() 13 | self.speech_dim = speech_emb_dim 14 | self.gan_emb_dim = gan_emb_dim 15 | self.fc = nn.Linear(self.speech_dim, self.gan_emb_dim * 4) 16 | # multiply by 2 if leaky relu 17 | # multiply by 4 if GLU 18 | # self.act = nn.LeakyReLU(negative_slope) 19 | self.act = nn.GLU(dim=1) 20 | 21 | def forward(self, x): 22 | """ 23 | x: speech embedding vector 24 | """ 25 | x = self.act(self.fc(x)) 26 | mu = x[:, : self.gan_emb_dim] 27 | logvar = x[:, self.gan_emb_dim :] 28 | std = torch.exp(logvar.mul(0.5)) 29 | z = torch.randn(std.size()).to(x.device) 30 | z_ca = mu + std * z 31 | return z_ca, mu, logvar 32 | 33 | 34 | class Interpolate(nn.Module): 35 | def __init__(self, scale_factor, mode, size=None): 36 | super(Interpolate, self).__init__() 37 | self.interp = nn.functional.interpolate 38 | self.scale_factor = scale_factor 39 | self.mode = mode 40 | self.size = size 41 | 42 | def forward(self, x): 43 | return self.interp( 44 | x, scale_factor=self.scale_factor, mode=self.mode, size=self.size 45 | ) 46 | 47 | 48 | class UpBlock(nn.Module): 49 | def __init__(self, in_channels, out_channels): 50 | super().__init__() 51 | self.first_seq = Interpolate(scale_factor=2, mode="nearest") 52 | self.seq = nn.Sequential( 53 | nn.Conv2d( 54 | in_channels, 55 | out_channels * 2, 56 | kernel_size=3, 57 | stride=1, 58 | padding=1, 59 | bias=False, 60 | ), 61 | nn.BatchNorm2d(out_channels * 2), 62 | nn.GLU(dim=1), 63 | ) 64 | 65 | def forward(self, x): 66 | x = self.first_seq(x) 67 | x = self.seq(x) 68 | return x 69 | 70 | 71 | class Block(nn.Module): 72 | def __init__(self, desc): 73 | super().__init__() 74 | self.desc = desc 75 | 76 | def forward(self, x): 77 | return x 78 | 79 | 80 | class InitStateGenerator(nn.Module): 81 | def __init__(self, in_dim: int, gen_dim: int): 82 | super().__init__() 83 | # in_dim = z_dim + gan_emb_dim 84 | self.in_dim = in_dim 85 | self.gen_dim = gen_dim 86 | self.input_projection = nn.Sequential( 87 | nn.Linear(self.in_dim, self.gen_dim * 4 * 4 * 2, bias=False), 88 | nn.BatchNorm1d(self.gen_dim * 4 * 4 * 2), 89 | nn.GLU(dim=1), 90 | ) 91 | self.seq_upsample = nn.Sequential( 92 | UpBlock(gen_dim, gen_dim // 2), 93 | UpBlock(gen_dim // 2, gen_dim // 4), 94 | UpBlock(gen_dim // 4, gen_dim // 8), 95 | UpBlock(gen_dim // 8, gen_dim // 16), 96 | ) 97 | 98 | def forward(self, z_code, c_code): 99 | inp = torch.cat((z_code, c_code), 1) 100 | out = self.input_projection(inp) 101 | out = out.view(-1, self.gen_dim, 4, 4) 102 | # up from 4x4 -> 64x64 103 | out = self.seq_upsample(out) 104 | return out 105 | 106 | 107 | class ResBlock(nn.Module): 108 | def __init__(self, in_channels): 109 | super().__init__() 110 | self.block = nn.Sequential( 111 | nn.Conv2d( 112 | in_channels, 113 | in_channels * 2, 114 | kernel_size=3, 115 | stride=1, 116 | padding=1, 117 | bias=False, 118 | ), 119 | nn.BatchNorm2d(in_channels * 2), 120 | nn.GLU(dim=1), 121 | nn.Conv2d( 122 | in_channels, 123 | in_channels, 124 | kernel_size=3, 125 | stride=1, 126 | padding=1, 127 | bias=False, 128 | ), 129 | nn.BatchNorm2d(in_channels), 130 | ) 131 | 132 | def forward(self, x): 133 | return self.block(x) + x 134 | 135 | 136 | class NextStageGenerator(nn.Module): 137 | def __init__(self, gan_emb_dim: int, gen_dim: int): 138 | super().__init__() 139 | self.gan_emb_dim = gan_emb_dim 140 | self.gen_dim = gen_dim 141 | self.joint_conv = nn.Sequential( 142 | nn.Conv2d( 143 | gan_emb_dim + gen_dim, 144 | gen_dim * 2, 145 | kernel_size=3, 146 | stride=1, 147 | padding=1, 148 | bias=False, 149 | ), 150 | nn.BatchNorm2d(gen_dim * 2), 151 | nn.GLU(dim=1), 152 | ) 153 | self.residual = nn.Sequential(ResBlock(gen_dim), ResBlock(gen_dim)) 154 | self.upsample = UpBlock(gen_dim * 2, gen_dim // 2) 155 | 156 | def forward(self, h_code, c_code): 157 | s_size = h_code.size(2) 158 | c_code = c_code.view(-1, self.gan_emb_dim, 1, 1) 159 | c_code = c_code.repeat(1, 1, s_size, s_size) 160 | 161 | concat_code = torch.cat((c_code, h_code), 1) 162 | 163 | out = self.joint_conv(concat_code) 164 | out = self.residual(out) 165 | out = torch.cat((out, h_code), 1) 166 | out = self.upsample(out) 167 | return out 168 | 169 | 170 | class ImageGenerator(nn.Module): 171 | def __init__(self, gen_dim: int): 172 | super().__init__() 173 | self.img = nn.Sequential( 174 | nn.Conv2d(gen_dim, 3, kernel_size=3, stride=1, padding=1, bias=False,), 175 | nn.Tanh(), 176 | ) 177 | 178 | def forward(self, h_code): 179 | return self.img(h_code) 180 | 181 | 182 | class DenselyStackedGenerator(nn.Module): 183 | def __init__( 184 | self, latent_space_dim: int, speech_emb_dim: int, gen_dim: int, gan_emb_dim: int 185 | ): 186 | super().__init__() 187 | inp_dim = latent_space_dim + gan_emb_dim 188 | 189 | self.conditioning_augmentation = ConditioningAugmentationNetwork( 190 | speech_emb_dim=speech_emb_dim, gan_emb_dim=gan_emb_dim 191 | ) 192 | 193 | self.F0 = InitStateGenerator(in_dim=inp_dim, gen_dim=gen_dim * 16) 194 | self.G0 = ImageGenerator(gen_dim=gen_dim) 195 | 196 | self.F1 = NextStageGenerator(gan_emb_dim=gan_emb_dim, gen_dim=gen_dim) 197 | self.G1 = ImageGenerator(gen_dim=gen_dim // 2) 198 | 199 | self.F2 = NextStageGenerator(gan_emb_dim=gan_emb_dim, gen_dim=gen_dim // 2) 200 | self.G2 = ImageGenerator(gen_dim=gen_dim // 4) 201 | 202 | def get_params(self): 203 | return [p for p in self.parameters() if p.requires_grad] 204 | 205 | def forward(self, z_code, speech_emb): 206 | c_code, mu, logvar = self.conditioning_augmentation(speech_emb) 207 | 208 | h0 = self.F0(z_code, c_code) 209 | h1 = self.F1(h0, c_code) 210 | h2 = self.F2(h1, c_code) 211 | 212 | fake_imgs = {64: self.G0(h0), 128: self.G1(h1), 256: self.G2(h2)} 213 | 214 | return fake_imgs, mu, logvar 215 | -------------------------------------------------------------------------------- /s2igan/sen/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | import wandb 5 | 6 | 7 | def sen_train_epoch( 8 | image_encoder, 9 | speech_encoder, 10 | classifier, 11 | dataloader, 12 | optimizer, 13 | scheduler, 14 | criterion, 15 | device, 16 | epoch, 17 | log_wandb: bool = False, 18 | ): 19 | size = len(dataloader) 20 | run_loss = 0 21 | run_img_acc = 0 22 | run_speech_acc = 0 23 | pbar = tqdm(dataloader, total=size) 24 | for (imgs, specs, len_specs, labels) in pbar: 25 | imgs, specs, len_specs, labels = ( 26 | imgs.to(device), 27 | specs.to(device), 28 | len_specs.to(device), 29 | labels.to(device), 30 | ) 31 | 32 | optimizer.zero_grad() 33 | 34 | V = image_encoder(imgs) 35 | A = speech_encoder(specs, len_specs) 36 | cls_img = classifier(V) 37 | cls_speech = classifier(A) 38 | 39 | match_loss, dist_loss, loss = criterion(V, A, cls_img, cls_speech, labels) 40 | loss.backward() 41 | 42 | optimizer.step() 43 | scheduler.step() 44 | 45 | loss = loss.item() 46 | 47 | img_acc = (cls_img.argmax(-1) == labels).sum() / labels.size(0) * 100 48 | speech_acc = (cls_speech.argmax(-1) == labels).sum() / labels.size(0) * 100 49 | 50 | img_acc = img_acc.item() 51 | speech_acc = speech_acc.item() 52 | 53 | run_loss += loss 54 | run_img_acc += img_acc 55 | run_speech_acc += speech_acc 56 | 57 | if log_wandb: 58 | wandb.log({"train/sen_loss": loss}) 59 | wandb.log({"train/matching_loss": match_loss.item()}) 60 | wandb.log({"train/distinctive_loss": dist_loss.item()}) 61 | wandb.log({"train/image_accuracy": img_acc}) 62 | wandb.log({"train/speech_accuracy": speech_acc}) 63 | wandb.log({"train/epoch": epoch}) 64 | wandb.log({"train/lr-OneCycleLR": scheduler.get_last_lr()[0]}) 65 | 66 | pbar.set_description( 67 | f"[Epoch: {epoch}] Loss: {loss:.2f} | Image Acc: {img_acc:.2f}% | Speech Acc: {speech_acc:.2f}%" 68 | ) 69 | 70 | return { 71 | "loss": run_loss / size, 72 | "img_acc": run_img_acc / size, 73 | "speech_acc": run_speech_acc / size, 74 | } 75 | 76 | 77 | def sen_eval_epoch( 78 | image_encoder, 79 | speech_encoder, 80 | classifier, 81 | dataloader, 82 | criterion, 83 | device, 84 | epoch, 85 | log_wandb: bool = False, 86 | ): 87 | size = len(dataloader) 88 | run_loss = 0 89 | run_img_acc = 0 90 | run_speech_acc = 0 91 | pbar = tqdm(dataloader, total=size) 92 | with torch.no_grad(): 93 | for (imgs, specs, len_specs, labels) in pbar: 94 | imgs, specs, len_specs, labels = ( 95 | imgs.to(device), 96 | specs.to(device), 97 | len_specs.to(device), 98 | labels.to(device), 99 | ) 100 | 101 | V = image_encoder(imgs) 102 | A = speech_encoder(specs, len_specs) 103 | cls_img = classifier(V) 104 | cls_speech = classifier(A) 105 | 106 | match_loss, dist_loss, loss = criterion(V, A, cls_img, cls_speech, labels) 107 | 108 | loss = loss.item() 109 | 110 | img_acc = (cls_img.argmax(-1) == labels).sum() / labels.size(0) * 100 111 | speech_acc = (cls_speech.argmax(-1) == labels).sum() / labels.size(0) * 100 112 | 113 | img_acc = img_acc.item() 114 | speech_acc = speech_acc.item() 115 | 116 | run_loss += loss 117 | run_img_acc += img_acc 118 | run_speech_acc += speech_acc 119 | 120 | if log_wandb: 121 | wandb.log({"val/sen_loss": loss}) 122 | wandb.log({"val/matching_loss": match_loss.item()}) 123 | wandb.log({"val/distinctive_loss": dist_loss.item()}) 124 | wandb.log({"val/image_accuracy": img_acc}) 125 | wandb.log({"val/speech_accuracy": speech_acc}) 126 | 127 | pbar.set_description( 128 | f"[Epoch: {epoch}] Loss: {loss:.2f} | Image Acc: {img_acc:.2f}% | Speech Acc: {speech_acc:.2f}%" 129 | ) 130 | 131 | return { 132 | "loss": run_loss / size, 133 | "img_acc": run_img_acc / size, 134 | "speech_acc": run_speech_acc / size, 135 | } 136 | 137 | # sed 138 | 139 | def sed_train_epoch( 140 | speech_encoder, 141 | classifier, 142 | dataloader, 143 | optimizer, 144 | scheduler, 145 | criterion, 146 | device, 147 | epoch, 148 | log_wandb: bool = False, 149 | ): 150 | size = len(dataloader) 151 | run_loss = 0 152 | run_speech_acc = 0 153 | pbar = tqdm(dataloader, total=size) 154 | for (specs, len_specs, labels) in pbar: 155 | specs, len_specs, labels = ( 156 | specs.to(device), 157 | len_specs.to(device), 158 | labels.to(device), 159 | ) 160 | 161 | optimizer.zero_grad() 162 | 163 | A = speech_encoder(specs, len_specs) 164 | cls_speech = classifier(A) 165 | 166 | loss = criterion(cls_speech, labels) 167 | loss.backward() 168 | 169 | optimizer.step() 170 | scheduler.step() 171 | 172 | loss = loss.item() 173 | 174 | speech_acc = (cls_speech.argmax(-1) == labels).sum() / labels.size(0) * 100 175 | 176 | speech_acc = speech_acc.item() 177 | 178 | run_loss += loss 179 | run_speech_acc += speech_acc 180 | 181 | if log_wandb: 182 | wandb.log({"train/sed_loss": loss}) 183 | wandb.log({"train/speech_accuracy": speech_acc}) 184 | wandb.log({"train/epoch": epoch}) 185 | wandb.log({"train/lr-OneCycleLR": scheduler.get_last_lr()[0]}) 186 | 187 | pbar.set_description( 188 | f"[Epoch: {epoch}] Loss: {loss:.2f} | Speech Acc: {speech_acc:.2f}%" 189 | ) 190 | 191 | return { 192 | "loss": run_loss / size, 193 | "speech_acc": run_speech_acc / size, 194 | } 195 | 196 | 197 | def sed_eval_epoch( 198 | speech_encoder, 199 | classifier, 200 | dataloader, 201 | criterion, 202 | device, 203 | epoch, 204 | log_wandb: bool = False, 205 | ): 206 | size = len(dataloader) 207 | run_loss = 0 208 | run_speech_acc = 0 209 | pbar = tqdm(dataloader, total=size) 210 | with torch.no_grad(): 211 | for (specs, len_specs, labels) in pbar: 212 | specs, len_specs, labels = ( 213 | specs.to(device), 214 | len_specs.to(device), 215 | labels.to(device), 216 | ) 217 | 218 | A = speech_encoder(specs, len_specs) 219 | cls_speech = classifier(A) 220 | 221 | loss = criterion( cls_speech, labels) 222 | 223 | loss = loss.item() 224 | 225 | speech_acc = (cls_speech.argmax(-1) == labels).sum() / labels.size(0) * 100 226 | 227 | speech_acc = speech_acc.item() 228 | 229 | run_loss += loss 230 | run_speech_acc += speech_acc 231 | 232 | if log_wandb: 233 | wandb.log({"val/sed_loss": loss}) 234 | wandb.log({"val/speech_accuracy": speech_acc}) 235 | 236 | pbar.set_description( 237 | f"[Epoch: {epoch}] Loss: {loss:.2f} | Speech Acc: {speech_acc:.2f}%" 238 | ) 239 | 240 | return { 241 | "loss": run_loss / size, 242 | "speech_acc": run_speech_acc / size, 243 | } 244 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torchaudio 8 | from PIL import Image 9 | from torch import nn 10 | from torch.utils.data import DataLoader, Dataset 11 | from torchaudio.transforms import MelSpectrogram 12 | from torchvision import transforms as T 13 | 14 | 15 | class SENDataset(Dataset): 16 | def __init__( 17 | self, 18 | json_file: str, 19 | img_path: str, 20 | audio_path: str, 21 | input_size=299, 22 | n_fft=512, 23 | n_mels=40, 24 | win_length=250, 25 | hop_length=100, 26 | ): 27 | super().__init__() 28 | data = json.load(open(json_file, "r", encoding="utf-8"))["data"] 29 | walker = [ 30 | [ 31 | dict( 32 | label=datum["class"], 33 | img=img_path + os.sep + datum["img"], 34 | audio=audio_path + os.sep + wav, 35 | ) 36 | for wav in datum["wav"] 37 | ] 38 | for datum in data 39 | ] 40 | # check exits 41 | self.walker = [j for i in walker for j in i] 42 | subset = json_file.rsplit(os.sep, 1)[-1].split("_", 1)[0] 43 | 44 | # self.img_transform = { 45 | # 'train': T.Compose( 46 | # [ 47 | # T.ToTensor(), 48 | # T.RandomRotation(degrees=(0, 180)), 49 | # T.RandomHorizontalFlip(p=0.5), 50 | # T.RandomVerticalFlip(p=0.5), 51 | # T.Resize((input_size, input_size)), 52 | # T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 53 | # ] 54 | # ), 55 | # 'test': T.Compose( 56 | # [ 57 | # T.ToTensor(), 58 | # T.Resize((input_size, input_size)), 59 | # T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 60 | # ] 61 | # ) 62 | # }[subset] 63 | 64 | # Augmentation for image 65 | 66 | 67 | self.img_transform = T.Compose( 68 | [ 69 | T.ToTensor(), 70 | T.Resize((input_size, input_size)), 71 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 72 | ] 73 | ) 74 | 75 | sample_rate = 16000 # default value 76 | self.audio_transform = MelSpectrogram( 77 | sample_rate, n_fft, win_length, hop_length, n_mels=n_mels 78 | ) 79 | 80 | def __len__(self): 81 | return len(self.walker) 82 | 83 | def __getitem__(self, idx): 84 | item = self.walker[idx] 85 | 86 | img = Image.open(item["img"]) 87 | img = self.img_transform(img) 88 | 89 | wav, sr = torchaudio.load(item["audio"]) 90 | mel_spec = self.audio_transform(wav) 91 | mel_spec = mel_spec.squeeze().permute(1, 0) # (len, n_mels) 92 | 93 | return img, mel_spec, mel_spec.size(0), item["label"] 94 | 95 | # sed 96 | class SEDDataset(Dataset): 97 | def __init__( 98 | self, 99 | json_file: str, 100 | audio_path: str, 101 | input_size=299, 102 | n_fft=512, 103 | n_mels=40, 104 | win_length=250, 105 | hop_length=100, 106 | ): 107 | super().__init__() 108 | data = json.load(open(json_file, "r", encoding="utf-8"))["data"] 109 | walker = [ 110 | [ 111 | dict( 112 | label=datum["class"], 113 | audio=audio_path + os.sep + wav, 114 | ) 115 | for wav in datum["wav"] 116 | ] 117 | for datum in data 118 | ] 119 | # check exits 120 | self.walker = [j for i in walker for j in i] 121 | 122 | 123 | sample_rate = 16000 # default value 124 | self.audio_transform = MelSpectrogram( 125 | sample_rate, n_fft, win_length, hop_length, n_mels=n_mels 126 | ) 127 | 128 | def __len__(self): 129 | return len(self.walker) 130 | 131 | def __getitem__(self, idx): 132 | item = self.walker[idx] 133 | 134 | 135 | wav, sr = torchaudio.load(item["audio"]) 136 | mel_spec = self.audio_transform(wav) 137 | mel_spec = mel_spec.squeeze().permute(1, 0) # (len, n_mels) 138 | 139 | return mel_spec, mel_spec.size(0), item["label"] 140 | 141 | class RDGDataset(Dataset): 142 | def __init__( 143 | self, 144 | json_file: str, 145 | img_path: str, 146 | audio_path: str, 147 | input_size=299, 148 | n_fft=512, 149 | n_mels=40, 150 | win_length=250, 151 | hop_length=100, 152 | ): 153 | super().__init__() 154 | data = json.load(open(json_file, "r", encoding="utf-8"))["data"] 155 | walker = [ 156 | [ 157 | dict( 158 | label=datum["class"], 159 | img=img_path + os.sep + datum["img"], 160 | audio=audio_path + os.sep + wav, 161 | ) 162 | for wav in datum["wav"] 163 | ] 164 | for datum in data 165 | ] 166 | # check exits 167 | self.walker = [j for i in walker for j in i] 168 | 169 | self.data_class = defaultdict(list) 170 | for data in self.walker: 171 | self.data_class[data.get("label")].append(data) 172 | 173 | subset = json_file.rsplit(os.sep, 1)[-1].split("_", 1)[0] 174 | 175 | # self.img_transform = { 176 | # 'train': T.Compose( 177 | # [ 178 | # T.ToTensor(), 179 | # T.RandomRotation(degrees=(0, 180)), 180 | # T.RandomHorizontalFlip(p=0.5), 181 | # T.RandomVerticalFlip(p=0.5), 182 | # T.Resize((input_size, input_size)), 183 | # T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 184 | # ] 185 | # ), 186 | # 'test': T.Compose( 187 | # [ 188 | # T.ToTensor(), 189 | # T.Resize((input_size, input_size)), 190 | # T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 191 | # ] 192 | # ) 193 | # }[subset] 194 | 195 | self.img_transform = T.Compose( 196 | [ 197 | T.ToTensor(), 198 | T.Resize((input_size, input_size)), 199 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 200 | ] 201 | ) 202 | 203 | sample_rate = 16000 # default value 204 | self.audio_transform = MelSpectrogram( 205 | sample_rate, n_fft, win_length, hop_length, n_mels=n_mels 206 | ) 207 | 208 | def __len__(self): 209 | return len(self.walker) 210 | 211 | def __get_random_same_class__(self, label): 212 | data = self.data_class[label] 213 | l, h = 1, 102 214 | return data[random.randint(l, h)] 215 | 216 | def __get_random_diff_class__(self, diff_label): 217 | l, h = 1, 102 218 | label = random.randint(l, h) 219 | while label == diff_label: 220 | label = random.randint(l, h) 221 | return self.__get_random_same_class__(label) 222 | 223 | def __getitem__(self, index): 224 | item = self.walker[index] 225 | label = item["label"] 226 | 227 | real_img = Image.open(item["img"]) 228 | similar_img = Image.open(self.__get_random_same_class__(label)["img"]) 229 | wrong_img = Image.open(self.__get_random_diff_class__(label)["img"]) 230 | 231 | real_img = self.img_transform(real_img) 232 | similar_img = self.img_transform(similar_img) 233 | wrong_img = self.img_transform(wrong_img) 234 | 235 | wav, sr = torchaudio.load(item["audio"]) 236 | mel_spec = self.audio_transform(wav) 237 | mel_spec = mel_spec.squeeze().permute(1, 0) # (len, n_mels) 238 | 239 | return real_img, similar_img, wrong_img, mel_spec, mel_spec.size(0), (item['audio'], sr) -------------------------------------------------------------------------------- /s2igan/rdg/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DownBlock(nn.Module): 6 | def __init__(self, in_channels, out_channels): 7 | super().__init__() 8 | self.block = nn.Sequential( 9 | nn.Conv2d( 10 | in_channels, 11 | out_channels, 12 | kernel_size=4, 13 | stride=2, 14 | padding=1, 15 | bias=False, 16 | ), 17 | nn.BatchNorm2d(out_channels), 18 | nn.LeakyReLU(0.2, inplace=True), 19 | ) 20 | 21 | def forward(self, x): 22 | return self.block(x) 23 | 24 | 25 | class DownScale16TimesBlock(nn.Module): 26 | def __init__(self, disc_dim: int): 27 | super().__init__() 28 | self.seq = nn.Sequential( 29 | nn.Conv2d(3, disc_dim, kernel_size=4, stride=2, padding=1, bias=False,), 30 | nn.LeakyReLU(0.2, inplace=True), 31 | nn.Conv2d( 32 | disc_dim, disc_dim * 2, kernel_size=4, stride=2, padding=1, bias=False, 33 | ), 34 | nn.BatchNorm2d(disc_dim * 2), 35 | nn.LeakyReLU(0.2, inplace=True), 36 | nn.Conv2d( 37 | disc_dim * 2, 38 | disc_dim * 4, 39 | kernel_size=4, 40 | stride=2, 41 | padding=1, 42 | bias=False, 43 | ), 44 | nn.BatchNorm2d(disc_dim * 4), 45 | nn.LeakyReLU(0.2, inplace=True), 46 | nn.Conv2d( 47 | disc_dim * 4, 48 | disc_dim * 8, 49 | kernel_size=4, 50 | stride=2, 51 | padding=1, 52 | bias=False, 53 | ), 54 | nn.BatchNorm2d(disc_dim * 8), 55 | nn.LeakyReLU(0.2, inplace=True), 56 | ) 57 | 58 | def forward(self, img): 59 | return self.seq(img) 60 | 61 | 62 | class DiscriminatorFor64By64(nn.Module): 63 | def __init__(self, disc_dim: int, gan_emb_dim: int): 64 | super().__init__() 65 | self.disc_dim = disc_dim 66 | self.gan_emb_dim = gan_emb_dim 67 | self.down_scale = DownScale16TimesBlock(disc_dim) 68 | self.joint_conv = nn.Sequential( 69 | nn.Conv2d( 70 | disc_dim * 8 + gan_emb_dim, 71 | disc_dim * 8, 72 | kernel_size=3, 73 | stride=1, 74 | padding=1, 75 | bias=False, 76 | ), 77 | nn.BatchNorm2d(disc_dim * 8), 78 | nn.LeakyReLU(0.2, inplace=True), 79 | ) 80 | self.logits = nn.Sequential( 81 | nn.Conv2d(disc_dim * 8, 1, kernel_size=4, stride=4), nn.Sigmoid() 82 | ) 83 | self.uncond_logits = nn.Sequential( 84 | nn.Conv2d(disc_dim * 8, 1, kernel_size=4, stride=1), nn.Sigmoid() 85 | ) 86 | 87 | def get_params(self): 88 | return [p for p in self.parameters() if p.requires_grad] 89 | 90 | def forward(self, x_var, c_code): 91 | x_code = self.down_scale(x_var) 92 | c_code = c_code.view(-1, self.gan_emb_dim, 1, 1) 93 | c_code = c_code.repeat(1, 1, 4, 4) 94 | code = torch.cat((c_code, x_code), 1) 95 | code = self.joint_conv(code) 96 | 97 | output = self.logits(code) 98 | uncond_output = self.uncond_logits(x_code) 99 | 100 | return {"cond": output.view(-1), "uncond": uncond_output.view(-1)} 101 | 102 | 103 | class DiscriminatorFor128By128(nn.Module): 104 | def __init__(self, disc_dim: int, gan_emb_dim: int): 105 | super().__init__() 106 | self.disc_dim = disc_dim 107 | self.gan_emb_dim = gan_emb_dim 108 | self.down_scale = nn.Sequential( 109 | DownScale16TimesBlock(disc_dim), 110 | DownBlock(disc_dim * 8, disc_dim * 16), 111 | nn.Sequential( 112 | nn.Conv2d( 113 | disc_dim * 16, 114 | disc_dim * 8, 115 | kernel_size=3, 116 | stride=1, 117 | padding=1, 118 | bias=False, 119 | ), 120 | nn.BatchNorm2d(disc_dim * 8), 121 | nn.LeakyReLU(0.2, inplace=True), 122 | ), 123 | ) 124 | self.joint_conv = nn.Sequential( 125 | nn.Conv2d( 126 | disc_dim * 8 + gan_emb_dim, 127 | disc_dim * 8, 128 | kernel_size=3, 129 | stride=1, 130 | padding=1, 131 | bias=False, 132 | ), 133 | nn.BatchNorm2d(disc_dim * 8), 134 | nn.LeakyReLU(0.2, inplace=True), 135 | ) 136 | self.logits = nn.Sequential( 137 | nn.Conv2d(disc_dim * 8, 1, kernel_size=4, stride=4), nn.Sigmoid() 138 | ) 139 | self.uncond_logits = nn.Sequential( 140 | nn.Conv2d(disc_dim * 8, 1, kernel_size=4, stride=1), nn.Sigmoid() 141 | ) 142 | 143 | def get_params(self): 144 | return [p for p in self.parameters() if p.requires_grad] 145 | 146 | def forward(self, x_var, c_code): 147 | x_code = self.down_scale(x_var) 148 | 149 | c_code = c_code.view(-1, self.gan_emb_dim, 1, 1) 150 | c_code = c_code.repeat(1, 1, 4, 4) 151 | code = torch.cat((c_code, x_code), 1) 152 | code = self.joint_conv(code) 153 | 154 | output = self.logits(code) 155 | uncond_output = self.uncond_logits(x_code) 156 | 157 | return {"cond": output.view(-1), "uncond": uncond_output.view(-1)} 158 | 159 | 160 | class DiscriminatorFor256By256(nn.Module): 161 | def __init__(self, disc_dim: int, gan_emb_dim: int): 162 | super().__init__() 163 | self.disc_dim = disc_dim 164 | self.gan_emb_dim = gan_emb_dim 165 | self.down_scale = nn.Sequential( 166 | DownScale16TimesBlock(disc_dim), 167 | DownBlock(disc_dim * 8, disc_dim * 16), 168 | DownBlock(disc_dim * 16, disc_dim * 32), 169 | nn.Sequential( 170 | nn.Conv2d( 171 | disc_dim * 32, 172 | disc_dim * 16, 173 | kernel_size=3, 174 | stride=1, 175 | padding=1, 176 | bias=False, 177 | ), 178 | nn.BatchNorm2d(disc_dim * 16), 179 | nn.LeakyReLU(0.2, inplace=True), 180 | ), 181 | nn.Sequential( 182 | nn.Conv2d( 183 | disc_dim * 16, 184 | disc_dim * 8, 185 | kernel_size=3, 186 | stride=1, 187 | padding=1, 188 | bias=False, 189 | ), 190 | nn.BatchNorm2d(disc_dim * 8), 191 | nn.LeakyReLU(0.2, inplace=True), 192 | ), 193 | ) 194 | self.joint_conv = nn.Sequential( 195 | nn.Conv2d( 196 | disc_dim * 8 + gan_emb_dim, 197 | disc_dim * 8, 198 | kernel_size=3, 199 | stride=1, 200 | padding=1, 201 | bias=False, 202 | ), 203 | nn.BatchNorm2d(disc_dim * 8), 204 | nn.LeakyReLU(0.2, inplace=True), 205 | ) 206 | self.logits = nn.Sequential( 207 | nn.Conv2d(disc_dim * 8, 1, kernel_size=4, stride=4), nn.Sigmoid() 208 | ) 209 | self.uncond_logits = nn.Sequential( 210 | nn.Conv2d(disc_dim * 8, 1, kernel_size=4, stride=1), nn.Sigmoid() 211 | ) 212 | 213 | def get_params(self): 214 | return [p for p in self.parameters() if p.requires_grad] 215 | 216 | def forward(self, x_var, c_code): 217 | x_code = self.down_scale(x_var) 218 | 219 | c_code = c_code.view(-1, self.gan_emb_dim, 1, 1) 220 | c_code = c_code.repeat(1, 1, 4, 4) 221 | code = torch.cat((c_code, x_code), 1) 222 | code = self.joint_conv(code) 223 | 224 | output = self.logits(code) 225 | uncond_output = self.uncond_logits(x_code) 226 | 227 | return {"cond": output.view(-1), "uncond": uncond_output.view(-1)} 228 | -------------------------------------------------------------------------------- /s2igan/rdg/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import random 4 | import wandb 5 | from torchvision import transforms as T 6 | 7 | 8 | def get_transform(img_dim): 9 | return T.Compose([T.Resize(img_dim), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 10 | 11 | 12 | Resizer = {64: get_transform(64), 128: get_transform(128), 256: get_transform(256)} 13 | 14 | 15 | def update_D( 16 | models, batch, optimizers, schedulers, criterions, specific_params, device 17 | ): 18 | origin_real_img, origin_similar_img, origin_wrong_img, spec, spec_len, raw_audio = batch 19 | 20 | real_imgs, wrong_imgs, similar_imgs = {}, {}, {} 21 | for img_dim in specific_params.img_dims: 22 | real_imgs[img_dim] = Resizer[img_dim](origin_real_img) 23 | wrong_imgs[img_dim] = Resizer[img_dim](origin_wrong_img) 24 | similar_imgs[img_dim] = Resizer[img_dim](origin_similar_img) 25 | 26 | for key in optimizers.keys(): 27 | optimizers[key].zero_grad() 28 | 29 | bs = origin_real_img.size(0) 30 | 31 | Z = torch.randn(bs, specific_params.latent_space_dim, device=device) 32 | A = models["sed"](spec, spec_len) 33 | 34 | fake_imgs, mu, logvar = models["gen"](Z, A) 35 | 36 | zero_labels = torch.zeros(bs, device=device, dtype=torch.float) 37 | one_labels = torch.ones(bs, device=device, dtype=torch.float) 38 | two_labels = torch.zeros(bs, device=device, dtype=torch.float) + 2 39 | 40 | D_loss = 0 41 | for img_dim in specific_params.img_dims: 42 | optimizers[img_dim].zero_grad() 43 | 44 | real_img = real_imgs[img_dim] 45 | wrong_img = wrong_imgs[img_dim] 46 | 47 | real_out = models["disc"][img_dim](real_img, mu.detach()) 48 | wrong_out = models["disc"][img_dim](wrong_img, mu.detach()) 49 | fake_out = models["disc"][img_dim](fake_imgs[img_dim].detach(), mu.detach()) 50 | 51 | loss_real_cond = criterions["bce"](real_out["cond"], one_labels) 52 | loss_real_uncond = criterions["bce"](real_out["uncond"], one_labels) 53 | # --- 54 | loss_wrong_cond = criterions["bce"](wrong_out["cond"], zero_labels) 55 | loss_wrong_uncond = criterions["bce"](wrong_out["uncond"], one_labels) 56 | # --- 57 | loss_fake_cond = criterions["bce"](fake_out["cond"], zero_labels) 58 | loss_fake_uncond = criterions["bce"](fake_out["uncond"], zero_labels) 59 | 60 | curr_D_loss = ( 61 | loss_real_cond 62 | + loss_real_uncond 63 | + loss_fake_cond 64 | + loss_fake_uncond 65 | + loss_wrong_cond 66 | + loss_wrong_uncond 67 | ) 68 | curr_D_loss.backward() 69 | optimizers[img_dim].step() 70 | schedulers[img_dim].step() 71 | 72 | D_loss += curr_D_loss.detach().item() 73 | 74 | return D_loss 75 | 76 | 77 | def update_RS( 78 | models, batch, optimizers, schedulers, criterions, specific_params, device 79 | ): 80 | origin_real_img, origin_similar_img, origin_wrong_img, spec, spec_len, raw_audio = batch 81 | 82 | real_imgs, wrong_imgs, similar_imgs = {}, {}, {} 83 | for img_dim in specific_params.img_dims: 84 | real_imgs[img_dim] = Resizer[img_dim](origin_real_img) 85 | wrong_imgs[img_dim] = Resizer[img_dim](origin_wrong_img) 86 | similar_imgs[img_dim] = Resizer[img_dim](origin_similar_img) 87 | 88 | optimizers.zero_grad() 89 | 90 | bs = origin_real_img.size(0) 91 | 92 | Z = torch.randn(bs, specific_params.latent_space_dim, device=device) 93 | A = models["sed"](spec, spec_len) 94 | 95 | fake_imgs, mu, logvar = models["gen"](Z, A) 96 | 97 | zero_labels = torch.zeros(bs, device=device, dtype=torch.float) 98 | one_labels = torch.ones(bs, device=device, dtype=torch.float) 99 | two_labels = torch.zeros(bs, device=device, dtype=torch.float) + 2 100 | 101 | real_img = Resizer[256](origin_real_img) 102 | similar_img = Resizer[256](origin_similar_img) 103 | wrong_img = Resizer[256](origin_wrong_img) 104 | 105 | real_feat = models["ied"](real_img) 106 | similar_feat = models["ied"](similar_img) 107 | fake_feat = models["ied"](fake_imgs[256].detach()) 108 | wrong_feat = models["ied"](wrong_img) 109 | 110 | R1 = models["rs"](similar_feat.detach(), real_feat.detach()) 111 | R2 = models["rs"](wrong_feat.detach(), real_feat.detach()) 112 | R3 = models["rs"](real_feat.detach(), real_feat.detach()) 113 | R_GT_FI = models["rs"](fake_feat.detach(), real_feat.detach()) 114 | 115 | RS_loss = criterions["rs"](R1, R2, R3, R_GT_FI, zero_labels, one_labels, two_labels) 116 | 117 | RS_loss.backward() 118 | optimizers.step() 119 | schedulers.step() 120 | 121 | return RS_loss.detach().item() 122 | 123 | 124 | def update_G( 125 | models, batch, optimizers, schedulers, criterions, specific_params, device 126 | ): 127 | origin_real_img, origin_similar_img, origin_wrong_img, spec, spec_len, raw_audio = batch 128 | 129 | real_imgs, wrong_imgs, similar_imgs = {}, {}, {} 130 | for img_dim in specific_params.img_dims: 131 | real_imgs[img_dim] = Resizer[img_dim](origin_real_img) 132 | wrong_imgs[img_dim] = Resizer[img_dim](origin_wrong_img) 133 | similar_imgs[img_dim] = Resizer[img_dim](origin_similar_img) 134 | 135 | optimizers.zero_grad() 136 | 137 | bs = origin_real_img.size(0) 138 | 139 | Z = torch.randn(bs, specific_params.latent_space_dim, device=device) 140 | A = models["sed"](spec, spec_len) 141 | 142 | fake_imgs, mu, logvar = models["gen"](Z, A) 143 | 144 | zero_labels = torch.zeros(bs, device=device, dtype=torch.float) 145 | one_labels = torch.ones(bs, device=device, dtype=torch.float) 146 | two_labels = torch.zeros(bs, device=device, dtype=torch.float) + 2 147 | 148 | G_loss = 0 149 | for img_dim in specific_params.img_dims: 150 | 151 | fake_out = models["disc"][img_dim](fake_imgs[img_dim], mu) 152 | cond_loss = criterions["bce"](fake_out["cond"], one_labels) 153 | uncond_loss = criterions["bce"](fake_out["uncond"], one_labels) 154 | 155 | wandb.log({f"train/cond_loss_{img_dim}": cond_loss.item()}) 156 | wandb.log({f"train/uncond_loss_{img_dim}": uncond_loss.item()}) 157 | 158 | G_loss += cond_loss + uncond_loss 159 | 160 | real_feat = models["ied"](real_imgs[img_dim]) 161 | fake_feat = models["ied"](fake_imgs[img_dim]) 162 | rs_out = models["rs"](real_feat, fake_feat) 163 | 164 | G_loss += criterions["ce"](rs_out, one_labels.long()) 165 | 166 | # real_img = Resizer[256](origin_real_img) 167 | # similar_img = Resizer[256](origin_similar_img) 168 | # wrong_img = Resizer[256](origin_wrong_img) 169 | 170 | # real_feat = models["ied"](real_img) 171 | # similar_feat = models["ied"](similar_img) 172 | # fake_feat = models["ied"](fake_imgs[256]) 173 | # wrong_feat = models["ied"](wrong_img) 174 | 175 | # R1 = models["rs"](similar_feat, real_feat) 176 | # R2 = models["rs"](wrong_feat, real_feat) 177 | # R3 = models["rs"](real_feat, real_feat) 178 | # R_GT_FI = models["rs"](fake_feat, real_feat) 179 | 180 | # RS_loss = criterions["rs"](R1, R2, R3, R_GT_FI, zero_labels, one_labels, two_labels) 181 | 182 | KL_loss = criterions["kl"](mu, logvar) * specific_params.kl_loss_coef 183 | # G_loss += RS_loss 184 | G_loss += KL_loss 185 | G_loss.backward() 186 | optimizers.step() 187 | schedulers.step() 188 | 189 | i = random.randint(0, origin_real_img.size(0) - 1) 190 | audio_path, sr = raw_audio[i] 191 | 192 | image_64 = torch.cat((fake_imgs[64][i:i+1], real_imgs[64][i:i+1]), 0) * 0.5 + 0.5 193 | image_128 = torch.cat((fake_imgs[128][i:i+1], real_imgs[128][i:i+1]), 0) * 0.5 + 0.5 194 | image_256 = torch.cat((fake_imgs[256][i:i+1], real_imgs[256][i:i+1]), 0) * 0.5 + 0.5 195 | 196 | wandb.log({"train/image_64": wandb.Image(image_64)}) 197 | wandb.log({"train/image_128": wandb.Image(image_128)}) 198 | wandb.log({"train/image_256": wandb.Image(image_256)}) 199 | wandb.log({"train/speech_description": wandb.Audio(audio_path, sample_rate=sr)}) 200 | 201 | return G_loss.detach().item(), KL_loss.detach().item() 202 | 203 | 204 | def rdg_train_epoch( 205 | models, 206 | dataloader, 207 | optimizers, 208 | schedulers, 209 | criterions, 210 | specific_params, 211 | device, 212 | epoch, 213 | log_wandb, 214 | ): 215 | size = len(dataloader) 216 | pbar = tqdm(dataloader, total=size) 217 | for ( 218 | origin_real_img, 219 | origin_similar_img, 220 | origin_wrong_img, 221 | spec, 222 | spec_len, 223 | raw_audio 224 | ) in pbar: 225 | # origin_real_img, origin_similar_img, origin_wrong_img, spec, spec_len 226 | batch = ( 227 | origin_real_img.to(device), 228 | origin_similar_img.to(device), 229 | origin_wrong_img.to(device), 230 | spec.to(device), 231 | spec_len.to(device), 232 | raw_audio 233 | ) 234 | 235 | D_loss = update_D( 236 | models, 237 | batch, 238 | optimizers["disc"], 239 | schedulers["disc"], 240 | criterions, 241 | specific_params, 242 | device, 243 | ) 244 | RS_loss = update_RS( 245 | models, 246 | batch, 247 | optimizers["rs"], 248 | schedulers["rs"], 249 | criterions, 250 | specific_params, 251 | device, 252 | ) 253 | G_loss, KL_loss = update_G( 254 | models, 255 | batch, 256 | optimizers["gen"], 257 | schedulers["gen"], 258 | criterions, 259 | specific_params, 260 | device, 261 | ) 262 | 263 | if log_wandb: 264 | wandb.log({"train/G_loss": G_loss}) 265 | wandb.log({"train/D_loss": D_loss}) 266 | wandb.log({"train/KL_loss": KL_loss}) 267 | wandb.log({"train/RS_loss": RS_loss}) 268 | wandb.log({"train/epoch": epoch}) 269 | wandb.log({"train/lr-OneCycleLR_G": schedulers["gen"].get_last_lr()[0]}) 270 | 271 | pbar.set_description( 272 | f"[Epoch: {epoch}] G_Loss: {G_loss:.2f} | D_Loss: {D_loss:.2f} | RS_loss: {RS_loss:.2f}" 273 | ) 274 | --------------------------------------------------------------------------------