├── Recommender ├── data │ └── .gitkeep ├── requirements.txt ├── README.md ├── utils.py ├── gen_label.py ├── FMRecommender.py ├── run_classifier.py ├── run.py └── fm_audio.py ├── Speech ├── vits_lib │ ├── pretrain │ │ └── .gitkeep │ ├── text │ │ ├── __pycache__ │ │ │ ├── symbols.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── cleaners.cpython-38.pyc │ │ ├── symbols.py │ │ ├── LICENSE │ │ ├── __init__.py │ │ └── cleaners.py │ ├── monotonic_align │ │ ├── setup.py │ │ ├── __init__.py │ │ └── core.pyx │ ├── configs │ │ ├── ljs_base.json │ │ ├── ljs_nosdp.json │ │ └── vctk_base.json │ ├── commons.py │ ├── utils.py │ ├── transforms.py │ ├── attentions.py │ └── modules.py ├── requirements.txt ├── README.md ├── inference_user.py ├── inference.py ├── data │ └── speaker_info │ │ ├── speaker-info.txt │ │ └── vctk_audio_sid_text_val_filelist.txt.cleaned.txt └── utils.py ├── .gitignore ├── images ├── logo.png └── framework.png ├── Dialogue ├── requirements.txt ├── data │ ├── .DS_Store │ └── ml-1m │ │ └── user_preference.npy ├── README.md ├── config │ ├── recommend_pattern.py │ └── thanks_pattern.py ├── movie_utils.py ├── coat_utils.py ├── gen_coat.py ├── gen_movie.py └── coat_attr.py ├── Evaluate ├── requirements.txt ├── evaluate.py └── fed.py └── README.md /Recommender/data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Speech/vits_lib/pretrain/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */.DS_Store 3 | .vscode 4 | */.vscode -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyllll/VCRS/HEAD/images/logo.png -------------------------------------------------------------------------------- /Dialogue/requirements.txt: -------------------------------------------------------------------------------- 1 | random 2 | numpy 3 | pandas 4 | lightgbm 5 | re 6 | json -------------------------------------------------------------------------------- /images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyllll/VCRS/HEAD/images/framework.png -------------------------------------------------------------------------------- /Dialogue/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Dialogue/data/.DS_Store -------------------------------------------------------------------------------- /Evaluate/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | transformers 4 | random 5 | json 6 | argparse -------------------------------------------------------------------------------- /Recommender/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio 3 | argparse 4 | tqdm 5 | transformers 6 | sklearn -------------------------------------------------------------------------------- /Dialogue/data/ml-1m/user_preference.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Dialogue/data/ml-1m/user_preference.npy -------------------------------------------------------------------------------- /Speech/vits_lib/text/__pycache__/symbols.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Speech/vits_lib/text/__pycache__/symbols.cpython-38.pyc -------------------------------------------------------------------------------- /Speech/vits_lib/text/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Speech/vits_lib/text/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Speech/vits_lib/text/__pycache__/cleaners.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Speech/vits_lib/text/__pycache__/cleaners.cpython-38.pyc -------------------------------------------------------------------------------- /Speech/requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.21 2 | librosa==0.8.0 3 | matplotlib==3.3.1 4 | phonemizer==2.2.1 5 | scipy==1.5.2 6 | tensorboard==2.3.0 7 | Unidecode==1.1.1 8 | torch 9 | torchvision 10 | pydub 11 | random 12 | tqdm 13 | ipython 14 | numpy -------------------------------------------------------------------------------- /Speech/vits_lib/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name = 'monotonic_align', 7 | ext_modules = cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()] 9 | ) 10 | -------------------------------------------------------------------------------- /Speech/vits_lib/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | ''' 6 | _pad = '_' 7 | _punctuation = ';:,.!?¡¿—…"«»“” ' 8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | 11 | 12 | # Export all symbols: 13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 14 | 15 | # Special symbol ids 16 | SPACE_ID = symbols.index(" ") 17 | -------------------------------------------------------------------------------- /Speech/vits_lib/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """ Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /Speech/vits_lib/text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /Dialogue/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This is the code for text-based conversation generation 3 | 4 | ## Installtion 5 | install dependencies via: 6 | ``` 7 | cd ./Dialogue/ 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## Prepare Data 12 | 1. If you want to generate conversation in the moive domain, please download the [ml-1m.csv](https://drive.google.com/file/d/1iOum0fcgPzyvV5Mj8EuNdgtd0eLPz2qt/view?usp=sharing) and put it in `./data/ml-1m/`. The ml-1m.csv file contains the movie information scraped from the imdb website, and this file is divided into 10 columns: 13 | 14 | ` | user_id | item_id | rating | timestamp | movie | director | actors | country | title | genres |` 15 | 16 | 2. Corresponding templates are in the `./config/` directory. You can also add some new templates to enrich the conversation. 17 | 18 | 3. For convenience, we also summarize the templates we use in the [document](https://fssntlo70a.feishu.cn/sheets/shtcn5rtVsQa69LA1efCJpWeIEd). 19 | 20 | ## RUN 21 | Generate conversation in the e-commerce domain. 22 | ``` 23 | python gen_coat.py 24 | ``` 25 | Generate conversation in the movie domain. 26 | ``` 27 | python gen_movie.py 28 | ``` 29 | You can find generated conversation in `./res/`, such as `dialogue_info_coat.json` or `dialogue_info_ml-1m.json`. -------------------------------------------------------------------------------- /Speech/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This is the code for voice-based conversation generation 3 | 4 | ## Installtion 5 | 1. install dependencies via: 6 | ``` 7 | cd ./Speech/ 8 | pip install -r requirements.txt 9 | ``` 10 | 2. Build Monotonic Alignment Search 11 | ``` 12 | cd ./vits_lib/monotonic_align/ 13 | python3 setup.py build_ext --inplace 14 | ``` 15 | 16 | ## Pretrained model 17 | 1. Download [pretrained VITS models](https://drive.google.com/drive/folders/1ksarh-cJf3F5eKJjLVWY0X1j1qsQqiS2) from GoogleDrive. The pretrained models are provided by [VITS Repo](https://github.com/jaywalnut310/vits). 18 | 19 | 2. Put the pretrained models in the `./vits_lib/pretrain/` directory. 20 | 21 | ## RUN 22 | 1. Generate audio containing only the content of the user's conversation. 23 | ``` 24 | python inference_user.py --dataset='xxx' 25 | ``` 26 | ```xxx``` is ```coat``` or ```ml-1m``` 27 | 2. Generate audio containing full conversations (i.e., users and agents). 28 | ``` 29 | python inference.py --dataset='xxx' 30 | ``` 31 | ```xxx``` is ```coat``` or ```ml-1m``` 32 | 3. Note that all python scripts are used to generate complete audio on ml-1m, you can generate a certain number of audio according to your needs. 33 | 4. All results are saved in `./speech_res/` directory. 34 | -------------------------------------------------------------------------------- /Speech/vits_lib/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /Speech/vits_lib/configs/ljs_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 20000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned", 21 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned", 22 | "text_cleaners":["english_cleaners2"], 23 | "max_wav_value": 32768.0, 24 | "sampling_rate": 22050, 25 | "filter_length": 1024, 26 | "hop_length": 256, 27 | "win_length": 1024, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0.0, 30 | "mel_fmax": null, 31 | "add_blank": true, 32 | "n_speakers": 0, 33 | "cleaned_text": true 34 | }, 35 | "model": { 36 | "inter_channels": 192, 37 | "hidden_channels": 192, 38 | "filter_channels": 768, 39 | "n_heads": 2, 40 | "n_layers": 6, 41 | "kernel_size": 3, 42 | "p_dropout": 0.1, 43 | "resblock": "1", 44 | "resblock_kernel_sizes": [3,7,11], 45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 46 | "upsample_rates": [8,8,2,2], 47 | "upsample_initial_channel": 512, 48 | "upsample_kernel_sizes": [16,16,4,4], 49 | "n_layers_q": 3, 50 | "use_spectral_norm": false 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /Speech/vits_lib/configs/ljs_nosdp.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 20000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned", 21 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned", 22 | "text_cleaners":["english_cleaners2"], 23 | "max_wav_value": 32768.0, 24 | "sampling_rate": 22050, 25 | "filter_length": 1024, 26 | "hop_length": 256, 27 | "win_length": 1024, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0.0, 30 | "mel_fmax": null, 31 | "add_blank": true, 32 | "n_speakers": 0, 33 | "cleaned_text": true 34 | }, 35 | "model": { 36 | "inter_channels": 192, 37 | "hidden_channels": 192, 38 | "filter_channels": 768, 39 | "n_heads": 2, 40 | "n_layers": 6, 41 | "kernel_size": 3, 42 | "p_dropout": 0.1, 43 | "resblock": "1", 44 | "resblock_kernel_sizes": [3,7,11], 45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 46 | "upsample_rates": [8,8,2,2], 47 | "upsample_initial_channel": 512, 48 | "upsample_kernel_sizes": [16,16,4,4], 49 | "n_layers_q": 3, 50 | "use_spectral_norm": false, 51 | "use_sdp": false 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /Speech/vits_lib/configs/vctk_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned", 21 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned", 22 | "text_cleaners":["english_cleaners2"], 23 | "max_wav_value": 32768.0, 24 | "sampling_rate": 22050, 25 | "filter_length": 1024, 26 | "hop_length": 256, 27 | "win_length": 1024, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0.0, 30 | "mel_fmax": null, 31 | "add_blank": true, 32 | "n_speakers": 109, 33 | "cleaned_text": true 34 | }, 35 | "model": { 36 | "inter_channels": 192, 37 | "hidden_channels": 192, 38 | "filter_channels": 768, 39 | "n_heads": 2, 40 | "n_layers": 6, 41 | "kernel_size": 3, 42 | "p_dropout": 0.1, 43 | "resblock": "1", 44 | "resblock_kernel_sizes": [3,7,11], 45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 46 | "upsample_rates": [8,8,2,2], 47 | "upsample_initial_channel": 512, 48 | "upsample_kernel_sizes": [16,16,4,4], 49 | "n_layers_q": 3, 50 | "use_spectral_norm": false, 51 | "gin_channels": 256 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /Recommender/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This is the code for the exploration of the two VCRSs datasets. 3 | 4 | ## Installtion 5 | install dependencies via: 6 | ``` 7 | cd ./Recommender/ 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## Prepare Data 12 | 1. Download datasets from Google Drive: [coat.tar.gz](https://drive.google.com/file/d/1FnpYhMaeskckxGheKjar0U4YHIdDKM6K/view?usp=share_link) and [ml-1m.tar.gz](https://drive.google.com/file/d/195ugsUrU51VMUjjI329M84qegtK2QuGC/view?usp=sharing) 13 | 14 | 2. Put the dataset in the `./data/` directory. 15 | 16 | 3. `tar -zxvf xxx.tar.gz` 17 | 18 | ## Multi-task Classification Module (MCM) 19 | MCM aims to extract explicit semantic features (e.g., user age) from our created VCRS datasets. 20 | * Extract features on Coat dataset 21 | ``` 22 | python run_classifier.py --dataset='coat' --batch_size=64 --hidden_size=1024 --dropout=0.2 --epochs=20 --lr=0.0001 --test=1 23 | ``` 24 | 25 | * Extract features on ML-1M dataset 26 | ``` 27 | python run_classifier.py --dataset='ml-1m' --batch_size=64 --hidden_size=1024 --dropout=0.2 --epochs=20 --lr=0.0001 --test=1 28 | ``` 29 | 30 | * The trained model is saved as `clf_model_xxx.pt` 31 | 32 | ## Voice Feature Integration Module (VFIM) 33 | VFIM seeks to integrate the extracted voice-related features into the recommendation model for a performance-enhanced recommendation, which builds a two-phase fusion framework together with the MCM. 34 | 35 | * Factorization Machineal (FM) algorithm on Coat 36 | ``` 37 | python gen_label.py --dataset=coat 38 | python fm_audio.py --dataset=coat 39 | ``` 40 | 41 | * Factorization Machineal (FM) algorithm on ML-1M 42 | ``` 43 | python gen_label.py --dataset=ml-1m 44 | python fm_audio.py --dataset=ml-1m 45 | ``` -------------------------------------------------------------------------------- /Speech/vits_lib/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from vits_lib.text import cleaners 3 | from vits_lib.text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, cleaner_names): 12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | ''' 19 | sequence = [] 20 | 21 | clean_text = _clean_text(text, cleaner_names) 22 | for symbol in clean_text: 23 | symbol_id = _symbol_to_id[symbol] 24 | sequence += [symbol_id] 25 | return sequence 26 | 27 | 28 | def cleaned_text_to_sequence(cleaned_text): 29 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 30 | Args: 31 | text: string to convert to a sequence 32 | Returns: 33 | List of integers corresponding to the symbols in the text 34 | ''' 35 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 36 | return sequence 37 | 38 | 39 | def sequence_to_text(sequence): 40 | '''Converts a sequence of IDs back to a string''' 41 | result = '' 42 | for symbol_id in sequence: 43 | s = _id_to_symbol[symbol_id] 44 | result += s 45 | return result 46 | 47 | 48 | def _clean_text(text, cleaner_names): 49 | for name in cleaner_names: 50 | cleaner = getattr(cleaners, name) 51 | if not cleaner: 52 | raise Exception('Unknown cleaner: %s' % name) 53 | text = cleaner(text) 54 | return text 55 | -------------------------------------------------------------------------------- /Evaluate/evaluate.py: -------------------------------------------------------------------------------- 1 | import fed 2 | import json 3 | import os 4 | import random 5 | import argparse 6 | import numpy as np 7 | from collections import defaultdict 8 | 9 | 10 | def load_data(dataset): 11 | with open(f"../Dialogue/res/{dataset}/dialogue_info_{dataset}.json",'r') as load_f: 12 | coat = json.load(load_f) 13 | contexts = [] 14 | for _, dialogue in coat.items(): 15 | dialogue_content = dialogue["content"] 16 | context = [] 17 | for _, text in dialogue_content.items(): 18 | context.append(text) 19 | context = ' <|endoftext|> '.join(context) 20 | context = '<|endoftext|> ' + context + ' <|endoftext|>' 21 | contexts.append(context) 22 | 23 | return contexts 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser(description='dialogue evaluate') 28 | parser.add_argument('--dataset', 29 | type=str, 30 | default='coat', 31 | help='select dataset, option: coat, ml-1m, opendialKG, redial, inspire') 32 | parser.add_argument('--eval_num', 33 | type=int, 34 | default=400, 35 | help='the number of dialogues') 36 | args = parser.parse_args() 37 | data = load_data(args.dataset) 38 | model, tokenizer = fed.load_models("microsoft/DialoGPT-large") 39 | print("model load successful") 40 | 41 | scores = [] 42 | data = random.sample(data, args.eval_num) 43 | for conversation in data: 44 | scores.append(fed.evaluate(conversation, model, tokenizer)) 45 | 46 | fed_scores = defaultdict(list) 47 | for result in scores: 48 | score_val = 0.0 49 | for key, val in result.items(): 50 | fed_scores[key].append(val) 51 | score_val += val 52 | fed_scores['fed_overall'].append(score_val / len(result)) 53 | 54 | save_dir = f'./res/{args.dataset}/' 55 | if not os.path.exists(save_dir): 56 | os.makedirs(save_dir) 57 | with open(save_dir + f'{args.dataset}_results.json', 'w') as f: 58 | json.dump(fed_scores, f) 59 | -------------------------------------------------------------------------------- /Dialogue/config/recommend_pattern.py: -------------------------------------------------------------------------------- 1 | movie_rec = [ 2 | "Okay, I will suggest some movies you may like.", 3 | "Okay, I will recommend some movies you may like.", 4 | "All right, I'll give you some movie recommendations.", 5 | "Okay, I can recommend some movies that you might enjoy.", 6 | "Based on your preferences, I think you might enjoy the following movies.", 7 | "I have some movie suggestions that might match your interests.", 8 | "Okay, I'll offer some movie recommendations for you.", 9 | "Here are some movies that I think you might like based on your preferences.", 10 | "Here are some movies that I think would be a good fit for your interests.", 11 | "Okay, I'll suggest some movies for you to watch.", 12 | "Great, I can suggest some movies that you might like.", 13 | "I'll suggest some movies based on your interests.", 14 | "Ok, I have some movie recommendations that you might like.", 15 | "Let me recommend some movies for you based on your preferences.", 16 | "Great, I have suggestions ideas for movies that you might enjoy.", 17 | "Now, I have some movies that I think you might like based on your preferences.", 18 | "Okay, I have some movie recommendations that might match your interests.", 19 | "I will offer some movie suggestions based on what you've told me." 20 | ] 21 | 22 | coat_rec = [ 23 | "Okay, I will suggest some coats you may like.", 24 | "Okay, I will recommend some coats you may like.", 25 | "All right, I'll give you some coats recommendations.", 26 | "Okay, I can recommend some coats that you might enjoy.", 27 | "Based on your preferences, I think you might enjoy the following coats.", 28 | "I have some coats suggestions that might match your interests.", 29 | "Okay, I'll offer some coats recommendations for you.", 30 | "Here are some coats that I think you might like based on your preferences.", 31 | "Here are some coats that I think would be a good fit for your interests.", 32 | "Okay, I'll suggest some coats for you.", 33 | "Great, I can suggest some coats that you might like.", 34 | "I'll suggest some coats based on your interests.", 35 | "Ok, I have some coats recommendations that you might like.", 36 | "Let me recommend some coats for you based on your preferences.", 37 | "Great, I have suggestions ideas for coats that you might enjoy.", 38 | "Now, I have some coats that I think you might like based on your preferences.", 39 | "Okay, I have some coats recommendations that might match your interests.", 40 | "I will offer some coats suggestions based on what you've told me." 41 | ] -------------------------------------------------------------------------------- /Dialogue/movie_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def get_user_age(age): 5 | if age < 20: 6 | age_str = 'under 20' 7 | elif age >= 20 and age <= 30: 8 | age_str = '20-30' 9 | elif age > 30: 10 | age_str = 'over 30' 11 | 12 | return age_str 13 | 14 | def get_user_gender(gender): 15 | if gender == 'F': 16 | gender_str = 'women' 17 | elif gender == 'M': 18 | gender_str = 'men' 19 | 20 | return gender_str 21 | 22 | def get_item_actor(actors): 23 | actors = actors.strip('[').strip(']').replace("'","").split(',') 24 | actors = [actor.strip().strip('"') for actor in actors] 25 | 26 | 27 | return actors 28 | 29 | def get_item_genre(genre): 30 | genre = genre.split("|") 31 | genre = [gen.lower() for gen in genre] 32 | 33 | return genre 34 | 35 | def modify_country(country): 36 | country_dict = {'Australia': 'Australian', 'United States': 'American', 37 | 'Japan': 'Japanese', 'United Kingdom': 'British', 38 | 'Mexico': 'Mexican', 'Italy': 'Italian', 39 | 'France': 'French', 'Germany': 'German', 40 | 'Brazil': 'Brazilian', 'Spain': 'Spanish', 41 | 'Netherlands': 'Dutch', 'Canada': 'Canadian', 42 | 'Hong Kong': 'Hong Kong', 'Cuba': 'Cuban', 43 | 'Monaco': 'Monaco', 'Belgium': 'Belgian', 44 | 'Czech Republic': 'Czech', 'West Germany': 'West German', 45 | 'Ireland': 'Irish', 'Soviet Union': 'Soviet Union', 46 | 'New Zealand': 'New Zealand', 'China': 'Chinese', 47 | 'Luxembourg': 'Luxembourg', 'Taiwan': 'Taiwanese', 48 | 'South Africa': 'South African', 'Sweden': 'Swedish', 49 | 'Switzerland': 'Swiss', 'Denmark': 'Danish', 50 | 'Argentina': 'Argentine', 'Russia': 'Russian', 51 | 'Aruba': 'Aruba', 'India': 'Indian', 52 | 'Norway':'Norwegian', 'Federal Republic of Yugoslavia': 'Yugoslav', 53 | 'Iran': 'Iranian', 'Bhutan': 'Bhutanese', 54 | 'Vietnam': 'Vietnamese', 'Dominican Republic': 'Dominican', 55 | 'Hungary': 'Hungarian', 'Poland': 'Polish'} 56 | 57 | return country_dict[country] 58 | 59 | 60 | def check_in_english(utterance): 61 | for _, text in utterance.items(): 62 | text = re.sub(r'[.,"\'-?:!;]', '', text) 63 | text = text.replace(' ','') 64 | if not text.encode('UTF-8').isalpha(): 65 | return False 66 | 67 | return True 68 | -------------------------------------------------------------------------------- /Speech/inference_user.py: -------------------------------------------------------------------------------- 1 | import IPython.display as ipd 2 | import argparse 3 | import os 4 | import json 5 | import random 6 | import numpy as np 7 | from tqdm import tqdm 8 | from pydub import AudioSegment 9 | from utils import * 10 | import logging 11 | 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser(description='dialogue') 15 | parser.add_argument('--dataset', 16 | type=str, 17 | default='ml-1m', 18 | help='select dataset, option: coat, ml-1m') 19 | parser.add_argument('--format', 20 | type=str, 21 | default='mp3', 22 | help='select format, option: wav, flac, mp3') 23 | parser.add_argument('--start_id', 24 | type=int, 25 | default=0) 26 | args = parser.parse_args() 27 | with open(f"../Dialogue/res/{args.dataset}/dialogue_info_{args.dataset}.json",'r') as load_f: 28 | dialogue = json.load(load_f) 29 | logger = logging.getLogger() 30 | logger.setLevel(logging.ERROR) 31 | vctk_to_speaker_id, total_speaker = preprocess_speaker_info() 32 | agent, user = load_vits_model() 33 | random.seed(2023) 34 | save_dir = f'./speech_res/{args.dataset}_{args.format}_user/' 35 | if not os.path.exists(save_dir): 36 | os.makedirs(save_dir) 37 | for k, v in tqdm(dialogue.items()): 38 | dia_id = int(k) 39 | content = v['content'] 40 | user_id = v['user_id'] 41 | item_id = v['item_id'] 42 | gender = v['user_gender'] 43 | age = v['user_age'] 44 | 45 | speaker_list_id = selet_speaker_list_idx(age, gender) 46 | speaker_list = total_speaker[speaker_list_id] 47 | speaker = random.choice(speaker_list) 48 | speaker_id = vctk_to_speaker_id[speaker] 49 | speaker_speed = 1.2 50 | 51 | audio_list = [] 52 | count = 0 53 | save_name = f'diaid{dia_id}_uid{user_id}_iid{item_id}_{age}_{gender}_{speaker}' 54 | tag = 0 55 | for uttr_name, uttr in content.items(): 56 | if count % 2 == 0: 57 | s_audio = generate_user_speech(user, uttr, speaker_id, speaker_speed) 58 | if tag == 0: 59 | combine_audio = s_audio 60 | tag = 1 61 | else: 62 | combine_audio = np.hstack((combine_audio, s_audio)) 63 | count += 1 64 | audio = ipd.Audio(combine_audio, rate=22050, normalize=False) 65 | with open(f'{save_dir}{save_name}.mp3', 'wb') as f: 66 | f.write(audio.data) 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /Speech/inference.py: -------------------------------------------------------------------------------- 1 | import IPython.display as ipd 2 | import argparse 3 | import os 4 | import json 5 | import random 6 | import numpy as np 7 | from tqdm import tqdm 8 | from pydub import AudioSegment 9 | from utils import * 10 | import logging 11 | import random 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description='dialogue') 16 | parser.add_argument('--dataset', 17 | type=str, 18 | default='ml-1m', 19 | help='select dataset, option: coat, ml-1m') 20 | parser.add_argument('--format', 21 | type=str, 22 | default='mp3', 23 | help='select format, option: wav, flac, mp3') 24 | args = parser.parse_args() 25 | with open(f"../Dialogue/res/{args.dataset}/dialogue_info_{args.dataset}.json",'r') as load_f: 26 | dialogue = json.load(load_f) 27 | random.seed(2023) 28 | logger = logging.getLogger() 29 | logger.setLevel(logging.ERROR) 30 | vctk_to_speaker_id, total_speaker = preprocess_speaker_info() 31 | agent, user = load_vits_model() 32 | zero_audio = np.zeros((20000,), dtype=np.float32) 33 | zero_audio = ipd.Audio(zero_audio, rate=user[0].data.sampling_rate, normalize=False) 34 | with open(f'zero.wav', 'wb') as f: 35 | f.write(zero_audio.data) 36 | zero_wav = AudioSegment.from_wav("./zero.wav") 37 | save_dir = f'./speech_res/{args.dataset}_{args.format}/' 38 | if not os.path.exists(save_dir): 39 | os.makedirs(save_dir) 40 | for k, v in tqdm(dialogue.items()): 41 | dia_id = int(k) 42 | 43 | content = v['content'] 44 | user_id = v['user_id'] 45 | item_id = v['item_id'] 46 | gender = v['user_gender'] 47 | age = v['user_age'] 48 | 49 | speaker_list_id = selet_speaker_list_idx(age, gender) 50 | speaker_list = total_speaker[speaker_list_id] 51 | speaker = random.choice(speaker_list) 52 | speaker_id = vctk_to_speaker_id[speaker] 53 | speaker_speed = 1.2 54 | 55 | audio_list = [] 56 | count = 0 57 | save_name = f'diaid{dia_id}_uid{user_id}_iid{item_id}_{age}_{gender}_{speaker}' 58 | for uttr_name, uttr in content.items(): 59 | if (count % 2 == 0): 60 | s_audio = generate_user_speech_audio(user, uttr, speaker_id, speaker_speed) 61 | elif (count % 2 != 0): 62 | s_audio = generate_agent_speech_audio(agent, uttr) 63 | count += 1 64 | with open(f'{uttr_name}.wav', 'wb') as f: 65 | f.write(s_audio.data) 66 | audio_list.append(uttr_name) 67 | tag = 0 68 | for i, uttr_name in enumerate(audio_list): 69 | wav_file = AudioSegment.from_wav(f"./{uttr_name}.wav") 70 | if tag == 0: 71 | combine_wav = wav_file 72 | combine_wav = combine_wav + zero_wav 73 | tag = 1 74 | else: 75 | combine_wav = combine_wav + wav_file 76 | if (i != (len(audio_list) - 1)) and (i != (len(audio_list) - 3)): 77 | combine_wav = combine_wav + zero_wav 78 | os.remove(f"./{uttr_name}.wav") 79 | if args.format == 'wav': 80 | combine_wav.export(f"{save_dir}{save_name}.wav", format="wav") 81 | elif args.format == 'mp3': 82 | combine_wav.export(f"{save_dir}{save_name}.mp3", format="mp3") 83 | elif args.format == 'flac': 84 | combine_wav.export(f"{save_dir}{save_name}.flac", format="flac") 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /Dialogue/coat_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def age_map(age): 4 | if age == 0: 5 | age = 1 6 | elif age == 1: 7 | age = 2 8 | elif age == 2: 9 | age = 2 10 | elif age == 3: 11 | age = 2 12 | elif age == 4: 13 | age = 2 14 | elif age == 5: 15 | age = 0 16 | 17 | return age 18 | 19 | def get_item_index(feature, user_id): 20 | item_feature = feature[user_id] 21 | index = np.where(item_feature == 1)[0][ : -1] 22 | gender = index[0] 23 | jacket = index[1] - 2 24 | color = index[2] - 18 25 | 26 | return gender, jacket, color 27 | 28 | def get_user_age(age): 29 | if age == 0: 30 | age_str = 'under 20' 31 | elif age == 1: 32 | age_str = '20-30' 33 | elif age == 2: 34 | age_str = 'over 30' 35 | 36 | return age_str 37 | 38 | def get_user_gender(gender): 39 | if gender == 0: 40 | gender_str = 'men' 41 | elif gender == 1: 42 | gender_str = 'women' 43 | return gender_str 44 | 45 | def get_item_gender(gender): 46 | if gender == 0: 47 | gender_str = 'men' 48 | elif gender == 1: 49 | gender_str = 'women' 50 | 51 | return gender_str 52 | 53 | def get_item_type(jacket): 54 | if jacket == 0: 55 | jacket_str = 'bomber' 56 | elif jacket == 1: 57 | jacket_str = 'cropped' 58 | elif jacket == 2: 59 | jacket_str = 'field' 60 | elif jacket == 3: 61 | jacket_str = 'fleece' 62 | elif jacket == 4: 63 | jacket_str = 'insulated' 64 | elif jacket == 5: 65 | jacket_str = 'motorcycle' 66 | elif jacket == 6: 67 | jacket_str = 'other' 68 | elif jacket == 7: 69 | jacket_str = 'packable' 70 | elif jacket == 8: 71 | jacket_str = 'parkas' 72 | elif jacket == 9: 73 | jacket_str = 'pea' 74 | elif jacket == 10: 75 | jacket_str = 'rain' 76 | elif jacket == 11: 77 | jacket_str = 'shells' 78 | elif jacket == 12: 79 | jacket_str = 'track' 80 | elif jacket == 13: 81 | jacket_str = 'trench' 82 | elif jacket == 14: 83 | jacket_str = 'vests' 84 | elif jacket == 15: 85 | jacket_str = 'waterproof' 86 | 87 | return jacket_str 88 | 89 | def get_item_type_index(jacket_val): 90 | index_dict = {'bomber':0, 'cropped':1, 'field':2, 'fleece':3, 'insulated':4, 91 | 'motorcycle':5, 'other':6, 'packable':7, 'parkas':8, 'pea':9, 92 | 'rain':10, 'shells':11, 'track':12, 'trench':13, 'vests':14, 'waterproof':15} 93 | 94 | return index_dict[jacket_val] 95 | 96 | def get_item_color(color): 97 | if color == 0: 98 | color_str = 'beige' 99 | elif color == 1: 100 | color_str = 'black' 101 | elif color == 2: 102 | color_str = 'blue' 103 | elif color == 3: 104 | color_str = 'brown' 105 | elif color == 4: 106 | color_str = 'gray' 107 | elif color == 5: 108 | color_str = 'green' 109 | elif color == 6: 110 | color_str = 'multi' 111 | elif color == 7: 112 | color_str = 'navy' 113 | elif color == 8: 114 | color_str = 'olive' 115 | elif color == 9: 116 | color_str = 'other' 117 | elif color == 10: 118 | color_str = 'pink' 119 | elif color == 11: 120 | color_str = 'purple' 121 | elif color == 12: 122 | color_str = 'red' 123 | 124 | return color_str 125 | 126 | def get_item_color_index(color_val): 127 | index_dict = {'beige':0, 'black':1, 'blue':2, 'brown':3, 'gray':4, 128 | 'green':5, 'multi':6, 'navy':7, 'olive':8, 'other':9, 129 | 'pink':10, 'purple':11, 'red':12,} 130 | 131 | return index_dict[color_val] -------------------------------------------------------------------------------- /Speech/data/speaker_info/speaker-info.txt: -------------------------------------------------------------------------------- 1 | ID AGE GENDER ACCENTS REGION COMMENTS 2 | p225 23 F English Southern England 3 | p226 22 M English Surrey 4 | p227 38 M English Cumbria 5 | p228 22 F English Southern England 6 | p229 23 F English Southern England 7 | p230 22 F English Stockton-on-tees 8 | p231 23 F English Southern England 9 | p232 23 M English Southern England 10 | p233 23 F English Staffordshire 11 | p234 22 F Scottish West Dumfries 12 | p236 23 F English Manchester 13 | p237 22 M Scottish Fife 14 | p238 22 F NorthernIrish Belfast 15 | p239 22 F English SW England 16 | p240 21 F English Southern England 17 | p241 21 M Scottish Perth 18 | p243 22 M English London 19 | p244 22 F English Manchester 20 | p245 25 M Irish Dublin 21 | p246 22 M Scottish Selkirk 22 | p247 22 M Scottish Argyll 23 | p248 23 F Indian 24 | p249 22 F Scottish Aberdeen 25 | p250 22 F English SE England 26 | p251 26 M Indian 27 | p252 22 M Scottish Edinburgh 28 | p253 22 F Welsh Cardiff 29 | p254 21 M English Surrey 30 | p255 19 M Scottish Galloway 31 | p256 24 M English Birmingham 32 | p257 24 F English Southern England 33 | p258 22 M English Southern England 34 | p259 23 M English Nottingham 35 | p260 21 M Scottish Orkney 36 | p261 26 F NorthernIrish Belfast 37 | p262 23 F Scottish Edinburgh 38 | p263 22 M Scottish Aberdeen 39 | p264 23 F Scottish West Lothian 40 | p265 23 F Scottish Ross 41 | p266 22 F Irish Athlone 42 | p267 23 F English Yorkshire 43 | p268 23 F English Southern England 44 | p269 20 F English Newcastle 45 | p270 21 M English Yorkshire 46 | p271 19 M Scottish Fife 47 | p272 23 M Scottish Edinburgh 48 | p273 23 M English Suffolk 49 | p274 22 M English Essex 50 | p275 23 M Scottish Midlothian 51 | p276 24 F English Oxford 52 | p277 23 F English NE England 53 | p278 22 M English Cheshire 54 | p279 23 M English Leicester 55 | p280 25 F Unknown France (mic2 files unavailable) 56 | p281 29 M Scottish Edinburgh 57 | p282 23 F English Newcastle 58 | p283 24 F Irish Cork 59 | p284 20 M Scottish Fife 60 | p285 21 M Scottish Edinburgh 61 | p286 23 M English Newcastle 62 | p287 23 M English York 63 | p288 22 F Irish Dublin 64 | p292 23 M NorthernIrish Belfast 65 | p293 22 F NorthernIrish Belfast 66 | p294 33 F American San Francisco 67 | p295 23 F Irish Dublin 68 | p297 20 F American New York 69 | p298 19 M Irish Tipperary 70 | p299 25 F American California 71 | p300 23 F American California 72 | p301 23 F American North Carolina 73 | p302 20 M Canadian Montreal 74 | p303 24 F Canadian Toronto 75 | p304 22 M NorthernIrish Belfast 76 | p305 19 F American Philadelphia 77 | p306 21 F American New York 78 | p307 23 F Canadian Ontario 79 | p308 18 F American Alabama 80 | p310 21 F American Tennessee 81 | p311 21 M American Iowa 82 | p312 19 F Canadian Hamilton 83 | p313 24 F Irish County Down 84 | p314 26 F SouthAfrican Cape Town 85 | p315 18 M American New England (Text and mic2 files unavailable) 86 | p316 20 M Canadian Alberta 87 | p317 23 F Canadian Hamilton 88 | p318 32 F American Napa 89 | p323 19 F SouthAfrican Pretoria 90 | p326 26 M Australian English Sydney 91 | p329 23 F American 92 | p330 26 F American 93 | p333 19 F American Indiana 94 | p334 18 M American Chicago 95 | p335 25 F NewZealand English 96 | p336 18 F SouthAfrican Johannesburg 97 | p339 21 F American Pennsylvania 98 | p340 18 F Irish Dublin 99 | p341 26 F American Ohio 100 | p343 27 F Canadian Alberta 101 | p345 22 M American Florida 102 | p347 26 M SouthAfrican Johannesburg 103 | p351 21 F NorthernIrish Derry 104 | p360 19 M American New Jersey 105 | p361 19 F American New Jersey 106 | p362 29 F American 107 | p363 22 M Canadian Toronto 108 | p364 23 M Irish Donegal 109 | p374 28 M Australian English 110 | p376 22 M Indian 111 | s5 22 F British (new speaker, more data would be released for training a speaker-dependent model) 112 | -------------------------------------------------------------------------------- /Dialogue/config/thanks_pattern.py: -------------------------------------------------------------------------------- 1 | movie_agent = ["It was my pleasure to help.", 2 | "I'm happy to have been able to assist.", 3 | "No problem, I'm glad I could help.", 4 | "I'm glad I could make a difference.", 5 | "I'm glad I could be of service.", 6 | "It was my pleasure to assist you.", 7 | "I'm glad I could provide some assistance.", 8 | "You're welcome. I'm always happy to help with movie recommendations.", 9 | "It was my pleasure to provide some suggestions. I hope you enjoy the movies I recommended.", 10 | "I'm glad I could assist with your movie selection. Let me know if you have any other questions or need further recommendations.", 11 | "I'm always happy to assist with movie recommendations.", 12 | "I'm glad I could provide some movie recommendations for you. Let me know if you have any feedback or need further suggestions.", 13 | "I'm glad I could be of service with your movie selection. I hope you enjoy the movies I recommended.", 14 | "You're welcome.", 15 | "Hope I made some good suggestions.", 16 | "Glad to help.", 17 | ] 18 | 19 | movie_user = ["Thanks for all the help.", 20 | "Thanks for the advice, I'll check them.", 21 | "Thanks a lot for your recommendations.", 22 | "Thanks.", 23 | "Thank you.", 24 | "Thanks a lot for your help in finding something.", 25 | "Thank you so much.", 26 | "Thank you so much for the recommendations. I really appreciate your help.", 27 | "Thank you for your recommendation.", 28 | "I'm grateful for your suggestions.", 29 | "Thanks for the recommendations. I'm excited to try out some of the films you suggested.", 30 | "I really appreciate your help with movie recommendations.", 31 | "I appreciate your recommendations.", 32 | "Thanks for your help with movie suggestions.", 33 | "I'm grateful for your assistance with movie recommendations.", 34 | "I appreciate your help with movie suggestions.", 35 | ] 36 | 37 | 38 | coat_agent = ["It was my pleasure to help.", 39 | "I'm happy to have been able to assist.", 40 | "No problem, I'm glad I could help.", 41 | "I'm glad I could make a difference.", 42 | "I'm glad I could be of service.", 43 | "It was my pleasure to assist you.", 44 | "I'm glad I could provide some assistance.", 45 | "You're welcome. I'm always happy to help with coats recommendations.", 46 | "It was my pleasure to provide some suggestions. I hope you enjoy the coats I recommended.", 47 | "I'm glad I could assist with your coats selection. Let me know if you have any other questions or need further recommendations.", 48 | "I'm always happy to assist with coats recommendations.", 49 | "I'm glad I could provide some coats recommendations for you. Let me know if you have any feedback or need further suggestions.", 50 | "I'm glad I could be of service with your coats selection. I hope you enjoy the coats I recommended.", 51 | "You're welcome.", 52 | "Hope I made some good suggestions.", 53 | "Glad to help." 54 | ] 55 | 56 | 57 | coat_user = ["Thank you so much for the recommendations.", 58 | "Your suggestions were really helpful.", 59 | "I really appreciate your advice.", 60 | "Thanks for taking the time to help me out.", 61 | "I'm so grateful for your recommendations.", 62 | "Thank you for your kind assistance.", 63 | "Thanks for all the help.", 64 | "Thanks for the advice, I'll check them.", 65 | "Thanks a lot for your recommendations.", 66 | "Thanks.", 67 | "Thank you.", 68 | "Thanks a lot for your help in finding something.", 69 | "Thank you so much.", 70 | "Thank you so much for the recommendations. I really appreciate your help.", 71 | "Thank you for your recommendation.", 72 | "Thanks for the recommendations. I'm excited to try out some of the coats you suggested.", 73 | "I really appreciate your help with coats recommendations.", 74 | "I'm grateful for your assistance with coats recommendations.", 75 | "I appreciate your help with coats suggestions." 76 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Logo 4 | 5 | 6 |

Voice-based Conversational Recommender Systems

7 | 8 |

9 | Towards Building Voice-based Conversational Recommender Systems: Datasets, Potential Solutions, and Prospects 10 |

11 |
12 | 13 |
14 | Table of Contents 15 |
    16 |
  1. 17 | About The Project 18 |
  2. 19 |
  3. 20 | Dataset Description 21 |
  4. 22 |
  5. Potential Solution Exploration
  6. 23 |
  7. Data Construction
  8. 24 |
  9. Acknowledgement
  10. 25 |
26 |
27 | 28 | # About The Project 29 | This project is the code of paper "[Towards Building Voice-based Conversational Recommender Systems: Datasets, Potential Solutions, and Prospects](https://arxiv.org/abs/2306.08219)". In this project, we aim to provide two voice-based conversational recommender systems datasets in the e-commerce and movie domains. 30 | 31 | 32 | 33 | # Dataset Description 34 | You can download datasets from GoogleDrive. The datasets consist of two parts: [coat.tar.gz](https://drive.google.com/file/d/1FnpYhMaeskckxGheKjar0U4YHIdDKM6K/view?usp=share_link) and [ml-1m.tar.gz](https://drive.google.com/file/d/195ugsUrU51VMUjjI329M84qegtK2QuGC/view?usp=sharing) 35 | 36 | ## Dataset files 37 | The data file is formatted as a mp3 file and the file name form is `diaidxx_uidxx_iidxx_xx_xx_xx.mp3`. 38 | 39 | For example, for file `diaid21_uid249_iid35_20-30_men_251.mp3`, its meaning is as follows: 40 | ``` 41 | diaid21: corresponds to dialogue 21 in the text-based conversation dataset 42 | uid249: user id is 249 43 | iid35: item id is 35 44 | 20-30: user's age is between 20 and 30 45 | men: user's gender is male 46 | 251: corresponds to speaker p251 in vctk dataset 47 | ``` 48 | Speaker information on the vctk dataset can be found [here](www.udialogue.org/download/cstr-vctk-corpus.html) 49 | 50 | ## Case study 51 | Here we provide a demo of a data file (i.e., `diaid21_uid249_iid35_20-30_men_251.mp3`) that contains text and audio dialogue between the user and the agent. 52 | 53 | https://github.com/hyllll/VCRS/assets/38367896/b2b1b4d7-e860-46f4-9d6f-24ac8e1a6192 54 | 55 | Note that since we currently only explore the impact of speech on VCRS from the user's perspective, only the user's speech is included in the provided dataset. If you want complete dialogue audio, you can generate it through the code we provide. 56 | 57 | # Potential Solution Exploration 58 | We propose to extract explicit semantic features from the voice data and then incorporate them into the recommendation model in a two-phase fusion manner. 59 | 60 | Please refer to [here](https://github.com/hyllll/VCRS/tree/main/Recommender) for how to run the code. 61 | 62 | # Data Construction 63 | Our VCRSs dataset creation task includes four steps: (1) backbone dataset selection; (2)text-based conversation generation; (3) voice-based conversation generation; and (4) quality evaluation. 64 |
65 |

66 | 67 |
68 |

69 |
70 | 71 | ## Backbone dataset selection 72 | We choose [Coat](www.cs.cornell.edu/~schnabts/mnar/) and [ML-1M](grouplens.org/datasets/movielens/1m/) as our backbone datasets. Using user-item interactions and item features to simulate a text-based conversation between users and agents for recommendation and using user features to assign proper speakers. 73 | 74 | ## Text-based conversation generation 75 | Please refer to [here](https://github.com/hyllll/VCRS/tree/main/Dialogue) for how to generate the text-based conversation and the code is in `./Dialogue/` directory. 76 | 77 | ## Voice-based conversation generation 78 | Please refer to [here](https://github.com/hyllll/VCRS/tree/main/Speech) for how to generate the voice-based conversation. 79 | 80 | ## Quality evaluation 81 | We adopt the fine-grained evaluation of dialogue (FED) metric to measure the quality of the generated text-based conversation. 82 | ### Installtion 83 | ``` 84 | pip install -r requirements.txt 85 | ``` 86 | ### RUN 87 | 1. ```cd ./Evaluate/``` 88 | 2. ``` python evaluate.py --dataset='xxx'```, ```xxx``` is ```coat``` or ```ml-1m```. 89 | 3. All results are saved in `./res/` directory. 90 | 91 | 92 | 93 | # Acknowledgement 94 | * Convert text to audio using [VITS](https://github.com/jaywalnut310/vits), a SOTA end-to-end text-to-speech (TTS) model. 95 | 96 | * Improve code efficiency by [conv_rec_sys](https://github.com/xxkkrr/conv_rec_sys). 97 | 98 | * Evaluate text-based conversation with [Fed](https://github.com/exe1023/DialEvalMetrics). 99 | -------------------------------------------------------------------------------- /Speech/vits_lib/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from phonemizer import phonemize 18 | import logging 19 | import sys 20 | from logging import Logger 21 | 22 | 23 | # Regular expression matching whitespace: 24 | _whitespace_re = re.compile(r'\s+') 25 | 26 | # List of (regular expression, replacement) pairs for abbreviations: 27 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 28 | ('mrs', 'misess'), 29 | ('mr', 'mister'), 30 | ('dr', 'doctor'), 31 | ('st', 'saint'), 32 | ('co', 'company'), 33 | ('jr', 'junior'), 34 | ('maj', 'major'), 35 | ('gen', 'general'), 36 | ('drs', 'doctors'), 37 | ('rev', 'reverend'), 38 | ('lt', 'lieutenant'), 39 | ('hon', 'honorable'), 40 | ('sgt', 'sergeant'), 41 | ('capt', 'captain'), 42 | ('esq', 'esquire'), 43 | ('ltd', 'limited'), 44 | ('col', 'colonel'), 45 | ('ft', 'fort'), 46 | ]] 47 | 48 | 49 | def expand_abbreviations(text): 50 | for regex, replacement in _abbreviations: 51 | text = re.sub(regex, replacement, text) 52 | return text 53 | 54 | 55 | def expand_numbers(text): 56 | return normalize_numbers(text) 57 | 58 | 59 | def lowercase(text): 60 | return text.lower() 61 | 62 | 63 | def collapse_whitespace(text): 64 | return re.sub(_whitespace_re, ' ', text) 65 | 66 | 67 | def convert_to_ascii(text): 68 | return unidecode(text) 69 | 70 | 71 | def basic_cleaners(text): 72 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 73 | text = lowercase(text) 74 | text = collapse_whitespace(text) 75 | return text 76 | 77 | 78 | def transliteration_cleaners(text): 79 | '''Pipeline for non-English text that transliterates to ASCII.''' 80 | text = convert_to_ascii(text) 81 | text = lowercase(text) 82 | text = collapse_whitespace(text) 83 | return text 84 | 85 | 86 | def english_cleaners(text): 87 | '''Pipeline for English text, including abbreviation expansion.''' 88 | text = convert_to_ascii(text) 89 | text = lowercase(text) 90 | text = expand_abbreviations(text) 91 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, words_mismatch='ignore', logger=get_logger()) 92 | phonemes = collapse_whitespace(phonemes) 93 | return phonemes 94 | 95 | 96 | def english_cleaners2(text): 97 | '''Pipeline for English text, including abbreviation expansion. + punctuation + stress''' 98 | text = convert_to_ascii(text) 99 | text = lowercase(text) 100 | text = expand_abbreviations(text) 101 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, preserve_punctuation=True, with_stress=True, words_mismatch='ignore', logger=get_logger()) 102 | phonemes = collapse_whitespace(phonemes) 103 | return phonemes 104 | 105 | 106 | 107 | def get_logger(verbosity: str = 'quiet', name: str = 'phonemizer') -> Logger: 108 | """Returns a configured logging.Logger instance 109 | from https://github.com/bootphon/phonemizer/blob/master/phonemizer/logger.py 110 | The logger is configured to output messages on the standard error stream 111 | (stderr). 112 | Parameters 113 | ---------- 114 | verbosity (str) : The level of verbosity, must be 'verbose' (displays 115 | debug/info and warning messages), 'normal' (warnings only) or 'quiet' (do 116 | not display anything). 117 | name (str) : The logger name, default to 'phonemizer' 118 | Raises 119 | ------ 120 | RuntimeError if `verbosity` is not 'normal', 'verbose', or 'quiet'. 121 | """ 122 | # make sure the verbosity argument is valid 123 | valid_verbosity = ['normal', 'verbose', 'quiet'] 124 | if verbosity not in valid_verbosity: 125 | raise RuntimeError( 126 | f'verbosity is {verbosity} but must be in ' 127 | f'{", ".join(valid_verbosity)}') 128 | 129 | logger = logging.getLogger(name) 130 | 131 | # setup output to stderr 132 | logger.handlers = [] 133 | handler = logging.StreamHandler(sys.stderr) 134 | 135 | # setup verbosity level 136 | logger.setLevel(logging.ERROR) 137 | # print(verbosity) 138 | if verbosity == 'verbose': 139 | logger.setLevel(logging.DEBUG) 140 | elif verbosity == 'quiet': 141 | handler = logging.NullHandler() 142 | 143 | # setup messages format 144 | handler.setFormatter(logging.Formatter('[%(levelname)s] %(message)s')) 145 | logger.addHandler(handler) 146 | return logger 147 | -------------------------------------------------------------------------------- /Speech/vits_lib/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def get_timing_signal_1d( 68 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 69 | position = torch.arange(length, dtype=torch.float) 70 | num_timescales = channels // 2 71 | log_timescale_increment = ( 72 | math.log(float(max_timescale) / float(min_timescale)) / 73 | (num_timescales - 1)) 74 | inv_timescales = min_timescale * torch.exp( 75 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 78 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 79 | signal = signal.view(1, channels, length) 80 | return signal 81 | 82 | 83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 84 | b, channels, length = x.size() 85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 86 | return x + signal.to(dtype=x.dtype, device=x.device) 87 | 88 | 89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 90 | b, channels, length = x.size() 91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 93 | 94 | 95 | def subsequent_mask(length): 96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 97 | return mask 98 | 99 | 100 | @torch.jit.script 101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 102 | n_channels_int = n_channels[0] 103 | in_act = input_a + input_b 104 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 106 | acts = t_act * s_act 107 | return acts 108 | 109 | 110 | def convert_pad_shape(pad_shape): 111 | l = pad_shape[::-1] 112 | pad_shape = [item for sublist in l for item in sublist] 113 | return pad_shape 114 | 115 | 116 | def shift_1d(x): 117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 118 | return x 119 | 120 | 121 | def sequence_mask(length, max_length=None): 122 | if max_length is None: 123 | max_length = length.max() 124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 125 | return x.unsqueeze(0) < length.unsqueeze(1) 126 | 127 | 128 | def generate_path(duration, mask): 129 | """ 130 | duration: [b, 1, t_x] 131 | mask: [b, 1, t_y, t_x] 132 | """ 133 | device = duration.device 134 | 135 | b, _, t_y, t_x = mask.shape 136 | cum_duration = torch.cumsum(duration, -1) 137 | 138 | cum_duration_flat = cum_duration.view(b * t_x) 139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 140 | path = path.view(b, t_x, t_y) 141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 142 | path = path.unsqueeze(1).transpose(2,3) * mask 143 | return path 144 | 145 | 146 | def clip_grad_value_(parameters, clip_value, norm_type=2): 147 | if isinstance(parameters, torch.Tensor): 148 | parameters = [parameters] 149 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 150 | norm_type = float(norm_type) 151 | if clip_value is not None: 152 | clip_value = float(clip_value) 153 | 154 | total_norm = 0 155 | for p in parameters: 156 | param_norm = p.grad.data.norm(norm_type) 157 | total_norm += param_norm.item() ** norm_type 158 | if clip_value is not None: 159 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 160 | total_norm = total_norm ** (1. / norm_type) 161 | return total_norm 162 | -------------------------------------------------------------------------------- /Evaluate/fed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler 7 | from torch.utils.data.distributed import DistributedSampler 8 | 9 | from transformers import AutoTokenizer, AutoModelWithLMHead 10 | 11 | #tokenizer = GPT2Tokenizer.from_pretrained('dialogpt') 12 | #model = GPT2LMHeadModel.from_pretrained('gpt2') 13 | #weights = torch.load("dialogpt/small_fs.pkl") 14 | #weights = {k.replace("module.", ""): v for k,v in weights.items()} 15 | #weights["lm_head.weight"] = weights["lm_head.decoder.weight"] 16 | #weights.pop("lm_head.decoder.weight",None) 17 | #model.load_state_dict(weights) 18 | 19 | 20 | def load_models(name="microsoft/DialoGPT-large"): 21 | tokenizer = AutoTokenizer.from_pretrained(name, mirror='tuna') 22 | model = AutoModelWithLMHead.from_pretrained(name, mirror='tuna') 23 | model.to("cuda") 24 | return model, tokenizer 25 | 26 | def score(text, tokenizer, model): 27 | if not text.startswith("<|endoftext|> "): 28 | text = "<|endoftext|> " + text 29 | input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0) # Batch size 1 30 | tokenize_input = tokenizer.tokenize(text) 31 | #50256 is the token_id for <|endoftext|> 32 | tensor_input = torch.tensor([ tokenizer.convert_tokens_to_ids(tokenize_input)]).cuda() 33 | with torch.no_grad(): 34 | outputs = model(tensor_input, labels=tensor_input) 35 | loss, logits = outputs[:2] 36 | 37 | return loss.item() 38 | 39 | def evaluate(conversation, model, tokenizer): 40 | scores = {} 41 | turn_level_utts = { 42 | "interesting": { 43 | "positive": ["Wow that is really interesting.", "That's really interesting!", "Cool! That sounds super interesting."], 44 | "negative": ["That's not very interesting.", "That's really boring.", "That was a really boring response."] 45 | }, 46 | "engaging": { 47 | "positive": ["Wow! That's really cool!", "Tell me more!", "I'm really interested in learning more about this."], 48 | "negative": ["Let's change the topic.", "I don't really care. That's pretty boring.", "I want to talk about something else."] 49 | }, 50 | "specific": { 51 | "positive": ["That's good to know. Cool!", "I see, that's interesting.", "That's a good point."], 52 | "negative": ["That's a very generic response.", "Not really relevant here.", "That's not really relevant here."] 53 | }, 54 | "relevant": { 55 | "positive": [], 56 | "negative": ["That's not even related to what I said.", "Don't change the topic!", "Why are you changing the topic?"] 57 | }, 58 | "correct": { 59 | "positive": [], 60 | "negative": ["You're not understanding me!", "I am so confused right now!", "I don't understand what you're saying."] 61 | }, 62 | "semantically appropriate": { 63 | "positive": ["That makes sense!", "You have a good point."], 64 | "negative": ["That makes no sense!"] 65 | }, 66 | "understandable": { 67 | "positive": ["That makes sense!", "You have a good point."], 68 | "negative": ["I don't understand at all!", "I'm so confused!", "That makes no sense!", "What does that even mean?"] 69 | }, 70 | "fluent": { 71 | "positive": ["That makes sense!", "You have a good point."], 72 | "negative": ["Is that real English?", "I'm so confused right now!", "That makes no sense!"] 73 | }, 74 | } 75 | for metric,utts in turn_level_utts.items(): 76 | pos = utts["positive"] 77 | neg = utts["negative"] 78 | 79 | # Positive score 80 | high_score = 0 81 | for m in pos: 82 | hs = score(conversation + " <|endoftext|> " + m, tokenizer, model) 83 | high_score += hs 84 | 85 | high_score = high_score/max(len(pos), 1) 86 | 87 | # Negative score 88 | low_score = 0 89 | for m in neg: 90 | ls = score(conversation + " <|endoftext|> " + m, tokenizer, model) 91 | low_score += ls 92 | low_score = low_score/max(len(neg), 1) 93 | 94 | scores[metric] = (low_score - high_score) 95 | 96 | dialog_level_utts = { 97 | "coherent": { 98 | "positive": [], 99 | "negative": ["You're making no sense at all.", "You're changing the topic so much!", "You are so confusing."] 100 | }, 101 | "error recovery": { 102 | "positive": [], 103 | "negative": ["I am so confused right now.", "You're really confusing.", "I don't understand what you're saying."] 104 | }, 105 | "consistent": { 106 | "positive": [], 107 | "negative": ["That's not what you said earlier!", "Stop contradicting yourself!"], 108 | }, 109 | "diverse": { 110 | "positive": [], 111 | "negative": ["Stop saying the same thing repeatedly.", "Why are you repeating yourself?", "Stop repeating yourself!"] 112 | }, 113 | "depth": { 114 | "positive": [], 115 | "negative": ["Stop changing the topic so much.", "Don't change the topic!"], 116 | }, 117 | "likeable": { 118 | "positive": ["I like you!", "You're super polite and fun to talk to", "Great talking to you."], 119 | "negative": ["You're not very nice.", "You're not very fun to talk to.", "I don't like you."] 120 | }, 121 | "understand": { 122 | "positive": [], 123 | "negative": ["You're not understanding me!", "What are you trying to say?", "I don't understand what you're saying."] 124 | }, 125 | "flexible": { 126 | "positive": ["You're very easy to talk to!", "Wow you can talk about a lot of things!"], 127 | "negative": ["I don't want to talk about that!", "Do you know how to talk about something else?"], 128 | }, 129 | "informative": { 130 | "positive": ["Thanks for all the information!", "Wow that's a lot of information.", "You know a lot of facts!"], 131 | "negative": ["You're really boring.", "You don't really know much."], 132 | }, 133 | "inquisitive": { 134 | "positive": ["You ask a lot of questions!", "That's a lot of questions!"], 135 | "negative": ["You don't ask many questions.", "You don't seem interested."], 136 | }, 137 | } 138 | for metric,utts in dialog_level_utts.items(): 139 | pos = utts["positive"] 140 | neg = utts["negative"] 141 | 142 | # Positive 143 | high_score = 0 144 | for m in pos: 145 | hs = score(conversation + " <|endoftext|> " + m, tokenizer, model) 146 | high_score += hs 147 | 148 | high_score = high_score/max(len(pos), 1) 149 | 150 | # Negative 151 | low_score = 0 152 | for m in neg: 153 | ls = score(conversation + " <|endoftext|> " + m, tokenizer, model) 154 | low_score += ls 155 | low_score = low_score/max(len(neg), 1) 156 | 157 | scores[metric] = (low_score - high_score) 158 | 159 | return scores -------------------------------------------------------------------------------- /Speech/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import vits_lib.commons as commons 3 | import vits_lib.utils as vits_utils 4 | from vits_lib.models import SynthesizerTrn 5 | from vits_lib.text.symbols import symbols 6 | from vits_lib.text import text_to_sequence 7 | import IPython.display as ipd 8 | 9 | 10 | 11 | def get_text(text, hps): 12 | text_norm = text_to_sequence(text, hps.data.text_cleaners) 13 | if hps.data.add_blank: 14 | text_norm = commons.intersperse(text_norm, 0) 15 | text_norm = torch.LongTensor(text_norm) 16 | return text_norm 17 | 18 | def get_vid_sid(line): 19 | ''' 20 | Return: 21 | vctk_id: vctk id 22 | speaker_id: speaker id in vits 23 | ''' 24 | line = line.split('|') 25 | vctk_id = int(line[0].split('/')[1][1:]) 26 | speaker_id = int(line[1]) 27 | 28 | return vctk_id, speaker_id 29 | 30 | def split_speaker(s_info): 31 | # TODO: optimieze if else 32 | under_20_male = [] 33 | under_20_female = [] 34 | between_20_30_male = [] 35 | between_20_30_female = [] 36 | above_30_male = [] 37 | above_30_female = [] 38 | 39 | for i, info in enumerate(s_info): 40 | if i == 0: 41 | continue 42 | else: 43 | gender = info[10:11] 44 | age = int(info[6:8]) 45 | user_id = int(info[1:4]) 46 | if age < 20 and gender == 'M': 47 | under_20_male.append(user_id) 48 | elif age < 20 and gender == 'F': 49 | under_20_female.append(user_id) 50 | elif age >= 20 and age <=30 and gender == 'M': 51 | between_20_30_male.append(user_id) 52 | elif age >= 20 and age <=30 and gender == 'F': 53 | between_20_30_female.append(user_id) 54 | elif age > 30 and gender == 'M': 55 | above_30_male.append(user_id) 56 | else: 57 | above_30_female.append(user_id) 58 | 59 | between_20_30_female.remove(5) 60 | under_20_male.remove(315) 61 | 62 | total_speaker = [] 63 | total_speaker.append(under_20_male) 64 | total_speaker.append(under_20_female) 65 | total_speaker.append(between_20_30_male) 66 | total_speaker.append(between_20_30_female) 67 | total_speaker.append(above_30_male) 68 | total_speaker.append(above_30_female) 69 | 70 | return total_speaker 71 | 72 | 73 | def preprocess_speaker_info(): 74 | with open('./data/speaker_info/speaker-info.txt') as f: 75 | lines = f.readlines() 76 | s_info = [] 77 | for sub in lines: 78 | s_info.append(sub.replace("\n", "")) 79 | 80 | with open('./data/speaker_info/vctk_audio_sid_text_test_filelist.txt.cleaned.txt') as f: 81 | sid_text_test_filelist = f.readlines() 82 | with open('./data/speaker_info/vctk_audio_sid_text_train_filelist.txt.cleaned.txt') as f: 83 | sid_text_train_filelist = f.readlines() 84 | with open('./data/speaker_info/vctk_audio_sid_text_val_filelist.txt.cleaned.txt') as f: 85 | sid_text_val_filelist = f.readlines() 86 | filelist = sid_text_test_filelist + sid_text_train_filelist + sid_text_val_filelist 87 | vctk_to_speaker_id = {} 88 | for line in filelist: 89 | v_id, s_id = get_vid_sid(line) 90 | if v_id not in vctk_to_speaker_id.keys(): 91 | vctk_to_speaker_id[v_id] = s_id 92 | total_speaker = split_speaker(s_info) 93 | 94 | return vctk_to_speaker_id, total_speaker 95 | 96 | 97 | def load_vits_model(): 98 | hps_agent = vits_utils.get_hparams_from_file("./vits_lib/configs/ljs_base.json") 99 | net_agent = SynthesizerTrn( 100 | len(symbols), 101 | hps_agent.data.filter_length // 2 + 1, 102 | hps_agent.train.segment_size // hps_agent.data.hop_length, 103 | **hps_agent.model).cuda() 104 | _ = net_agent.eval() 105 | _ = vits_utils.load_checkpoint("./vits_lib/pretrain/pretrained_ljs.pth", net_agent, None) 106 | 107 | hps_user = vits_utils.get_hparams_from_file("./vits_lib/configs/vctk_base.json") 108 | net_user = SynthesizerTrn( 109 | len(symbols), 110 | hps_user.data.filter_length // 2 + 1, 111 | hps_user.train.segment_size // hps_user.data.hop_length, 112 | n_speakers=hps_user.data.n_speakers, 113 | **hps_user.model).cuda() 114 | _ = net_user.eval() 115 | _ = vits_utils.load_checkpoint("./vits_lib/pretrain/pretrained_vctk.pth", net_user, None) 116 | 117 | return (hps_agent, net_agent), (hps_user, net_user) 118 | 119 | 120 | def generate_agent_speech_audio(agent, text): 121 | hps_agent = agent[0] 122 | net_agent = agent[1] 123 | 124 | stn_tst = get_text(text, hps_agent) 125 | with torch.no_grad(): 126 | x_tst = stn_tst.cuda().unsqueeze(0) 127 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 128 | audio = net_agent.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1.0)[0][0,0].data.cpu().float().numpy() 129 | audio = ipd.Audio(audio, rate=hps_agent.data.sampling_rate, normalize=False) 130 | 131 | return audio 132 | 133 | 134 | def generate_user_speech(user, text, sid, speaker_speed): 135 | hps_user = user[0] 136 | net_user = user[1] 137 | 138 | stn_tst = get_text(text, hps_user) 139 | with torch.no_grad(): 140 | x_tst = stn_tst.cuda().unsqueeze(0) 141 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 142 | sid = torch.LongTensor([sid]).cuda() 143 | audio = net_user.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=speaker_speed)[0][0,0].data.cpu().float().numpy() 144 | # audio = ipd.Audio(audio, rate=hps_user.data.sampling_rate, normalize=False) 145 | 146 | return audio 147 | 148 | 149 | def generate_user_speech_audio(user, text, sid, speaker_speed): 150 | hps_user = user[0] 151 | net_user = user[1] 152 | 153 | stn_tst = get_text(text, hps_user) 154 | with torch.no_grad(): 155 | x_tst = stn_tst.cuda().unsqueeze(0) 156 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() 157 | sid = torch.LongTensor([sid]).cuda() 158 | audio = net_user.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=speaker_speed)[0][0,0].data.cpu().float().numpy() 159 | audio = ipd.Audio(audio, rate=hps_user.data.sampling_rate, normalize=False) 160 | 161 | return audio 162 | 163 | 164 | 165 | def selet_speaker_list_idx(age, gender): 166 | if age == 'under 20' and gender == 'men': 167 | return 0 168 | elif age == 'under 20' and gender == 'women': 169 | return 1 170 | elif age == '20-30' and gender == 'men': 171 | return 2 172 | elif age == '20-30' and gender == 'women': 173 | return 3 174 | elif age == 'over 30' and gender == 'men': 175 | return 4 176 | else: 177 | return 5 -------------------------------------------------------------------------------- /Speech/data/speaker_info/vctk_audio_sid_text_val_filelist.txt.cleaned.txt: -------------------------------------------------------------------------------- 1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm. 2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm. 3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt. 4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn. 5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn. 6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd. 7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt. 8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt. 9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm. 10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz. 11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz. 12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks. 13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt. 14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns? 15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ. 16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt. 17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ. 18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi. 19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz. 20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl. 21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ. 22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn. 23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld. 24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs. 25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz. 26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd. 27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ. 28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs. 29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt. 31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk. 32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl. 33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf. 34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm. 35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd. 36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf. 37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn. 38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp. 39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ. 40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən. 41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl. 42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs." 43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl. 44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz. 45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt. 46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk. 47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt. 48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm. 49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz. 50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd. 51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ. 52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən. 53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt. 54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz. 55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd. 56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs. 57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd. 58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz. 59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd. 60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs. 61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs. 62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd. 63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən. 64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ? 65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn. 66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ. 67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst. 68 | DUMMY2/p362/p362_247.wav|15|ðɪs wʊd ɡˈɪv skˈɑːtlənd ɐɹˈaʊnd ˈeɪt mˈɛmbɚz. 69 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt. 70 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ. 71 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm. 72 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf. 73 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn. 74 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd. 75 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən. 76 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ. 77 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl. 78 | DUMMY2/p362/p362_314.wav|15|ɪz ðæt ɹˈaɪt? 79 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd. 80 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt. 81 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ. 82 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl. 83 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk. 84 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt. 85 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm. 86 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ. 87 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ. 88 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt. 89 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt. 90 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd. 91 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz. 92 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk. 93 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ. 94 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt. 95 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət? 96 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ. 97 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ. 98 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts. 99 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː. 100 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf. 101 | -------------------------------------------------------------------------------- /Dialogue/gen_coat.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import lightgbm as lgb 4 | import random 5 | import json 6 | import os 7 | from coat_utils import * 8 | from collections import defaultdict 9 | from config.coat_pattern import start_pattern, agent_pattern, user_pattern 10 | from config.thanks_pattern import coat_agent, coat_user 11 | from config.recommend_pattern import coat_rec 12 | from coat_attr import generate_gender_dialogue, generate_jacket_dialogue, generate_color_dialogue 13 | 14 | def load_data(): 15 | coat = pd.read_csv("./data/coat/coat_info.csv") 16 | coat['age'] = coat['age'].apply(lambda age: age_map(age)) 17 | item_feature = np.genfromtxt('./data/coat/item_features.ascii', dtype=None) 18 | 19 | return (coat, item_feature) 20 | 21 | def calculate_user_preference(data): 22 | coat = data[0] 23 | item_feature = data[1] 24 | user_record = defaultdict(set) 25 | item_matrix = np.zeros((300, 3)) 26 | for _, row in coat.iterrows(): 27 | user_id = int(row['user']) 28 | item_id = int(row['item']) 29 | user_record[user_id].add(item_id) 30 | gender, jacket, color = get_item_index(item_feature, item_id) 31 | item_matrix[item_id][0] = gender 32 | item_matrix[item_id][1] = jacket 33 | item_matrix[item_id][2] = color 34 | 35 | item_csv = pd.DataFrame(item_matrix) 36 | item_csv.insert(item_csv.shape[1], 'label', 0) 37 | item_csv.columns = ['gender','jacket', 'color', 'label'] 38 | features_cols = ['gender', 'jacket', 'color'] 39 | user_preference = defaultdict(list) 40 | for user in range(300): 41 | user_item_csv = item_csv.copy() 42 | record = list(user_record[user]) 43 | for item in record: 44 | user_item_csv.loc[item, 'label'] = 1 45 | X = user_item_csv[features_cols] 46 | Y = user_item_csv.label 47 | 48 | cls = lgb.LGBMClassifier(importance_type='gain') 49 | cls.fit(X, Y) 50 | 51 | indices = np.argsort(cls.booster_.feature_importance(importance_type='gain')) 52 | feature = [features_cols[i] for i in indices] 53 | 54 | user_preference[user] = feature 55 | 56 | return user_preference 57 | 58 | def get_user_item_info(data): 59 | user_info = {} 60 | item_info = {} 61 | coat = data[0] 62 | item_feature = data[1] 63 | for _, row in coat.iterrows(): 64 | user_id = int(row['user']) 65 | item_id = int(row['item']) 66 | user_info[user_id] = {} 67 | user_info[user_id]['age'] = get_user_age(row['age']) 68 | user_info[user_id]['gender'] = get_user_gender(row['gender']) 69 | 70 | if item_id not in item_info.keys(): 71 | gender, jacket, color = get_item_index(item_feature, item_id) 72 | item_info[item_id] = {} 73 | item_info[item_id]['gender'] = get_item_gender(gender) 74 | item_info[item_id]['jacket'] = get_item_type(jacket) 75 | item_info[item_id]['color'] = get_item_color(color) 76 | 77 | return user_info, item_info 78 | 79 | def calculate_attr_weights(data): 80 | gender_all = {} 81 | jacket_all = {} 82 | color_all = {} 83 | coat = data[0] 84 | item_feature = data[1] 85 | for _, row in coat.iterrows(): 86 | item_id = int(row['item']) 87 | gender, jacket, color = get_item_index(item_feature, item_id) 88 | if gender not in gender_all.keys(): 89 | gender_all[gender] = 1 90 | else: 91 | gender_all[gender] += 1 92 | 93 | if jacket not in jacket_all.keys(): 94 | jacket_all[jacket] = 1 95 | else: 96 | jacket_all[jacket] += 1 97 | 98 | if color not in color_all.keys(): 99 | color_all[color] = 1 100 | else: 101 | color_all[color] += 1 102 | 103 | gender_weight = [gender_all[i] for i in sorted(gender_all)] 104 | jacket_weight = [jacket_all[i] for i in sorted(jacket_all)] 105 | color_weight = [color_all[i] for i in sorted(color_all)] 106 | 107 | return (gender_weight, jacket_weight, color_weight), (gender_all, jacket_all, color_all) 108 | 109 | 110 | if __name__ == '__main__': 111 | coat_data = load_data() 112 | user_preference = calculate_user_preference(coat_data) 113 | user_info, item_info = get_user_item_info(coat_data) 114 | weights, attr_counts = calculate_attr_weights(coat_data) 115 | print("data load complete") 116 | dialogue_info = {} 117 | dialogue_id = 0 118 | for _, row in coat_data[0].iterrows(): 119 | user_id = int(row['user']) 120 | item_id = int(row['item']) 121 | 122 | new_dialogue = {} 123 | new_dialogue["user_id"] = user_id 124 | new_dialogue["item_id"] = item_id 125 | 126 | new_dialogue["user_gender"] = user_info[user_id]["gender"] 127 | new_dialogue["user_age"] = user_info[user_id]["age"] 128 | new_dialogue["content"] = {} 129 | new_dialogue["content"]["start"] = random.choice(start_pattern) 130 | 131 | dialouge_order = user_preference[user_id] 132 | tmp_new_dialogue = [] 133 | 134 | for slot in dialouge_order: 135 | if slot == "gender": 136 | gender_val = item_info[item_id]["gender"] 137 | utterance = generate_gender_dialogue(agent_pattern, user_pattern, gender_val, attr_counts[0], weights[0]) 138 | elif slot == "jacket": 139 | jacket_val = item_info[item_id]["jacket"] 140 | utterance = generate_jacket_dialogue(agent_pattern, user_pattern, jacket_val, attr_counts[1], weights[1]) 141 | elif slot == "color": 142 | color_val = item_info[item_id]["color"] 143 | utterance = generate_color_dialogue(agent_pattern, user_pattern, color_val, attr_counts[2], weights[2]) 144 | tmp_new_dialogue.append(utterance) 145 | end_dialogue = {} 146 | end_dialogue["rec"] = random.choice(coat_rec) 147 | end_dialogue["thanks_user"] = random.choice(coat_user) 148 | end_dialogue["thanks_agent"] = random.choice(coat_agent) 149 | tmp_new_dialogue.append(end_dialogue) 150 | print("finish:", dialogue_id) 151 | start_index = 0 152 | end_index = 0 153 | step = 0 154 | name = ["Q1", "A1", "Q2", "A2", "Q3", "A3", "Q4", "A4", "Q5", "A5", "Q6", "A6", "Q7", "A7", "Q8", "A8", "Q9", "A9", "Q10", "A10", "Q11", "A11", "Q12", "A12"] 155 | for dia in tmp_new_dialogue: 156 | end_index = len(dia) + end_index 157 | tmp_name = name[start_index : end_index] 158 | tmp_dia = [] 159 | for _, v in dia.items(): 160 | tmp_dia.append(v) 161 | for i, val in enumerate(tmp_name): 162 | new_dialogue["content"][val] = tmp_dia[i] 163 | start_index = end_index 164 | dialogue_info[dialogue_id] = new_dialogue 165 | dialogue_id = dialogue_id + 1 166 | 167 | 168 | res_path = './res/coat/' 169 | if not os.path.exists(res_path): 170 | os.makedirs(res_path) 171 | with open(res_path + 'dialogue_info_coat.json', 'w') as f: 172 | json.dump(dialogue_info, f, indent=4) 173 | 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /Recommender/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | 6 | def get_ur(df): 7 | """ 8 | Method of getting user-rating pairs 9 | Parameters 10 | ---------- 11 | df : pd.DataFrame, rating dataframe 12 | 13 | Returns 14 | ------- 15 | ur : dict, dictionary stored user-items interactions 16 | """ 17 | ur = defaultdict(set) 18 | for _, row in df.iterrows(): 19 | ur[int(row['user'])].add(int(row['item'])) 20 | 21 | return ur 22 | 23 | 24 | def build_candidates_set(test_ur, train_ur, item_pool, candidates_num=1000): 25 | """ 26 | method of building candidate items for ranking 27 | Parameters 28 | ---------- 29 | test_ur : dict, ground_truth that represents the relationship of user and item in the test set 30 | train_ur : dict, this represents the relationship of user and item in the train set 31 | item_pool : the set of all items 32 | candidates_num : int, the number of candidates 33 | Returns 34 | ------- 35 | test_ucands : dict, dictionary storing candidates for each user in test set 36 | """ 37 | # random.seed(1) 38 | test_ucands = defaultdict(list) 39 | for k, v in test_ur.items(): 40 | sample_num = candidates_num - len(v) if len(v) < candidates_num else 0 41 | sub_item_pool = item_pool - v - train_ur[k] # remove GT & interacted 42 | sample_num = min(len(sub_item_pool), sample_num) 43 | if sample_num == 0: 44 | samples = random.sample(v, candidates_num) 45 | test_ucands[k] = list(set(samples)) 46 | else: 47 | samples = random.sample(sub_item_pool, sample_num) 48 | test_ucands[k] = list(v | set(samples)) 49 | 50 | return test_ucands 51 | 52 | 53 | def precision_at_k(r, k): 54 | """ 55 | Precision calculation method 56 | Parameters 57 | ---------- 58 | r : List, list of the rank items 59 | k : int, top-K number 60 | 61 | Returns 62 | ------- 63 | pre : float, precision value 64 | """ 65 | assert k >= 1 66 | r = np.asarray(r)[:k] != 0 67 | if r.size != k: 68 | raise ValueError('Relevance score length < k') 69 | # return np.mean(r) 70 | pre = sum(r) / len(r) 71 | 72 | return pre 73 | 74 | 75 | def recall_at_k(rs, test_ur, k): 76 | """ 77 | Recall calculation method 78 | Parameters 79 | ---------- 80 | rs : Dict, {user : rank items} for test set 81 | test_ur : Dict, {user : items} for test set ground truth 82 | k : int, top-K number 83 | 84 | Returns 85 | ------- 86 | rec : float recall value 87 | """ 88 | assert k >= 1 89 | res = [] 90 | for user in test_ur.keys(): 91 | r = np.asarray(rs[user])[:k] != 0 92 | if r.size != k: 93 | raise ValueError('Relevance score length < k') 94 | if len(test_ur[user]) == 0: 95 | raise KeyError(f'Invalid User Index: {user}') 96 | res.append(sum(r) / len(test_ur[user])) 97 | rec = np.mean(res) 98 | 99 | return rec 100 | 101 | 102 | def mrr_at_k(rs, k): 103 | """ 104 | Mean Reciprocal Rank calculation method 105 | Parameters 106 | ---------- 107 | rs : Dict, {user : rank items} for test set 108 | k : int, topK number 109 | 110 | Returns 111 | ------- 112 | mrr : float, MRR value 113 | """ 114 | assert k >= 1 115 | res = 0 116 | for r in rs.values(): 117 | r = np.asarray(r)[:k] != 0 118 | for index, item in enumerate(r): 119 | if item == 1: 120 | res += 1 / (index + 1) 121 | mrr = res / len(rs) 122 | 123 | return mrr 124 | 125 | 126 | def ap(r): 127 | """ 128 | Average precision calculation method 129 | Parameters 130 | ---------- 131 | r : List, Relevance scores (list or numpy) in rank order (first element is the first item) 132 | 133 | Returns 134 | ------- 135 | a_p : float, Average precision value 136 | """ 137 | r = np.asarray(r) != 0 138 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]] 139 | if not out: 140 | return 0. 141 | a_p = np.sum(out) / len(r) 142 | 143 | return a_p 144 | 145 | 146 | def map_at_k(rs): 147 | """ 148 | Mean Average Precision calculation method 149 | Parameters 150 | ---------- 151 | rs : Dict, {user : rank items} for test set 152 | 153 | Returns 154 | ------- 155 | m_a_p : float, MAP value 156 | """ 157 | m_a_p = np.mean([ap(r) for r in rs]) 158 | return m_a_p 159 | 160 | 161 | def dcg_at_k(r, k): 162 | """ 163 | Discounted Cumulative Gain calculation method 164 | Parameters 165 | ---------- 166 | r : List, Relevance scores (list or numpy) in rank order 167 | (first element is the first item) 168 | k : int, top-K number 169 | 170 | Returns 171 | ------- 172 | dcg : float, DCG value 173 | """ 174 | assert k >= 1 175 | r = np.asfarray(r)[:k] != 0 176 | if r.size: 177 | dcg = np.sum(np.subtract(np.power(2, r), 1) / np.log2(np.arange(2, r.size + 2))) 178 | return dcg 179 | return 0. 180 | 181 | 182 | def ndcg_at_k(r, k): 183 | """ 184 | Normalized Discounted Cumulative Gain calculation method 185 | Parameters 186 | ---------- 187 | r : List, Relevance scores (list or numpy) in rank order 188 | (first element is the first item) 189 | k : int, top-K number 190 | 191 | Returns 192 | ------- 193 | ndcg : float, NDCG value 194 | """ 195 | assert k >= 1 196 | idcg = dcg_at_k(sorted(r, reverse=True), k) 197 | if not idcg: 198 | return 0. 199 | ndcg = dcg_at_k(r, k) / idcg 200 | 201 | return ndcg 202 | 203 | 204 | def hr_at_k(rs, test_ur): 205 | """ 206 | Hit Ratio calculation method 207 | Parameters 208 | ---------- 209 | rs : Dict, {user : rank items} for test set 210 | test_ur : (Deprecated) Dict, {user : items} for test set ground truth 211 | 212 | Returns 213 | ------- 214 | hr : float, HR value 215 | """ 216 | # another way for calculating hit rate 217 | # numer, denom = 0., 0. 218 | # for user in test_ur.keys(): 219 | # numer += np.sum(rs[user]) 220 | # denom += len(test_ur[user]) 221 | 222 | # return numer / denom 223 | uhr = 0 224 | for r in rs.values(): 225 | if np.sum(r) != 0: 226 | uhr += 1 227 | hr = uhr / len(rs) 228 | 229 | return hr 230 | 231 | def get_feature(feature_list): 232 | if 'gender' in feature_list: 233 | feature_num = 0 234 | elif 'age' in feature_list: 235 | feature_num = 1 236 | elif 'gender' in feature_list and 'age' in feature_list: 237 | feature_num = 2 238 | elif 'gender' not in feature_list and 'age' not in feature_list: 239 | feature_num = 3 240 | 241 | return feature_num 242 | 243 | def get_user_info(df, feature_num): 244 | user_info = dict() 245 | for _, row in df.iterrows(): 246 | if feature_num == 0: 247 | user_info[int(row['user'])] = [int(row['gender'])] 248 | elif feature_num == 1: 249 | user_info[int(row['user'])] = [int(row['age'])] 250 | elif feature_num == 2: 251 | user_info[int(row['user'])] = [int(row['gender']), int(row['age'])] 252 | return user_info 253 | 254 | -------------------------------------------------------------------------------- /Recommender/gen_label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | from torch.utils.data import Dataset,DataLoader 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchaudio 9 | import tqdm as tqdm 10 | import argparse 11 | from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel, AutoConfig 12 | 13 | 14 | class Wav2Vec2ClassificationModel(Wav2Vec2PreTrainedModel): 15 | def __init__(self, config, hidden_size, dropout): 16 | super().__init__(config) 17 | 18 | self.wav2vec2 = Wav2Vec2Model(config) 19 | self.hidden_size = hidden_size 20 | self.fc = nn.Linear(config.hidden_size, self.hidden_size) 21 | self.dropout = nn.Dropout(dropout) 22 | self.gender_fc = nn.Linear(self.hidden_size, 2) 23 | self.age_fc = nn.Linear(self.hidden_size, 3) 24 | self.tanh = nn.Tanh() 25 | 26 | self.init_weights() 27 | 28 | def freeze_feature_extractor(self): 29 | self.wav2vec2.feature_extractor._freeze_parameters() 30 | 31 | def merged_strategy(self, hidden_states): 32 | outputs = torch.mean(hidden_states, dim=1) 33 | 34 | return outputs 35 | 36 | def forward( 37 | self, 38 | input_values, 39 | attention_mask=None, 40 | output_attentions=None, 41 | output_hidden_states=None, 42 | return_dict=None, 43 | labels=None, 44 | ): 45 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 46 | with torch.no_grad(): 47 | outputs = self.wav2vec2( 48 | input_values, 49 | attention_mask=attention_mask, 50 | output_attentions=output_attentions, 51 | output_hidden_states=output_hidden_states, 52 | return_dict=return_dict, 53 | ) 54 | 55 | hidden_states = outputs[0] 56 | x = self.merged_strategy(hidden_states) 57 | x = self.dropout(x) 58 | x = self.fc(x) 59 | x = self.tanh(x) 60 | x = self.dropout(x) 61 | gender_logits = self.gender_fc(x) 62 | gender_logits = F.log_softmax(gender_logits, dim=-1) 63 | age_logits = self.age_fc(x) 64 | age_logits = F.log_softmax(age_logits, dim=-1) 65 | 66 | return age_logits, gender_logits 67 | 68 | def construct_data(input_dir): 69 | datalist = os.listdir(input_dir) 70 | user = [] 71 | item = [] 72 | gender = [] 73 | age = [] 74 | name = [] 75 | for file in datalist: 76 | u_id = int(file.split('_')[1][3:]) 77 | i_id = int(file.split('_')[2][3:]) 78 | user.append(u_id) 79 | item.append(i_id) 80 | 81 | name.append(file) 82 | gender.append(file.split('_')[-2]) 83 | age.append(file.split('_')[-3]) 84 | data = pd.DataFrame({'user':user, 'item':item, 'gender':gender, 'age':age, 'audio_name':name}) 85 | 86 | return data 87 | 88 | class GenDataset(Dataset): 89 | def __init__(self, df, audio_path): 90 | self.df = df 91 | self.audio_path = audio_path 92 | self.user = self.df.iloc[:,0].values 93 | self.item = self.df.iloc[:,1].values 94 | self.gender = self.df.iloc[:,2].values 95 | self.age = self.df.iloc[:,3].values 96 | self.audios = self.df.iloc[:,4].values 97 | self.age_labels = {'under 20':0, '20-30':1, 'over 30':2} 98 | self.gender_labels = {'women':0, 'men':1} 99 | 100 | def __len__(self): 101 | return len(self.df) 102 | 103 | def __getitem__(self, index): 104 | user = self.user[index] 105 | item = self.item[index] 106 | gender = self.gender_labels[self.gender[index]] 107 | age = self.age_labels[self.age[index]] 108 | 109 | audio_file_path = os.path.join(self.audio_path, self.audios[index]) 110 | waveform, sample_rate = torchaudio.load(audio_file_path) 111 | waveform = self._resample(waveform, sample_rate) 112 | 113 | return user, item, gender, age, waveform 114 | 115 | def _resample(self, waveform, sample_rate): 116 | resampler = torchaudio.transforms.Resample(sample_rate,16000) 117 | 118 | return resampler(waveform) 119 | 120 | def pad_sequence(batch): 121 | # Make all tensor in a batch the same length by padding with zeros 122 | batch = [item.t() for item in batch] 123 | batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.) 124 | return batch.permute(0, 2, 1) 125 | 126 | def collate_fn(batch): 127 | 128 | users, items, audios, age_labels, gender_labels = [], [], [], [], [] 129 | 130 | for user, item, gender, age, waveform in batch: 131 | audios += [waveform] 132 | users += [torch.tensor(user)] 133 | items += [torch.tensor(item)] 134 | age_labels += [torch.tensor(age)] 135 | gender_labels += [torch.tensor(gender)] 136 | 137 | audios = pad_sequence(audios) 138 | users = torch.stack(users) 139 | items = torch.stack(items) 140 | age_labels = torch.stack(age_labels) 141 | gender_labels = torch.stack(gender_labels) 142 | 143 | return users, items, gender_labels, age_labels, audios.squeeze(dim=1) 144 | 145 | 146 | def get_likely_index(tensor): 147 | # find most likely label index for each element in the batch 148 | return tensor.argmax(dim=-1) 149 | 150 | def number_of_correct(pred, target): 151 | # count number of correct predictions 152 | return pred.eq(target).sum().item() 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser(description='generate label on audio') 156 | parser.add_argument('--model', 157 | type=str, 158 | default='ml-1m') 159 | parser.add_argument('--dataset', 160 | type=str, 161 | default='coat') 162 | args = parser.parse_args() 163 | 164 | model = torch.load(f'./clf_model_{args.model}.pt') 165 | audio_path = f'./data/{args.dataset}/' 166 | data = construct_data(audio_path) 167 | 168 | pre_data = GenDataset(data, audio_path) 169 | pre_loader = DataLoader( 170 | pre_data, 171 | batch_size=4, 172 | shuffle=True, 173 | collate_fn=collate_fn, 174 | num_workers=4, 175 | pin_memory=False, 176 | ) 177 | 178 | model.eval() 179 | age_correct = 0 180 | gender_correct = 0 181 | tag = 1 182 | for user, item, gender_label, age_label, waveform in tqdm.tqdm(pre_loader): 183 | waveform = waveform.to("cuda") 184 | age_label = age_label.to("cuda") 185 | gender_label = gender_label.to("cuda") 186 | 187 | age_logits, gender_logits = model(waveform) 188 | 189 | age_pred = get_likely_index(age_logits) # batch 190 | gender_pred = get_likely_index(gender_logits) 191 | 192 | age_correct += number_of_correct(age_pred, age_label) 193 | gender_correct += number_of_correct(gender_pred, gender_label) 194 | 195 | if tag == 1: 196 | users = user.numpy() 197 | items = item.numpy() 198 | genders = gender_pred.cpu().numpy() 199 | ages = age_pred.cpu().numpy() 200 | tag = 0 201 | else: 202 | users = np.hstack((users, user.numpy())) 203 | items = np.hstack((items, item.numpy())) 204 | genders = np.hstack((genders, gender_pred.cpu().numpy())) 205 | ages = np.hstack((ages, age_pred.cpu().numpy())) 206 | 207 | 208 | age_accu = age_correct / len(pre_loader.dataset) * 100. 209 | gender_accu = gender_correct / len(pre_loader.dataset) * 100. 210 | 211 | print(f"Age: {age_accu:.2f}% Gender: {gender_accu:.2f}%") 212 | result_save_path = './res/' 213 | if not os.path.exists(result_save_path): 214 | os.makedirs(result_save_path) 215 | data = pd.DataFrame({'user':list(users), 'item':list(items), 'gender':list(genders), 'age':list(ages)}) 216 | data.to_csv(f'./res/{args.dataset}_predict.csv', index=False) -------------------------------------------------------------------------------- /Speech/vits_lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 21 | iteration = checkpoint_dict['iteration'] 22 | learning_rate = checkpoint_dict['learning_rate'] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 25 | saved_state_dict = checkpoint_dict['model'] 26 | if hasattr(model, 'module'): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict= {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, 'module'): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info("Loaded checkpoint '{}' (iteration {})" .format( 42 | checkpoint_path, iteration)) 43 | return model, optimizer, learning_rate, iteration 44 | 45 | 46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 47 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 48 | iteration, checkpoint_path)) 49 | if hasattr(model, 'module'): 50 | state_dict = model.module.state_dict() 51 | else: 52 | state_dict = model.state_dict() 53 | torch.save({'model': state_dict, 54 | 'iteration': iteration, 55 | 'optimizer': optimizer.state_dict(), 56 | 'learning_rate': learning_rate}, checkpoint_path) 57 | 58 | 59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 60 | for k, v in scalars.items(): 61 | writer.add_scalar(k, v, global_step) 62 | for k, v in histograms.items(): 63 | writer.add_histogram(k, v, global_step) 64 | for k, v in images.items(): 65 | writer.add_image(k, v, global_step, dataformats='HWC') 66 | for k, v in audios.items(): 67 | writer.add_audio(k, v, global_step, audio_sampling_rate) 68 | 69 | 70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 71 | f_list = glob.glob(os.path.join(dir_path, regex)) 72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 73 | x = f_list[-1] 74 | print(x) 75 | return x 76 | 77 | 78 | def plot_spectrogram_to_numpy(spectrogram): 79 | global MATPLOTLIB_FLAG 80 | if not MATPLOTLIB_FLAG: 81 | import matplotlib 82 | matplotlib.use("Agg") 83 | MATPLOTLIB_FLAG = True 84 | mpl_logger = logging.getLogger('matplotlib') 85 | mpl_logger.setLevel(logging.WARNING) 86 | import matplotlib.pylab as plt 87 | import numpy as np 88 | 89 | fig, ax = plt.subplots(figsize=(10,2)) 90 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 91 | interpolation='none') 92 | plt.colorbar(im, ax=ax) 93 | plt.xlabel("Frames") 94 | plt.ylabel("Channels") 95 | plt.tight_layout() 96 | 97 | fig.canvas.draw() 98 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 99 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 100 | plt.close() 101 | return data 102 | 103 | 104 | def plot_alignment_to_numpy(alignment, info=None): 105 | global MATPLOTLIB_FLAG 106 | if not MATPLOTLIB_FLAG: 107 | import matplotlib 108 | matplotlib.use("Agg") 109 | MATPLOTLIB_FLAG = True 110 | mpl_logger = logging.getLogger('matplotlib') 111 | mpl_logger.setLevel(logging.WARNING) 112 | import matplotlib.pylab as plt 113 | import numpy as np 114 | 115 | fig, ax = plt.subplots(figsize=(6, 4)) 116 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 117 | interpolation='none') 118 | fig.colorbar(im, ax=ax) 119 | xlabel = 'Decoder timestep' 120 | if info is not None: 121 | xlabel += '\n\n' + info 122 | plt.xlabel(xlabel) 123 | plt.ylabel('Encoder timestep') 124 | plt.tight_layout() 125 | 126 | fig.canvas.draw() 127 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 128 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 129 | plt.close() 130 | return data 131 | 132 | 133 | def load_wav_to_torch(full_path): 134 | sampling_rate, data = read(full_path) 135 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 136 | 137 | 138 | def load_filepaths_and_text(filename, split="|"): 139 | with open(filename, encoding='utf-8') as f: 140 | filepaths_and_text = [line.strip().split(split) for line in f] 141 | return filepaths_and_text 142 | 143 | 144 | def get_hparams(init=True): 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 147 | help='JSON file for configuration') 148 | parser.add_argument('-m', '--model', type=str, required=True, 149 | help='Model name') 150 | 151 | args = parser.parse_args() 152 | model_dir = os.path.join("./logs", args.model) 153 | 154 | if not os.path.exists(model_dir): 155 | os.makedirs(model_dir) 156 | 157 | config_path = args.config 158 | config_save_path = os.path.join(model_dir, "config.json") 159 | if init: 160 | with open(config_path, "r") as f: 161 | data = f.read() 162 | with open(config_save_path, "w") as f: 163 | f.write(data) 164 | else: 165 | with open(config_save_path, "r") as f: 166 | data = f.read() 167 | config = json.loads(data) 168 | 169 | hparams = HParams(**config) 170 | hparams.model_dir = model_dir 171 | return hparams 172 | 173 | 174 | def get_hparams_from_dir(model_dir): 175 | config_save_path = os.path.join(model_dir, "config.json") 176 | with open(config_save_path, "r") as f: 177 | data = f.read() 178 | config = json.loads(data) 179 | 180 | hparams =HParams(**config) 181 | hparams.model_dir = model_dir 182 | return hparams 183 | 184 | 185 | def get_hparams_from_file(config_path): 186 | with open(config_path, "r") as f: 187 | data = f.read() 188 | config = json.loads(data) 189 | 190 | hparams =HParams(**config) 191 | return hparams 192 | 193 | 194 | def check_git_hash(model_dir): 195 | source_dir = os.path.dirname(os.path.realpath(__file__)) 196 | if not os.path.exists(os.path.join(source_dir, ".git")): 197 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 198 | source_dir 199 | )) 200 | return 201 | 202 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 203 | 204 | path = os.path.join(model_dir, "githash") 205 | if os.path.exists(path): 206 | saved_hash = open(path).read() 207 | if saved_hash != cur_hash: 208 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 209 | saved_hash[:8], cur_hash[:8])) 210 | else: 211 | open(path, "w").write(cur_hash) 212 | 213 | 214 | def get_logger(model_dir, filename="train.log"): 215 | global logger 216 | logger = logging.getLogger(os.path.basename(model_dir)) 217 | logger.setLevel(logging.DEBUG) 218 | 219 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 220 | if not os.path.exists(model_dir): 221 | os.makedirs(model_dir) 222 | h = logging.FileHandler(os.path.join(model_dir, filename)) 223 | h.setLevel(logging.DEBUG) 224 | h.setFormatter(formatter) 225 | logger.addHandler(h) 226 | return logger 227 | 228 | 229 | class HParams(): 230 | def __init__(self, **kwargs): 231 | for k, v in kwargs.items(): 232 | if type(v) == dict: 233 | v = HParams(**v) 234 | self[k] = v 235 | 236 | def keys(self): 237 | return self.__dict__.keys() 238 | 239 | def items(self): 240 | return self.__dict__.items() 241 | 242 | def values(self): 243 | return self.__dict__.values() 244 | 245 | def __len__(self): 246 | return len(self.__dict__) 247 | 248 | def __getitem__(self, key): 249 | return getattr(self, key) 250 | 251 | def __setitem__(self, key, value): 252 | return setattr(self, key, value) 253 | 254 | def __contains__(self, key): 255 | return key in self.__dict__ 256 | 257 | def __repr__(self): 258 | return self.__dict__.__repr__() 259 | -------------------------------------------------------------------------------- /Dialogue/gen_movie.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import lightgbm as lgb 4 | import random 5 | import json 6 | import os 7 | from collections import defaultdict 8 | from movie_utils import * 9 | from config.movie_pattern import start_pattern, agent_pattern, user_pattern 10 | from config.thanks_pattern import movie_agent, movie_user 11 | from config.recommend_pattern import movie_rec 12 | from movie_attr import generate_country_dialogue, generate_genre_dialogue, generate_director_dialogue, generate_actor_dialogue 13 | 14 | 15 | def load_data(): 16 | movies = pd.read_csv("./data/ml-1m/ml-1m.csv") 17 | users = pd.read_csv("./data/ml-1m/users.dat", sep='::', names=['user_id', 'gender', 'age', 'occupation', 'zip'], engine='python') 18 | 19 | return (movies, users) 20 | 21 | def get_user_item_info(data): 22 | movies = data[0] 23 | users = data[1] 24 | user_record = defaultdict(set) 25 | user_info = {} 26 | item_info = {} 27 | for _, row in movies.iterrows(): 28 | user_id = int(row['user_id']) 29 | item_id = int(row['movie_id']) 30 | user_record[user_id].add(item_id) 31 | 32 | if user_id not in user_info.keys(): 33 | user_info[user_id] = {} 34 | age = list(users[users['user_id'] == user_id]['age'])[0] 35 | gender = list(users[users['user_id'] == user_id]['gender'])[0] 36 | user_info[user_id]['age'] = get_user_age(age) 37 | user_info[user_id]['gender'] = get_user_gender(gender) 38 | 39 | if item_id not in item_info.keys(): 40 | item_info[item_id] = {} 41 | item_info[item_id]['director'] = row['director'] 42 | item_info[item_id]['country'] = modify_country(row['country']) 43 | item_info[item_id]['actor'] = get_item_actor(row['actors']) 44 | item_info[item_id]['genre'] = get_item_genre(row['genres']) 45 | 46 | 47 | return user_record, user_info, item_info 48 | 49 | def calculate_user_preference(record, data): 50 | movies = data[0] 51 | columns = ["movie_id", "director", "actors", "country", "genres"] 52 | features_cols = ["director", "actors", "country", "genres"] 53 | user_preference = defaultdict(list) 54 | 55 | df_movie = movies[columns] 56 | df_movie = df_movie.drop_duplicates(keep='first') 57 | df_movie = df_movie.reset_index(drop=True) 58 | df_movie['director'] = pd.Categorical(df_movie['director']).codes 59 | df_movie['actors'] = pd.Categorical(df_movie['actors']).codes 60 | df_movie['country'] = pd.Categorical(df_movie['country']).codes 61 | df_movie['genres'] = pd.Categorical(df_movie['genres']).codes 62 | df_movie.insert(df_movie.shape[1], 'label', 0) 63 | 64 | for user in record: 65 | user_item_csv = df_movie.copy() 66 | record = list(user_record[user]) 67 | for movie in record: 68 | user_item_csv.loc[user_item_csv['movie_id']==movie, 'label'] = 1 69 | X = user_item_csv[features_cols] 70 | Y = user_item_csv.label 71 | 72 | cls = lgb.LGBMClassifier(importance_type='gain') 73 | cls.fit(X, Y) 74 | 75 | indices = np.argsort(cls.booster_.feature_importance(importance_type='gain')) 76 | feature = [features_cols[i] for i in indices] 77 | 78 | user_preference[user] = feature 79 | 80 | return user_preference 81 | 82 | def calculate_attr_weights(info, movie): 83 | genre_all = {} 84 | country_all = {} 85 | actor_all = {} 86 | director_all = {} 87 | 88 | for _, row in movie.iterrows(): 89 | item_id = int(row['movie_id']) 90 | 91 | genres = info[item_id]['genre'] 92 | for genre in genres: 93 | if genre not in genre_all.keys(): 94 | genre_all[genre] = 1 95 | else: 96 | genre_all[genre] += 1 97 | 98 | country = info[item_id]['country'] 99 | if country not in country_all.keys(): 100 | country_all[country] = 1 101 | else: 102 | country_all[country] += 1 103 | 104 | director = info[item_id]['director'] 105 | if director not in director_all.keys(): 106 | director_all[director] = 1 107 | else: 108 | director_all[director] += 1 109 | 110 | actors = info[item_id]['actor'] 111 | for actor in actors: 112 | if actor not in actor_all.keys(): 113 | actor_all[actor] = 1 114 | else: 115 | actor_all[actor] += 1 116 | 117 | other_directors = [k for k,v in director_all.items() if v < 50] #402, 10415 118 | other_actors = [k for k,v in actor_all.items() if v < 50] #1267, 33893 119 | 120 | genre_weight = [genre_all[i] for i in sorted(genre_all)] 121 | country_weight = [country_all[i] for i in sorted(country_all)] 122 | director_weight = [director_all[i] for i in sorted(director_all)] 123 | actor_weight = [actor_all[i] for i in sorted(actor_all)] 124 | 125 | return (genre_weight, country_weight, director_weight, actor_weight), (genre_all, country_all, director_all, actor_all), (other_directors, other_actors) 126 | 127 | 128 | if __name__ == '__main__': 129 | movie_data = load_data() 130 | user_record, user_info, item_info = get_user_item_info(movie_data) 131 | # user_preference = calculate_user_preference(user_record, movie_data) 132 | user_preference = np.load("./data/ml-1m/user_preference.npy", allow_pickle=True).item() 133 | weights, attr_counts, other_attr= calculate_attr_weights(item_info, movie_data[0]) 134 | print("data load complete") 135 | dialogue_info = {} 136 | dialogue_id = 0 137 | for idx, row in movie_data[0].iterrows(): 138 | user_id = int(row['user_id']) 139 | item_id = int(row['movie_id']) 140 | 141 | new_dialogue = {} 142 | new_dialogue["user_id"] = user_id 143 | new_dialogue["item_id"] = item_id 144 | 145 | new_dialogue["user_gender"] = user_info[user_id]["gender"] 146 | new_dialogue["user_age"] = user_info[user_id]["age"] 147 | new_dialogue["content"] = {} 148 | new_dialogue["content"]["start"] = random.choice(start_pattern) 149 | 150 | dialouge_order = user_preference[user_id] 151 | tmp_new_dialogue = [] 152 | 153 | for slot in dialouge_order: 154 | english = True 155 | if slot == "country": 156 | country_val = item_info[item_id]["country"] 157 | utterance = generate_country_dialogue(agent_pattern, user_pattern, country_val) 158 | elif slot == "genres": 159 | genre_val = item_info[item_id]["genre"] 160 | utterance = generate_genre_dialogue(agent_pattern, user_pattern, genre_val, attr_counts[0], weights[0]) 161 | elif slot == "director": 162 | director_val = item_info[item_id]["director"] 163 | utterance = generate_director_dialogue(agent_pattern, user_pattern, director_val, attr_counts[2], weights[2], other_attr[0]) 164 | english = check_in_english(utterance) 165 | if not english: 166 | break 167 | elif slot == "actors": 168 | actor_val = item_info[item_id]["actor"] 169 | utterance = generate_actor_dialogue(agent_pattern, user_pattern, actor_val, attr_counts[3], weights[3], other_attr[1]) 170 | english = check_in_english(utterance) 171 | if not english: 172 | break 173 | tmp_new_dialogue.append(utterance) 174 | 175 | if not english: 176 | continue 177 | end_dialogue = {} 178 | end_dialogue["rec"] = random.choice(movie_rec) 179 | end_dialogue["thanks_user"] = random.choice(movie_user) 180 | end_dialogue["thanks_agent"] = random.choice(movie_agent) 181 | tmp_new_dialogue.append(end_dialogue) 182 | print("finish:", dialogue_id) 183 | start_index = 0 184 | end_index = 0 185 | step = 0 186 | name = ["Q1", "A1", "Q2", "A2", "Q3", "A3", "Q4", "A4", "Q5", "A5", "Q6", "A6", "Q7", "A7", "Q8", "A8", "Q9", "A9", "Q10", "A10", "Q11", "A11", "Q12", "A12"] 187 | for dia in tmp_new_dialogue: 188 | end_index = len(dia) + end_index 189 | tmp_name = name[start_index : end_index] 190 | tmp_dia = [] 191 | for _, v in dia.items(): 192 | tmp_dia.append(v) 193 | for i, val in enumerate(tmp_name): 194 | new_dialogue["content"][val] = tmp_dia[i] 195 | start_index = end_index 196 | dialogue_info[dialogue_id] = new_dialogue 197 | dialogue_id = dialogue_id + 1 198 | 199 | res_path = './res/ml-1m/' 200 | if not os.path.exists(res_path): 201 | os.makedirs(res_path) 202 | with open(res_path + 'dialogue_info_ml-1m.json', 'w') as f: 203 | json.dump(dialogue_info, f, indent=4) 204 | 205 | -------------------------------------------------------------------------------- /Recommender/FMRecommender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | from daisy.utils.config import model_config, initializer_config, optimizer_config 11 | 12 | class PointFM(nn.Module): 13 | def __init__(self, 14 | user_num, 15 | item_num, 16 | factors=84, 17 | epochs=20, 18 | lr=0.001, 19 | reg_1 = 0.001, 20 | reg_2 = 0.001, 21 | loss_type='CL', 22 | optimizer='sgd', 23 | initializer='normal', 24 | gpuid='0', 25 | feature=1, 26 | early_stop=True): 27 | """ 28 | Point-wise FM Recommender Class 29 | Parameters 30 | ---------- 31 | user_num : int, the number of users 32 | item_num : int, the number of items 33 | factors : int, the number of latent factor 34 | epochs : int, number of training epochs 35 | lr : float, learning rate 36 | reg_1 : float, first-order regularization term 37 | reg_2 : float, second-order regularization term 38 | loss_type : str, loss function type 39 | optimizer : str, optimization method for training the algorithms 40 | initializer: str, parameter initializer 41 | gpuid : str, GPU ID 42 | early_stop : bool, whether to activate early stop mechanism 43 | 44 | """ 45 | super(PointFM, self).__init__() 46 | 47 | os.environ['CUDA_VISIBLE_DEVICES'] = gpuid 48 | cudnn.benchmark = True 49 | 50 | self.epochs = epochs 51 | self.lr = lr 52 | self.reg_1 = reg_1 53 | self.reg_2 = reg_2 54 | self.feature = feature 55 | 56 | self.embed_user = nn.Embedding(user_num, factors) 57 | self.embed_item = nn.Embedding(item_num, factors) 58 | 59 | self.u_bias = nn.Embedding(user_num, 1) 60 | self.i_bias = nn.Embedding(item_num, 1) 61 | 62 | if self.feature == 0: 63 | self.embed_gender = nn.Embedding(2, factors) 64 | self.g_bias = nn.Embedding(2, 1) 65 | elif self.feature == 1: 66 | self.embed_age = nn.Embedding(3, factors) 67 | self.a_bias = nn.Embedding(3, 1) 68 | elif self.feature == 2: 69 | self.embed_gender = nn.Embedding(2, factors) 70 | self.g_bias = nn.Embedding(2, 1) 71 | self.embed_age = nn.Embedding(3, factors) 72 | self.a_bias = nn.Embedding(3, 1) 73 | 74 | self.bias_ = nn.Parameter(torch.tensor([0.0])) 75 | 76 | # init weight 77 | nn.init.normal_(self.embed_user.weight) 78 | nn.init.normal_(self.embed_item.weight) 79 | if self.feature == 0: 80 | nn.init.normal_(self.embed_gender.weight) 81 | nn.init.constant_(self.g_bias.weight, 0.0) 82 | elif self.feature == 1: 83 | nn.init.normal_(self.embed_age.weight) 84 | nn.init.constant_(self.a_bias.weight, 0.0) 85 | elif self.feature == 2: 86 | nn.init.normal_(self.embed_gender.weight) 87 | nn.init.constant_(self.g_bias.weight, 0.0) 88 | nn.init.normal_(self.embed_age.weight) 89 | nn.init.constant_(self.a_bias.weight, 0.0) 90 | 91 | 92 | nn.init.constant_(self.u_bias.weight, 0.0) 93 | nn.init.constant_(self.i_bias.weight, 0.0) 94 | 95 | self.loss_type = loss_type 96 | self.optimizer = optimizer 97 | self.early_stop = early_stop 98 | 99 | def forward(self, user, item, gender=None, age=None): 100 | embed_user = self.embed_user(user) 101 | embed_item = self.embed_item(item) 102 | 103 | pred = (embed_user * embed_item).sum(dim=-1, keepdim=True) 104 | pred += self.u_bias(user) + self.i_bias(item) + self.bias_ 105 | if self.feature == 0: 106 | embed_gender = self.embed_gender(gender) 107 | pred += (embed_gender * embed_user).sum(dim=-1, keepdim=True) + (embed_gender * embed_item).sum(dim=-1, keepdim=True) 108 | pred += self.g_bias(gender) 109 | elif self.feature == 1: 110 | embed_age = self.embed_age(age) 111 | pred += (embed_age * embed_user).sum(dim=-1, keepdim=True) + (embed_age * embed_item).sum(dim=-1, keepdim=True) 112 | pred += self.a_bias(age) 113 | elif self.feature == 2: 114 | embed_gender = self.embed_gender(gender) 115 | embed_age = self.embed_age(age) 116 | pred += (embed_gender * embed_user).sum(dim=-1, keepdim=True) + (embed_gender * embed_item).sum(dim=-1, keepdim=True) 117 | pred += (embed_age * embed_user).sum(dim=-1, keepdim=True) + (embed_age * embed_item).sum(dim=-1, keepdim=True) 118 | pred += (embed_age * embed_gender).sum(dim=-1, keepdim=True) 119 | pred += self.g_bias(gender) + self.a_bias(age) 120 | 121 | return pred.view(-1) 122 | 123 | def fit(self, train_loader): 124 | if torch.cuda.is_available(): 125 | self.cuda() 126 | else: 127 | self.cpu() 128 | 129 | optimizer = optim.SGD(self.parameters(), lr=self.lr) 130 | 131 | if self.loss_type == 'CL': 132 | criterion = nn.BCEWithLogitsLoss(reduction='sum') 133 | elif self.loss_type == 'SL': 134 | criterion = nn.MSELoss(reduction='sum') 135 | else: 136 | raise ValueError(f'Invalid loss type: {self.loss_type}') 137 | 138 | last_loss = 0. 139 | for epoch in range(1, self.epochs + 1): 140 | self.train() 141 | 142 | current_loss = 0. 143 | # set process bar display 144 | pbar = tqdm(train_loader) 145 | pbar.set_description(f'[Epoch {epoch:03d}]') 146 | for user, item, gender, age, label in pbar: 147 | user = user.cuda() 148 | item = item.cuda() 149 | label = label.cuda() 150 | if self.feature == 0: 151 | gender = gender.cuda() 152 | elif self.feature == 1: 153 | age = age.cuda() 154 | elif self.feature == 2: 155 | gender = gender.cuda() 156 | age = age.cuda() 157 | 158 | self.zero_grad() 159 | if self.feature == 0: 160 | prediction = self.forward(user, item, gender=gender) 161 | elif self.feature == 1: 162 | prediction = self.forward(user, item, age=age) 163 | elif self.feature == 2: 164 | prediction = self.forward(user, item, gender=gender, age=age) 165 | else: 166 | prediction = self.forward(user, item) 167 | 168 | loss = criterion(prediction, label) 169 | loss += self.reg_1 * (self.embed_item.weight.norm(p=1) + self.embed_user.weight.norm(p=1)) 170 | loss += self.reg_2 * (self.embed_item.weight.norm() + self.embed_user.weight.norm()) 171 | 172 | if self.feature == 0: 173 | loss += self.reg_1 * (self.embed_gender.weight.norm(p=1)) 174 | loss += self.reg_2 * (self.embed_gender.weight.norm()) 175 | elif self.feature == 1: 176 | loss += self.reg_1 * (self.embed_age.weight.norm(p=1)) 177 | loss += self.reg_2 * (self.embed_age.weight.norm()) 178 | elif self.feature == 2: 179 | loss += self.reg_1 * (self.embed_gender.weight.norm(p=1) + self.embed_age.weight.norm(p=1)) 180 | loss += self.reg_2 * (self.embed_gender.weight.norm() + self.embed_age.weight.norm()) 181 | 182 | if torch.isnan(loss): 183 | raise ValueError(f'Loss=Nan or Infinity: current settings does not fit the recommender') 184 | 185 | loss.backward() 186 | optimizer.step() 187 | 188 | pbar.set_postfix(loss=loss.item()) 189 | current_loss += loss.item() 190 | 191 | self.eval() 192 | delta_loss = float(current_loss - last_loss) 193 | if (abs(delta_loss) < 1e-5) and self.early_stop: 194 | print('Satisfy early stop mechanism') 195 | break 196 | else: 197 | last_loss = current_loss 198 | 199 | def predict(self, u, i, g=None, a=None): 200 | if self.feature == 0: 201 | pred = self.forward(u, i, gender=g).cpu() 202 | elif self.feature == 1: 203 | pred = self.forward(u, i, age=a).cpu() 204 | elif self.feature == 2: 205 | pred = self.forward(u, i, gender=g, age=a).cpu() 206 | else: 207 | pred = self.forward(u, i).cpu() 208 | 209 | return pred 210 | -------------------------------------------------------------------------------- /Speech/vits_lib/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /Recommender/run_classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import torchaudio 6 | import argparse 7 | import time 8 | import tqdm as tqdm 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.utils.data import Dataset,DataLoader 14 | from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel, AutoConfig 15 | from sklearn.model_selection import train_test_split 16 | 17 | 18 | 19 | class Logger(object): 20 | def __init__(self, filename='default.log', stream=sys.stdout): 21 | self.terminal = stream 22 | self.log = open(filename, 'a') 23 | 24 | def write(self, message): 25 | self.terminal.write(message) 26 | self.log.write(message) 27 | 28 | def flush(self): 29 | pass 30 | 31 | class AudioDataset(Dataset): 32 | def __init__(self, datalist, audio_path, target_sample_rate, transformation=None): 33 | self.datalist = datalist 34 | self.audio_path = audio_path 35 | self.target_sample_rate = target_sample_rate 36 | self.transformation = None 37 | if transformation: 38 | self.transformation = transformation 39 | 40 | 41 | def __len__(self): 42 | return len(self.datalist) 43 | 44 | def __getitem__(self,idx): 45 | audio_file_path = os.path.join(self.audio_path, self.datalist[idx]) 46 | audio_name = self.datalist[idx] 47 | age_label, gender_label = self._get_label(audio_name) 48 | waveform, sample_rate = torchaudio.load(audio_file_path) 49 | if sample_rate != self.target_sample_rate: 50 | waveform = self._resample(waveform, sample_rate) 51 | 52 | return waveform, age_label, gender_label 53 | 54 | 55 | def _get_label(self, audio_name): 56 | name_list = audio_name.split('_') 57 | age_labels = {'under 20':0, '20-30':1, 'over 30':2} 58 | gender_labels = {'women':0, 'men':1} 59 | 60 | age = age_labels[name_list[-3]] 61 | gender = gender_labels[name_list[-2]] 62 | 63 | return age, gender 64 | 65 | def _resample(self, waveform, sample_rate): 66 | resampler = torchaudio.transforms.Resample(sample_rate,self.target_sample_rate) 67 | 68 | return resampler(waveform) 69 | 70 | 71 | class Wav2Vec2ClassificationModel(Wav2Vec2PreTrainedModel): 72 | def __init__(self, config, hidden_size, dropout): 73 | super().__init__(config) 74 | 75 | self.wav2vec2 = Wav2Vec2Model(config) 76 | self.hidden_size = hidden_size 77 | self.fc = nn.Linear(config.hidden_size, self.hidden_size) 78 | self.dropout = nn.Dropout(dropout) 79 | self.gender_fc = nn.Linear(self.hidden_size, 2) 80 | self.age_fc = nn.Linear(self.hidden_size, 3) 81 | self.tanh = nn.Tanh() 82 | self.relu = nn.ReLU() 83 | 84 | self.init_weights() 85 | 86 | def freeze_feature_extractor(self): 87 | self.wav2vec2.feature_extractor._freeze_parameters() 88 | 89 | def merged_strategy(self, hidden_states): 90 | outputs = torch.mean(hidden_states, dim=1) 91 | 92 | return outputs 93 | 94 | def forward( 95 | self, 96 | input_values, 97 | attention_mask=None, 98 | output_attentions=None, 99 | output_hidden_states=None, 100 | return_dict=None, 101 | labels=None, 102 | ): 103 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 104 | with torch.no_grad(): 105 | outputs = self.wav2vec2( 106 | input_values, 107 | attention_mask=attention_mask, 108 | output_attentions=output_attentions, 109 | output_hidden_states=output_hidden_states, 110 | return_dict=return_dict, 111 | ) 112 | hidden_states = outputs[0] 113 | x = self.merged_strategy(hidden_states) 114 | x = self.fc(x) 115 | x = self.relu(x) 116 | x = self.dropout(x) 117 | gender_logits = self.gender_fc(x) 118 | gender_logits = F.log_softmax(gender_logits, dim=-1) 119 | age_logits = self.age_fc(x) 120 | age_logits = F.log_softmax(age_logits, dim=-1) 121 | 122 | return age_logits, gender_logits 123 | 124 | 125 | def pad_sequence(batch): 126 | # Make all tensor in a batch the same length by padding with zeros 127 | batch = [item.t() for item in batch] 128 | batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.) 129 | return batch.permute(0, 2, 1) 130 | 131 | def collate_fn(batch): 132 | 133 | audios, age_labels, gender_labels = [], [], [] 134 | 135 | for waveform, age, gender in batch: 136 | audios += [waveform] 137 | age_labels += [torch.tensor(age)] 138 | gender_labels += [torch.tensor(gender)] 139 | 140 | audios = pad_sequence(audios) 141 | age_labels = torch.stack(age_labels) 142 | gender_labels = torch.stack(gender_labels) 143 | 144 | return audios.squeeze(dim=1), age_labels, gender_labels 145 | 146 | def train_single_epoch(model, dataloader, optimizer, device): 147 | model.train() 148 | losses = [] 149 | for waveform, age_label, gender_label in tqdm.tqdm(dataloader): 150 | waveform = waveform.to(device) 151 | age_label = age_label.to(device) 152 | gender_label = gender_label.to(device) 153 | 154 | age_logits, gender_logits = model(waveform) 155 | 156 | age_loss = F.nll_loss(age_logits, age_label) 157 | gender_loss = F.nll_loss(gender_logits, gender_label) 158 | loss = age_loss + gender_loss 159 | 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | losses.append(loss.item()) 164 | return losses 165 | 166 | def get_likely_index(tensor): 167 | # find most likely label index for each element in the batch 168 | return tensor.argmax(dim=-1) 169 | 170 | def number_of_correct(pred, target): 171 | # count number of correct predictions 172 | return pred.eq(target).sum().item() 173 | 174 | def test_val_single_epoch(model, dataloader, device): 175 | model.eval() 176 | age_correct = 0 177 | gender_correct = 0 178 | for waveform, age_label, gender_label in tqdm.tqdm(dataloader): 179 | waveform = waveform.to(device) 180 | age_label = age_label.to(device) 181 | gender_label = gender_label.to(device) 182 | 183 | age_logits, gender_logits = model(waveform) # batch * output 184 | 185 | age_pred = get_likely_index(age_logits) # batch 186 | gender_pred = get_likely_index(gender_logits) 187 | 188 | age_correct += number_of_correct(age_pred, age_label) 189 | gender_correct += number_of_correct(gender_pred, gender_label) 190 | 191 | age_accu = age_correct / len(dataloader.dataset) * 100. 192 | gender_accu = gender_correct / len(dataloader.dataset) * 100. 193 | 194 | return age_accu, gender_accu 195 | 196 | def train(model, train_loader, val_loader, optimizer, device, epochs): 197 | for epoch in tqdm.tqdm(range(epochs)): 198 | losses = train_single_epoch(model, train_loader, optimizer, device) 199 | loss = np.mean(losses) 200 | age_accu, gender_accu = test_val_single_epoch(model, val_loader, device) 201 | writer.add_scalar('loss', loss, epoch+1) 202 | writer.add_scalar('age_accuracy', age_accu, epoch+1) 203 | writer.add_scalar('gender_accuracy', gender_accu, epoch+1) 204 | print(f"\nTrain Epoch: {epoch + 1} Loss: {loss:.6f} Age: {age_accu:.2f}% Gender: {gender_accu:.2f}%") 205 | print('Finished Training') 206 | 207 | 208 | if __name__ == '__main__': 209 | parser = argparse.ArgumentParser(description='speech classification') 210 | parser.add_argument('--batch_size', 211 | type=int, 212 | default=4,) 213 | parser.add_argument('--hidden_size', 214 | type=int, 215 | default=1024) 216 | parser.add_argument('--dropout', 217 | type=float, 218 | default=0.2) 219 | parser.add_argument('--epochs', 220 | type=int, 221 | default=20) 222 | parser.add_argument('--lr', 223 | type=float, 224 | default=0.0001) 225 | parser.add_argument('--test', 226 | type=int, 227 | default=1) 228 | parser.add_argument('--dataset', 229 | type=str, 230 | default='ml-1m') 231 | args = parser.parse_args() 232 | 233 | cur_time = time.strftime("%Y-%m-%d-%H_%M_%S",time.localtime(time.time()))[5:] 234 | 235 | save_tb_log_path = f'./tb_log/{cur_time}_{args.batch_size}' 236 | if not os.path.exists('./tb_log/'): 237 | os.makedirs('./tb_log/') 238 | writer = SummaryWriter(save_tb_log_path, flush_secs=30) 239 | 240 | save_log = f'./log/{cur_time}_{args.lr}_{args.dropout}_{args.batch_size}.log' 241 | sys.stdout = Logger(save_log, sys.stdout) 242 | 243 | seed = 2023 244 | np.random.seed(seed) 245 | torch.manual_seed(seed) 246 | 247 | input_dir = f'./data/{args.dataset}/' 248 | datalist = os.listdir(input_dir) 249 | 250 | ratio_train = 0.8 251 | ratio_val = 0.1 252 | ratio_test = 0.1 253 | 254 | remaining, test_set = train_test_split(datalist, test_size=ratio_test, random_state=2023) 255 | ratio_remaining = 1 - ratio_test 256 | ratio_val_adjusted = ratio_val / ratio_remaining 257 | train_set, val_set = train_test_split(remaining, test_size=ratio_val_adjusted, random_state=2023) 258 | 259 | 260 | if torch.cuda.is_available(): 261 | device='cuda' 262 | torch.cuda.manual_seed_all(seed) 263 | torch.backends.cudnn.benchmark = False 264 | else: 265 | device='cpu' 266 | 267 | train_set = AudioDataset(train_set, input_dir, 16000) 268 | val_set = AudioDataset(val_set, input_dir, 16000) 269 | if args.test: 270 | test_set = AudioDataset(test_set, input_dir, 16000) 271 | 272 | if device == "cuda": 273 | num_workers = 1 274 | pin_memory = False 275 | else: 276 | num_workers = 0 277 | pin_memory = False 278 | 279 | train_loader = DataLoader( 280 | train_set, 281 | batch_size=args.batch_size, 282 | shuffle=True, 283 | collate_fn=collate_fn, 284 | num_workers=num_workers, 285 | pin_memory=pin_memory, 286 | ) 287 | 288 | val_loader = DataLoader( 289 | val_set, 290 | batch_size=args.batch_size, 291 | shuffle=False, 292 | drop_last=False, 293 | collate_fn=collate_fn, 294 | num_workers=num_workers, 295 | pin_memory=pin_memory, 296 | ) 297 | 298 | if args.test: 299 | test_loader = DataLoader( 300 | test_set, 301 | batch_size=args.batch_size, 302 | shuffle=False, 303 | drop_last=False, 304 | collate_fn=collate_fn, 305 | num_workers=num_workers, 306 | pin_memory=pin_memory, 307 | ) 308 | 309 | model_name_or_path = "lighteternal/wav2vec2-large-xlsr-53-greek" 310 | config = AutoConfig.from_pretrained( 311 | model_name_or_path, 312 | finetuning_task="wav2vec2_clf", 313 | ) 314 | 315 | clf_model = Wav2Vec2ClassificationModel.from_pretrained( 316 | model_name_or_path, 317 | config=config, 318 | hidden_size=args.hidden_size, 319 | dropout=args.dropout, 320 | ) 321 | 322 | clf_model = clf_model.to(device) 323 | 324 | optimizer = optim.Adam(clf_model.parameters(), lr=args.lr) 325 | clf_model.freeze_feature_extractor() 326 | 327 | train(clf_model, train_loader, val_loader, optimizer, device, args.epochs) 328 | 329 | if args.test: 330 | age_accu, gender_accu = test_val_single_epoch(clf_model, test_loader, device) 331 | torch.save(clf_model, f'./clf_model_{args.dataset}.pt') 332 | print(f"Test set: Age: {age_accu:.2f}% Gender: {gender_accu:.2f}%") 333 | 334 | -------------------------------------------------------------------------------- /Recommender/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | import pandas as pd 7 | import scipy.sparse as sp 8 | import torch.utils.data as data 9 | from tqdm import tqdm 10 | from sklearn.model_selection import train_test_split 11 | from utils import get_ur, build_candidates_set, precision_at_k, recall_at_k, map_at_k, hr_at_k, ndcg_at_k, mrr_at_k 12 | from FMRecommender import PointFM 13 | 14 | 15 | class Sample(object): 16 | def __init__(self, user_num, item_num, feature_num, num_ng=4): 17 | self.user_num = user_num 18 | self.item_num = item_num 19 | self.num_ng = num_ng 20 | self.feature_num = feature_num 21 | 22 | def transform(self, data, is_training=True): 23 | if not is_training: 24 | neg_set = [] 25 | for _, row in data.iterrows(): 26 | u = int(row['user']) 27 | i = int(row['item']) 28 | r = row['rating'] 29 | js = [] 30 | if self.feature_num == 0: 31 | neg_set.append([u, i, r, js]) 32 | else: 33 | g = int(row['gender']) 34 | a = int(row['age']) 35 | neg_set.append([u, i, g, a, r, js]) 36 | return neg_set 37 | 38 | user_num = self.user_num 39 | item_num = self.item_num 40 | pair_pos = sp.dok_matrix((user_num, item_num), dtype=np.float32) 41 | for _, row in data.iterrows(): 42 | pair_pos[int(row['user']), int(row['item'])] = 1.0 43 | neg_set = [] 44 | for _, row in data.iterrows(): 45 | u = int(row['user']) 46 | i = int(row['item']) 47 | r = row['rating'] 48 | js = [] 49 | for _ in range(self.num_ng): 50 | j = np.random.randint(item_num) 51 | while (u, j) in pair_pos: 52 | j = np.random.randint(item_num) 53 | js.append(j) 54 | if self.feature_num == 0: 55 | neg_set.append([u, i, r, js]) 56 | else: 57 | g = int(row['gender']) 58 | a = int(row['age']) 59 | neg_set.append([u, i, g, a, r, js]) 60 | return neg_set 61 | 62 | 63 | class FMData(data.Dataset): 64 | def __init__(self, neg_set, feature_num, is_training=True, neg_label_val=0.): 65 | super(FMData, self).__init__() 66 | self.features_fill = [] 67 | self.labels_fill = [] 68 | self.feature_num = feature_num 69 | self.neg_label = neg_label_val 70 | 71 | if self.feature_num == 0: 72 | self.init_normal(neg_set, is_training) 73 | else: 74 | self.init_double_feature(neg_set, is_training) 75 | 76 | def __len__(self): 77 | return len(self.labels_fill) 78 | 79 | def __getitem__(self, index): 80 | features = self.features_fill 81 | labels = self.labels_fill 82 | user = features[index][0] 83 | item = features[index][1] 84 | label = labels[index] 85 | if self.feature_num == 0: 86 | return user, item, item, item, label 87 | else: 88 | gender = features[index][2] 89 | age = features[index][3] 90 | return user, item, gender, age, label 91 | 92 | def init_normal(self, neg_set, is_training=True): 93 | for u, i, r, js in neg_set: 94 | self.features_fill.append([int(u), int(i)]) 95 | self.labels_fill.append(r) 96 | 97 | if is_training: 98 | for j in js: 99 | self.features_fill.append([int(u), int(j)]) 100 | self.labels_fill.append(self.neg_label) 101 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32) 102 | 103 | def init_double_feature(self, neg_set, is_training=True): 104 | for u, i, g, a, r, js in neg_set: 105 | self.features_fill.append([int(u), int(i), int(g), int(a)]) 106 | self.labels_fill.append(r) 107 | 108 | if is_training: 109 | for j in js: 110 | self.features_fill.append([int(u), int(j), int(g), int(a)]) 111 | self.labels_fill.append(self.neg_label) 112 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32) 113 | 114 | 115 | 116 | def age_map(age): 117 | age_dict = {'under 20':0, '20-30':1, 'over 30':2} 118 | 119 | return age_dict[age] 120 | 121 | def gender_map_ml(gender): 122 | gender_dict = {'women':0, 'men':1} 123 | 124 | return gender_dict[gender] 125 | 126 | def gender_map_coat(gender): 127 | gender_dict = {'women':1, 'men':0} 128 | 129 | return gender_dict[gender] 130 | 131 | 132 | def load_data(path): 133 | df = pd.read_csv(path) 134 | df['user'] = pd.Categorical(df['user']).codes 135 | df['item'] = pd.Categorical(df['item']).codes 136 | user_num = df['user'].nunique() 137 | item_num = df['item'].nunique() 138 | 139 | return df, len(df), user_num, item_num 140 | 141 | def get_user_info(df, feature_num): 142 | user_info = dict() 143 | for _, row in df.iterrows(): 144 | if feature_num == 0: 145 | user_info[int(row['user'])] = [int(row['item'])] 146 | else: 147 | user_info[int(row['user'])] = [int(row['gender']), int(row['age'])] 148 | 149 | return user_info 150 | 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser(description='fm recommender') 155 | parser.add_argument('--feature', 156 | type=str, 157 | default='user,item,gender,age,rating,normal') 158 | parser.add_argument('--val', 159 | type=int, 160 | default=1) 161 | parser.add_argument('--num_ng', 162 | type=int, 163 | default=5) 164 | parser.add_argument('--factors', 165 | type=int, 166 | default=16) 167 | parser.add_argument('--epochs', 168 | type=int, 169 | default=1) 170 | parser.add_argument('--lr', 171 | type=float, 172 | default=0.01) 173 | parser.add_argument('--batch_size', 174 | type=int, 175 | default=256) 176 | parser.add_argument('--dataset', 177 | type=str, 178 | default='coat') 179 | args = parser.parse_args() 180 | time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 181 | s_time = time[5:].replace(' ', '-').replace(':', '-') 182 | # seed = 1 183 | # np.random.seed(seed) 184 | # torch.manual_seed(seed) 185 | path = f'/home/workshop/dataset/lhy/Speech/recommend/data/{args.dataset}_mp3.csv' 186 | df, record_num, user_num, item_num = load_data(path) 187 | df.insert(df.shape[1], 'rating', 1) 188 | feature_list = args.feature.split(',') 189 | if 'gender' not in feature_list and 'age' not in feature_list: 190 | feature_num = 0 191 | elif 'gender' in feature_list and 'age' in feature_list and 'normal' not in feature_list: 192 | feature_num = 1 193 | elif 'gender' in feature_list and 'age' in feature_list and 'normal' in feature_list: 194 | feature_num = 2 195 | feature_list = ['user', 'item', 'age', 'gender', 'rating'] 196 | df = df[feature_list] 197 | if feature_num != 0: 198 | df['age'] = df['age'].apply(lambda age: age_map(age)) 199 | if args.dataset == 'coat': 200 | df['gender'] = df['gender'].apply(lambda gender: gender_map_coat(gender)) 201 | else: 202 | df['gender'] = df['gender'].apply(lambda gender: gender_map_ml(gender)) 203 | 204 | 205 | user_info = get_user_info(df, feature_num) 206 | 207 | ratio_train = 0.8 208 | ratio_val = 0.1 209 | ratio_test = 0.2 210 | 211 | train_set, test_set = train_test_split(df, test_size=ratio_test, random_state=2023) 212 | if args.val == 1: 213 | ratio_remaining = 1 - ratio_test 214 | ratio_val_adjusted = ratio_val / ratio_remaining 215 | train_set, val_set = train_test_split(train_set, test_size=ratio_val_adjusted, random_state=2023) 216 | 217 | test_ur = get_ur(test_set) 218 | total_train_ur = get_ur(train_set) 219 | item_pool = set(range(item_num)) 220 | sampler = Sample(user_num, item_num, feature_num, args.num_ng) 221 | neg_set = sampler.transform(train_set, is_training=True) 222 | print("data sample complete") 223 | train_dataset = FMData(neg_set, feature_num, is_training=True) 224 | model = PointFM(user_num, item_num, args.factors, args.epochs, args.lr, feature=feature_num) 225 | if torch.cuda.is_available(): 226 | # torch.cuda.manual_seed_all(seed) 227 | # torch.backends.cudnn.benchmark = False 228 | model = model.to("cuda") 229 | train_loader = data.DataLoader( 230 | train_dataset, 231 | batch_size=args.batch_size, 232 | shuffle=True, 233 | num_workers=4 234 | ) 235 | model.fit(train_loader) 236 | print('Start Calculating Metrics......') 237 | test_ucands = build_candidates_set(test_ur, total_train_ur, item_pool, 1000) 238 | print('') 239 | print('Generate recommend list...') 240 | print('') 241 | preds = {} 242 | for u in tqdm(test_ucands.keys()): 243 | if feature_num == 0: 244 | tmp = pd.DataFrame({ 245 | 'user': [u for _ in test_ucands[u]], 246 | 'item': test_ucands[u], 247 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense 248 | }) 249 | else: 250 | gender = user_info[u][0] 251 | age = user_info[u][1] 252 | tmp = pd.DataFrame({ 253 | 'user': [u for _ in test_ucands[u]], 254 | 'item': test_ucands[u], 255 | 'gender': [gender for _ in test_ucands[u]], 256 | 'age': [age for _ in test_ucands[u]], 257 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense 258 | }) 259 | tmp_neg_set = sampler.transform(tmp, is_training=False) 260 | tmp_dataset = FMData(tmp_neg_set, feature_num, is_training=False) 261 | tmp_loader = data.DataLoader( 262 | tmp_dataset, 263 | batch_size=1000, 264 | shuffle=False, 265 | num_workers=4 266 | ) 267 | for user, item, gender, age, label in tmp_loader: 268 | user = user.cuda() 269 | item = item.cuda() 270 | label = label.cuda() 271 | 272 | if feature_num != 0: 273 | gender = gender.cuda() 274 | age = age.cuda() 275 | prediction = model.predict(user, item, g=gender, a=age) 276 | else: 277 | prediction = model.predict(user, item) 278 | _, indices = torch.topk(prediction, 50) 279 | top_n = torch.take(torch.tensor(test_ucands[u]), indices).cpu().numpy() 280 | preds[u] = top_n 281 | 282 | for u in preds.keys(): 283 | preds[u] = [1 if i in test_ur[u] else 0 for i in preds[u]] 284 | 285 | print('Save metric@k result to res folder...') 286 | if feature_num == 0: 287 | result_save_path = f'./res/fm/{args.dataset}/wo_ag/' 288 | elif feature_num == 1: 289 | result_save_path = f'./res/fm/{args.dataset}/ag_audio/' 290 | elif feature_num == 2: 291 | result_save_path = f'./res/fm/{args.dataset}/ag_normal/' 292 | if not os.path.exists(result_save_path): 293 | os.makedirs(result_save_path) 294 | res = pd.DataFrame({'metric@K': ['pre', 'rec', 'hr', 'map', 'mrr', 'ndcg']}) 295 | for k in [1, 5, 10, 20, 30, 50]: 296 | tmp_preds = preds.copy() 297 | tmp_preds = {key: rank_list[:k] for key, rank_list in tmp_preds.items()} 298 | 299 | pre_k = np.mean([precision_at_k(r, k) for r in tmp_preds.values()]) 300 | rec_k = recall_at_k(tmp_preds, test_ur, k) 301 | hr_k = hr_at_k(tmp_preds, test_ur) 302 | map_k = map_at_k(tmp_preds.values()) 303 | mrr_k = mrr_at_k(tmp_preds, k) 304 | ndcg_k = np.mean([ndcg_at_k(r, k) for r in tmp_preds.values()]) 305 | 306 | if k == 10: 307 | print(f'Precision@{k}: {pre_k:.4f}') 308 | print(f'Recall@{k}: {rec_k:.4f}') 309 | print(f'HR@{k}: {hr_k:.4f}') 310 | print(f'MAP@{k}: {map_k:.4f}') 311 | print(f'MRR@{k}: {mrr_k:.4f}') 312 | print(f'NDCG@{k}: {ndcg_k:.4f}') 313 | 314 | res[k] = np.array([pre_k, rec_k, hr_k, map_k, mrr_k, ndcg_k]) 315 | 316 | common_prefix = f'{s_time}_{args.num_ng}_{args.factors}_{args.lr}_{args.batch_size}_{args.epochs}' 317 | 318 | res.to_csv( 319 | f'{result_save_path}{common_prefix}_results.csv', 320 | index=False 321 | ) -------------------------------------------------------------------------------- /Speech/vits_lib/attentions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | import vits_lib.commons as commons 8 | import vits_lib.modules as modules 9 | from vits_lib.modules import LayerNorm 10 | 11 | 12 | class Encoder(nn.Module): 13 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): 14 | super().__init__() 15 | self.hidden_channels = hidden_channels 16 | self.filter_channels = filter_channels 17 | self.n_heads = n_heads 18 | self.n_layers = n_layers 19 | self.kernel_size = kernel_size 20 | self.p_dropout = p_dropout 21 | self.window_size = window_size 22 | 23 | self.drop = nn.Dropout(p_dropout) 24 | self.attn_layers = nn.ModuleList() 25 | self.norm_layers_1 = nn.ModuleList() 26 | self.ffn_layers = nn.ModuleList() 27 | self.norm_layers_2 = nn.ModuleList() 28 | for i in range(self.n_layers): 29 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) 30 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 31 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 32 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 33 | 34 | def forward(self, x, x_mask): 35 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 36 | x = x * x_mask 37 | for i in range(self.n_layers): 38 | y = self.attn_layers[i](x, x, attn_mask) 39 | y = self.drop(y) 40 | x = self.norm_layers_1[i](x + y) 41 | 42 | y = self.ffn_layers[i](x, x_mask) 43 | y = self.drop(y) 44 | x = self.norm_layers_2[i](x + y) 45 | x = x * x_mask 46 | return x 47 | 48 | 49 | class Decoder(nn.Module): 50 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): 51 | super().__init__() 52 | self.hidden_channels = hidden_channels 53 | self.filter_channels = filter_channels 54 | self.n_heads = n_heads 55 | self.n_layers = n_layers 56 | self.kernel_size = kernel_size 57 | self.p_dropout = p_dropout 58 | self.proximal_bias = proximal_bias 59 | self.proximal_init = proximal_init 60 | 61 | self.drop = nn.Dropout(p_dropout) 62 | self.self_attn_layers = nn.ModuleList() 63 | self.norm_layers_0 = nn.ModuleList() 64 | self.encdec_attn_layers = nn.ModuleList() 65 | self.norm_layers_1 = nn.ModuleList() 66 | self.ffn_layers = nn.ModuleList() 67 | self.norm_layers_2 = nn.ModuleList() 68 | for i in range(self.n_layers): 69 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) 70 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 71 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 72 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 73 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 74 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 75 | 76 | def forward(self, x, x_mask, h, h_mask): 77 | """ 78 | x: decoder input 79 | h: encoder output 80 | """ 81 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 82 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 83 | x = x * x_mask 84 | for i in range(self.n_layers): 85 | y = self.self_attn_layers[i](x, x, self_attn_mask) 86 | y = self.drop(y) 87 | x = self.norm_layers_0[i](x + y) 88 | 89 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 90 | y = self.drop(y) 91 | x = self.norm_layers_1[i](x + y) 92 | 93 | y = self.ffn_layers[i](x, x_mask) 94 | y = self.drop(y) 95 | x = self.norm_layers_2[i](x + y) 96 | x = x * x_mask 97 | return x 98 | 99 | 100 | class MultiHeadAttention(nn.Module): 101 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): 102 | super().__init__() 103 | assert channels % n_heads == 0 104 | 105 | self.channels = channels 106 | self.out_channels = out_channels 107 | self.n_heads = n_heads 108 | self.p_dropout = p_dropout 109 | self.window_size = window_size 110 | self.heads_share = heads_share 111 | self.block_length = block_length 112 | self.proximal_bias = proximal_bias 113 | self.proximal_init = proximal_init 114 | self.attn = None 115 | 116 | self.k_channels = channels // n_heads 117 | self.conv_q = nn.Conv1d(channels, channels, 1) 118 | self.conv_k = nn.Conv1d(channels, channels, 1) 119 | self.conv_v = nn.Conv1d(channels, channels, 1) 120 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 121 | self.drop = nn.Dropout(p_dropout) 122 | 123 | if window_size is not None: 124 | n_heads_rel = 1 if heads_share else n_heads 125 | rel_stddev = self.k_channels**-0.5 126 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 127 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 128 | 129 | nn.init.xavier_uniform_(self.conv_q.weight) 130 | nn.init.xavier_uniform_(self.conv_k.weight) 131 | nn.init.xavier_uniform_(self.conv_v.weight) 132 | if proximal_init: 133 | with torch.no_grad(): 134 | self.conv_k.weight.copy_(self.conv_q.weight) 135 | self.conv_k.bias.copy_(self.conv_q.bias) 136 | 137 | def forward(self, x, c, attn_mask=None): 138 | q = self.conv_q(x) 139 | k = self.conv_k(c) 140 | v = self.conv_v(c) 141 | 142 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 143 | 144 | x = self.conv_o(x) 145 | return x 146 | 147 | def attention(self, query, key, value, mask=None): 148 | # reshape [b, d, t] -> [b, n_h, t, d_k] 149 | b, d, t_s, t_t = (*key.size(), query.size(2)) 150 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 151 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 152 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 153 | 154 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 155 | if self.window_size is not None: 156 | assert t_s == t_t, "Relative attention is only available for self-attention." 157 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 158 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) 159 | scores_local = self._relative_position_to_absolute_position(rel_logits) 160 | scores = scores + scores_local 161 | if self.proximal_bias: 162 | assert t_s == t_t, "Proximal bias is only available for self-attention." 163 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 164 | if mask is not None: 165 | scores = scores.masked_fill(mask == 0, -1e4) 166 | if self.block_length is not None: 167 | assert t_s == t_t, "Local attention is only available for self-attention." 168 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 169 | scores = scores.masked_fill(block_mask == 0, -1e4) 170 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 171 | p_attn = self.drop(p_attn) 172 | output = torch.matmul(p_attn, value) 173 | if self.window_size is not None: 174 | relative_weights = self._absolute_position_to_relative_position(p_attn) 175 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 176 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 177 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 178 | return output, p_attn 179 | 180 | def _matmul_with_relative_values(self, x, y): 181 | """ 182 | x: [b, h, l, m] 183 | y: [h or 1, m, d] 184 | ret: [b, h, l, d] 185 | """ 186 | ret = torch.matmul(x, y.unsqueeze(0)) 187 | return ret 188 | 189 | def _matmul_with_relative_keys(self, x, y): 190 | """ 191 | x: [b, h, l, d] 192 | y: [h or 1, m, d] 193 | ret: [b, h, l, m] 194 | """ 195 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 196 | return ret 197 | 198 | def _get_relative_embeddings(self, relative_embeddings, length): 199 | max_relative_position = 2 * self.window_size + 1 200 | # Pad first before slice to avoid using cond ops. 201 | pad_length = max(length - (self.window_size + 1), 0) 202 | slice_start_position = max((self.window_size + 1) - length, 0) 203 | slice_end_position = slice_start_position + 2 * length - 1 204 | if pad_length > 0: 205 | padded_relative_embeddings = F.pad( 206 | relative_embeddings, 207 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 208 | else: 209 | padded_relative_embeddings = relative_embeddings 210 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] 211 | return used_relative_embeddings 212 | 213 | def _relative_position_to_absolute_position(self, x): 214 | """ 215 | x: [b, h, l, 2*l-1] 216 | ret: [b, h, l, l] 217 | """ 218 | batch, heads, length, _ = x.size() 219 | # Concat columns of pad to shift from relative to absolute indexing. 220 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 221 | 222 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 223 | x_flat = x.view([batch, heads, length * 2 * length]) 224 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) 225 | 226 | # Reshape and slice out the padded elements. 227 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 228 | return x_final 229 | 230 | def _absolute_position_to_relative_position(self, x): 231 | """ 232 | x: [b, h, l, l] 233 | ret: [b, h, l, 2*l-1] 234 | """ 235 | batch, heads, length, _ = x.size() 236 | # padd along column 237 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 238 | x_flat = x.view([batch, heads, length**2 + length*(length -1)]) 239 | # add 0's in the beginning that will skew the elements after reshape 240 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 241 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 242 | return x_final 243 | 244 | def _attention_bias_proximal(self, length): 245 | """Bias for self-attention to encourage attention to close positions. 246 | Args: 247 | length: an integer scalar. 248 | Returns: 249 | a Tensor with shape [1, 1, length, length] 250 | """ 251 | r = torch.arange(length, dtype=torch.float32) 252 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 253 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 254 | 255 | 256 | class FFN(nn.Module): 257 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): 258 | super().__init__() 259 | self.in_channels = in_channels 260 | self.out_channels = out_channels 261 | self.filter_channels = filter_channels 262 | self.kernel_size = kernel_size 263 | self.p_dropout = p_dropout 264 | self.activation = activation 265 | self.causal = causal 266 | 267 | if causal: 268 | self.padding = self._causal_padding 269 | else: 270 | self.padding = self._same_padding 271 | 272 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 273 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 274 | self.drop = nn.Dropout(p_dropout) 275 | 276 | def forward(self, x, x_mask): 277 | x = self.conv_1(self.padding(x * x_mask)) 278 | if self.activation == "gelu": 279 | x = x * torch.sigmoid(1.702 * x) 280 | else: 281 | x = torch.relu(x) 282 | x = self.drop(x) 283 | x = self.conv_2(self.padding(x * x_mask)) 284 | return x * x_mask 285 | 286 | def _causal_padding(self, x): 287 | if self.kernel_size == 1: 288 | return x 289 | pad_l = self.kernel_size - 1 290 | pad_r = 0 291 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 292 | x = F.pad(x, commons.convert_pad_shape(padding)) 293 | return x 294 | 295 | def _same_padding(self, x): 296 | if self.kernel_size == 1: 297 | return x 298 | pad_l = (self.kernel_size - 1) // 2 299 | pad_r = self.kernel_size // 2 300 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 301 | x = F.pad(x, commons.convert_pad_shape(padding)) 302 | return x 303 | -------------------------------------------------------------------------------- /Speech/vits_lib/modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | import vits_lib.commons as commons 13 | from vits_lib.commons import init_weights, get_padding 14 | from vits_lib.transforms import piecewise_rational_quadratic_transform 15 | 16 | 17 | LRELU_SLOPE = 0.1 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | def __init__(self, channels, eps=1e-5): 22 | super().__init__() 23 | self.channels = channels 24 | self.eps = eps 25 | 26 | self.gamma = nn.Parameter(torch.ones(channels)) 27 | self.beta = nn.Parameter(torch.zeros(channels)) 28 | 29 | def forward(self, x): 30 | x = x.transpose(1, -1) 31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 32 | return x.transpose(1, -1) 33 | 34 | 35 | class ConvReluNorm(nn.Module): 36 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 37 | super().__init__() 38 | self.in_channels = in_channels 39 | self.hidden_channels = hidden_channels 40 | self.out_channels = out_channels 41 | self.kernel_size = kernel_size 42 | self.n_layers = n_layers 43 | self.p_dropout = p_dropout 44 | assert n_layers > 1, "Number of layers should be larger than 0." 45 | 46 | self.conv_layers = nn.ModuleList() 47 | self.norm_layers = nn.ModuleList() 48 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 49 | self.norm_layers.append(LayerNorm(hidden_channels)) 50 | self.relu_drop = nn.Sequential( 51 | nn.ReLU(), 52 | nn.Dropout(p_dropout)) 53 | for _ in range(n_layers-1): 54 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 55 | self.norm_layers.append(LayerNorm(hidden_channels)) 56 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 57 | self.proj.weight.data.zero_() 58 | self.proj.bias.data.zero_() 59 | 60 | def forward(self, x, x_mask): 61 | x_org = x 62 | for i in range(self.n_layers): 63 | x = self.conv_layers[i](x * x_mask) 64 | x = self.norm_layers[i](x) 65 | x = self.relu_drop(x) 66 | x = x_org + self.proj(x) 67 | return x * x_mask 68 | 69 | 70 | class DDSConv(nn.Module): 71 | """ 72 | Dialted and Depth-Separable Convolution 73 | """ 74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 75 | super().__init__() 76 | self.channels = channels 77 | self.kernel_size = kernel_size 78 | self.n_layers = n_layers 79 | self.p_dropout = p_dropout 80 | 81 | self.drop = nn.Dropout(p_dropout) 82 | self.convs_sep = nn.ModuleList() 83 | self.convs_1x1 = nn.ModuleList() 84 | self.norms_1 = nn.ModuleList() 85 | self.norms_2 = nn.ModuleList() 86 | for i in range(n_layers): 87 | dilation = kernel_size ** i 88 | padding = (kernel_size * dilation - dilation) // 2 89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 90 | groups=channels, dilation=dilation, padding=padding 91 | )) 92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 93 | self.norms_1.append(LayerNorm(channels)) 94 | self.norms_2.append(LayerNorm(channels)) 95 | 96 | def forward(self, x, x_mask, g=None): 97 | if g is not None: 98 | x = x + g 99 | for i in range(self.n_layers): 100 | y = self.convs_sep[i](x * x_mask) 101 | y = self.norms_1[i](y) 102 | y = F.gelu(y) 103 | y = self.convs_1x1[i](y) 104 | y = self.norms_2[i](y) 105 | y = F.gelu(y) 106 | y = self.drop(y) 107 | x = x + y 108 | return x * x_mask 109 | 110 | 111 | class WN(torch.nn.Module): 112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | self.hidden_channels =hidden_channels 116 | self.kernel_size = kernel_size, 117 | self.dilation_rate = dilation_rate 118 | self.n_layers = n_layers 119 | self.gin_channels = gin_channels 120 | self.p_dropout = p_dropout 121 | 122 | self.in_layers = torch.nn.ModuleList() 123 | self.res_skip_layers = torch.nn.ModuleList() 124 | self.drop = nn.Dropout(p_dropout) 125 | 126 | if gin_channels != 0: 127 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 129 | 130 | for i in range(n_layers): 131 | dilation = dilation_rate ** i 132 | padding = int((kernel_size * dilation - dilation) / 2) 133 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 134 | dilation=dilation, padding=padding) 135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 136 | self.in_layers.append(in_layer) 137 | 138 | # last one is not necessary 139 | if i < n_layers - 1: 140 | res_skip_channels = 2 * hidden_channels 141 | else: 142 | res_skip_channels = hidden_channels 143 | 144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 146 | self.res_skip_layers.append(res_skip_layer) 147 | 148 | def forward(self, x, x_mask, g=None, **kwargs): 149 | output = torch.zeros_like(x) 150 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 151 | 152 | if g is not None: 153 | g = self.cond_layer(g) 154 | 155 | for i in range(self.n_layers): 156 | x_in = self.in_layers[i](x) 157 | if g is not None: 158 | cond_offset = i * 2 * self.hidden_channels 159 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 160 | else: 161 | g_l = torch.zeros_like(x_in) 162 | 163 | acts = commons.fused_add_tanh_sigmoid_multiply( 164 | x_in, 165 | g_l, 166 | n_channels_tensor) 167 | acts = self.drop(acts) 168 | 169 | res_skip_acts = self.res_skip_layers[i](acts) 170 | if i < self.n_layers - 1: 171 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 172 | x = (x + res_acts) * x_mask 173 | output = output + res_skip_acts[:,self.hidden_channels:,:] 174 | else: 175 | output = output + res_skip_acts 176 | return output * x_mask 177 | 178 | def remove_weight_norm(self): 179 | if self.gin_channels != 0: 180 | torch.nn.utils.remove_weight_norm(self.cond_layer) 181 | for l in self.in_layers: 182 | torch.nn.utils.remove_weight_norm(l) 183 | for l in self.res_skip_layers: 184 | torch.nn.utils.remove_weight_norm(l) 185 | 186 | 187 | class ResBlock1(torch.nn.Module): 188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 189 | super(ResBlock1, self).__init__() 190 | self.convs1 = nn.ModuleList([ 191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 192 | padding=get_padding(kernel_size, dilation[0]))), 193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 194 | padding=get_padding(kernel_size, dilation[1]))), 195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 196 | padding=get_padding(kernel_size, dilation[2]))) 197 | ]) 198 | self.convs1.apply(init_weights) 199 | 200 | self.convs2 = nn.ModuleList([ 201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 202 | padding=get_padding(kernel_size, 1))), 203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 204 | padding=get_padding(kernel_size, 1))), 205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 206 | padding=get_padding(kernel_size, 1))) 207 | ]) 208 | self.convs2.apply(init_weights) 209 | 210 | def forward(self, x, x_mask=None): 211 | for c1, c2 in zip(self.convs1, self.convs2): 212 | xt = F.leaky_relu(x, LRELU_SLOPE) 213 | if x_mask is not None: 214 | xt = xt * x_mask 215 | xt = c1(xt) 216 | xt = F.leaky_relu(xt, LRELU_SLOPE) 217 | if x_mask is not None: 218 | xt = xt * x_mask 219 | xt = c2(xt) 220 | x = xt + x 221 | if x_mask is not None: 222 | x = x * x_mask 223 | return x 224 | 225 | def remove_weight_norm(self): 226 | for l in self.convs1: 227 | remove_weight_norm(l) 228 | for l in self.convs2: 229 | remove_weight_norm(l) 230 | 231 | 232 | class ResBlock2(torch.nn.Module): 233 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 234 | super(ResBlock2, self).__init__() 235 | self.convs = nn.ModuleList([ 236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 237 | padding=get_padding(kernel_size, dilation[0]))), 238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 239 | padding=get_padding(kernel_size, dilation[1]))) 240 | ]) 241 | self.convs.apply(init_weights) 242 | 243 | def forward(self, x, x_mask=None): 244 | for c in self.convs: 245 | xt = F.leaky_relu(x, LRELU_SLOPE) 246 | if x_mask is not None: 247 | xt = xt * x_mask 248 | xt = c(xt) 249 | x = xt + x 250 | if x_mask is not None: 251 | x = x * x_mask 252 | return x 253 | 254 | def remove_weight_norm(self): 255 | for l in self.convs: 256 | remove_weight_norm(l) 257 | 258 | 259 | class Log(nn.Module): 260 | def forward(self, x, x_mask, reverse=False, **kwargs): 261 | if not reverse: 262 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 263 | logdet = torch.sum(-y, [1, 2]) 264 | return y, logdet 265 | else: 266 | x = torch.exp(x) * x_mask 267 | return x 268 | 269 | 270 | class Flip(nn.Module): 271 | def forward(self, x, *args, reverse=False, **kwargs): 272 | x = torch.flip(x, [1]) 273 | if not reverse: 274 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 275 | return x, logdet 276 | else: 277 | return x 278 | 279 | 280 | class ElementwiseAffine(nn.Module): 281 | def __init__(self, channels): 282 | super().__init__() 283 | self.channels = channels 284 | self.m = nn.Parameter(torch.zeros(channels,1)) 285 | self.logs = nn.Parameter(torch.zeros(channels,1)) 286 | 287 | def forward(self, x, x_mask, reverse=False, **kwargs): 288 | if not reverse: 289 | y = self.m + torch.exp(self.logs) * x 290 | y = y * x_mask 291 | logdet = torch.sum(self.logs * x_mask, [1,2]) 292 | return y, logdet 293 | else: 294 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 295 | return x 296 | 297 | 298 | class ResidualCouplingLayer(nn.Module): 299 | def __init__(self, 300 | channels, 301 | hidden_channels, 302 | kernel_size, 303 | dilation_rate, 304 | n_layers, 305 | p_dropout=0, 306 | gin_channels=0, 307 | mean_only=False): 308 | assert channels % 2 == 0, "channels should be divisible by 2" 309 | super().__init__() 310 | self.channels = channels 311 | self.hidden_channels = hidden_channels 312 | self.kernel_size = kernel_size 313 | self.dilation_rate = dilation_rate 314 | self.n_layers = n_layers 315 | self.half_channels = channels // 2 316 | self.mean_only = mean_only 317 | 318 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 319 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 320 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 321 | self.post.weight.data.zero_() 322 | self.post.bias.data.zero_() 323 | 324 | def forward(self, x, x_mask, g=None, reverse=False): 325 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 326 | h = self.pre(x0) * x_mask 327 | h = self.enc(h, x_mask, g=g) 328 | stats = self.post(h) * x_mask 329 | if not self.mean_only: 330 | m, logs = torch.split(stats, [self.half_channels]*2, 1) 331 | else: 332 | m = stats 333 | logs = torch.zeros_like(m) 334 | 335 | if not reverse: 336 | x1 = m + x1 * torch.exp(logs) * x_mask 337 | x = torch.cat([x0, x1], 1) 338 | logdet = torch.sum(logs, [1,2]) 339 | return x, logdet 340 | else: 341 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 342 | x = torch.cat([x0, x1], 1) 343 | return x 344 | 345 | 346 | class ConvFlow(nn.Module): 347 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 348 | super().__init__() 349 | self.in_channels = in_channels 350 | self.filter_channels = filter_channels 351 | self.kernel_size = kernel_size 352 | self.n_layers = n_layers 353 | self.num_bins = num_bins 354 | self.tail_bound = tail_bound 355 | self.half_channels = in_channels // 2 356 | 357 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 358 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 359 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 360 | self.proj.weight.data.zero_() 361 | self.proj.bias.data.zero_() 362 | 363 | def forward(self, x, x_mask, g=None, reverse=False): 364 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 365 | h = self.pre(x0) 366 | h = self.convs(h, x_mask, g=g) 367 | h = self.proj(h) * x_mask 368 | 369 | b, c, t = x0.shape 370 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 371 | 372 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 373 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) 374 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 375 | 376 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 377 | unnormalized_widths, 378 | unnormalized_heights, 379 | unnormalized_derivatives, 380 | inverse=reverse, 381 | tails='linear', 382 | tail_bound=self.tail_bound 383 | ) 384 | 385 | x = torch.cat([x0, x1], 1) * x_mask 386 | logdet = torch.sum(logabsdet * x_mask, [1,2]) 387 | if not reverse: 388 | return x, logdet 389 | else: 390 | return x 391 | -------------------------------------------------------------------------------- /Recommender/fm_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | import scipy.sparse as sp 8 | import torch.utils.data as data 9 | from tqdm import tqdm 10 | from FMRecommender import PointFM 11 | from utils import precision_at_k, recall_at_k, map_at_k, hr_at_k, ndcg_at_k, mrr_at_k 12 | from utils import get_ur, get_feature, build_candidates_set, get_user_info 13 | 14 | 15 | def age_map(age): 16 | if age < 20: 17 | return 0 18 | elif age >= 20 & age <= 30: 19 | return 1 20 | elif age > 30: 21 | return 2 22 | 23 | def vote_user_info(df, u, mode): 24 | user_df = df[df['user'] == u] 25 | if mode == 0: 26 | gender_series = user_df['gender'].value_counts() 27 | gender = int(gender_series.idxmax()) 28 | return gender 29 | elif mode == 1: 30 | age_series = user_df['age'].value_counts() 31 | age = int(age_series.idxmax()) 32 | return age 33 | else: 34 | gender_series = user_df['gender'].value_counts() 35 | age_series = user_df['age'].value_counts() 36 | 37 | gender = int(gender_series.idxmax()) 38 | age = int(age_series.idxmax()) 39 | 40 | return gender, age 41 | 42 | def load_data(path): 43 | df = pd.read_csv(path) 44 | df['user'] = pd.Categorical(df['user']).codes 45 | df['item'] = pd.Categorical(df['item']).codes 46 | user_num = df['user'].nunique() 47 | item_num = df['item'].nunique() 48 | print(user_num, item_num) 49 | return df, len(df), user_num, item_num 50 | 51 | def split_testset(data, test_size=0.2): 52 | split_idx = int(np.ceil(len(data) * (1 - test_size))) 53 | train, test = data.iloc[:split_idx, :].copy(), data.iloc[split_idx:, :].copy() 54 | 55 | return train, test 56 | 57 | class Sample(object): 58 | def __init__(self, user_num, item_num, feature_num, num_ng=4) -> None: 59 | self.user_num = user_num 60 | self.item_num = item_num 61 | self.num_ng = num_ng 62 | self.feature_num = feature_num 63 | 64 | def transform(self, data, is_training=True): 65 | if not is_training: 66 | neg_set = [] 67 | for _, row in data.iterrows(): 68 | u = int(row['user']) 69 | i = int(row['item']) 70 | r = row['rating'] 71 | js = [] 72 | if self.feature_num == 0: 73 | g = int(row['gender']) 74 | neg_set.append([u, i, g, r, js]) 75 | elif self.feature_num == 1: 76 | a = int(row['age']) 77 | neg_set.append([u, i, a, r, js]) 78 | elif self.feature_num == 2: 79 | g = int(row['gender']) 80 | a = int(row['age']) 81 | neg_set.append([u, i, g, a, r, js]) 82 | else: 83 | neg_set.append([u, i, r, js]) 84 | return neg_set 85 | 86 | user_num = self.user_num 87 | item_num = self.item_num 88 | pair_pos = sp.dok_matrix((user_num, item_num), dtype=np.float32) 89 | for _, row in data.iterrows(): 90 | pair_pos[int(row['user']), int(row['item'])] = 1.0 91 | neg_set = [] 92 | for _, row in data.iterrows(): 93 | u = int(row['user']) 94 | i = int(row['item']) 95 | r = row['rating'] 96 | js = [] 97 | for _ in range(self.num_ng): 98 | j = np.random.randint(item_num) 99 | while (u, j) in pair_pos: 100 | j = np.random.randint(item_num) 101 | js.append(j) 102 | if self.feature_num == 0: 103 | g = int(row['gender']) 104 | neg_set.append([u, i, g, r, js]) 105 | elif self.feature_num == 1: 106 | a = int(row['age']) 107 | neg_set.append([u, i, a, r, js]) 108 | elif self.feature_num == 2: 109 | g = int(row['gender']) 110 | a = int(row['age']) 111 | neg_set.append([u, i, g, a, r, js]) 112 | else: 113 | neg_set.append([u, i, r, js]) 114 | 115 | return neg_set 116 | 117 | 118 | class FMData(data.Dataset): 119 | def __init__(self, neg_set, feature_num, is_training=True, neg_label_val=0.) -> None: 120 | super(FMData, self).__init__() 121 | self.features_fill = [] 122 | self.labels_fill = [] 123 | self.feature_num = feature_num 124 | self.neg_label = neg_label_val 125 | 126 | if self.feature_num == 0: 127 | self.init_gender(neg_set, is_training) 128 | elif self.feature_num == 1: 129 | self.init_age(neg_set, is_training) 130 | elif self.feature_num == 2: 131 | self.init_gender_age(neg_set, is_training) 132 | else: 133 | self.init_normal(neg_set, is_training) 134 | 135 | def __len__(self): 136 | return len(self.labels_fill) 137 | 138 | def __getitem__(self, index): 139 | features = self.features_fill 140 | labels = self.labels_fill 141 | user = features[index][0] 142 | item = features[index][1] 143 | label = labels[index] 144 | if self.feature_num == 0: 145 | gender = features[index][2] 146 | return user, item, gender, gender, label 147 | elif self.feature_num == 1: 148 | age = features[index][2] 149 | return user, item, age, age, label 150 | elif self.feature_num == 2: 151 | gender = features[index][2] 152 | age = features[index][3] 153 | return user, item, gender, age, label 154 | else: 155 | return user, item, item, item, label 156 | 157 | def init_gender(self, neg_set, is_training=True): 158 | for u, i, g, r, js in neg_set: 159 | self.features_fill.append([int(u), int(i), int(g)]) 160 | self.labels_fill.append(r) 161 | 162 | if is_training: 163 | for j in js: 164 | self.features_fill.append([int(u), int(j), int(g)]) 165 | self.labels_fill.append(self.neg_label) 166 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32) 167 | 168 | def init_age(self, neg_set, is_training=True): 169 | for u, i, a, r, js in neg_set: 170 | self.features_fill.append([int(u), int(i), int(a)]) 171 | self.labels_fill.append(r) 172 | 173 | if is_training: 174 | for j in js: 175 | self.features_fill.append([int(u), int(j), int(a)]) 176 | self.labels_fill.append(self.neg_label) 177 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32) 178 | 179 | def init_gender_age(self, neg_set, is_training=True): 180 | for u, i, g, a, r, js in neg_set: 181 | self.features_fill.append([int(u), int(i), int(g), int(a)]) 182 | self.labels_fill.append(r) 183 | 184 | if is_training: 185 | for j in js: 186 | self.features_fill.append([int(u), int(j), int(g), int(a)]) 187 | self.labels_fill.append(self.neg_label) 188 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32) 189 | 190 | def init_normal(self, neg_set, is_training=True): 191 | for u, i, r, js in neg_set: 192 | self.features_fill.append([int(u), int(i)]) 193 | self.labels_fill.append(r) 194 | 195 | if is_training: 196 | for j in js: 197 | self.features_fill.append([int(u), int(j)]) 198 | self.labels_fill.append(self.neg_label) 199 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32) 200 | 201 | 202 | if __name__ == '__main__': 203 | parser = argparse.ArgumentParser(description='fm recommender') 204 | parser.add_argument('--feature', 205 | type=str, 206 | default='user,item,rating') 207 | parser.add_argument('--val', 208 | type=int, 209 | default=0) 210 | parser.add_argument('--num_ng', 211 | type=int, 212 | default=5) 213 | parser.add_argument('--factors', 214 | type=int, 215 | default=32) 216 | parser.add_argument('--epochs', 217 | type=int, 218 | default=50) 219 | parser.add_argument('--lr', 220 | type=float, 221 | default=0.01) 222 | parser.add_argument('--batch_size', 223 | type=int, 224 | default=64) 225 | parser.add_argument('--dataset', 226 | type=str, 227 | default='ml-1m') 228 | args = parser.parse_args() 229 | time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 230 | s_time = time[5:].replace(' ', '-').replace(':', '-') 231 | path = f'./res/{args.dataset}_predict.csv' 232 | df, record_num, user_num, item_num = load_data(path) 233 | df.insert(df.shape[1], 'rating', 1) 234 | feature_list = args.feature.split(',') 235 | feature_num = get_feature(feature_list) 236 | df_data = df[feature_list] 237 | user_info = get_user_info(df_data, feature_num) 238 | train_set, test_set = split_testset(df_data, 0.2) 239 | if args.val: 240 | train_set, test_set = split_testset(train_set, 0.1) 241 | test_ur = get_ur(test_set) 242 | total_train_ur = get_ur(train_set) 243 | item_pool = set(range(item_num)) 244 | sampler = Sample(user_num, item_num, feature_num, args.num_ng) 245 | neg_set = sampler.transform(train_set, is_training=True) 246 | print("data sample complete") 247 | train_dataset = FMData(neg_set, feature_num, is_training=True) 248 | model = PointFM(user_num, item_num, args.factors, args.epochs, args.lr, feature=feature_num) 249 | train_loader = data.DataLoader( 250 | train_dataset, 251 | batch_size=args.batch_size, 252 | shuffle=True, 253 | num_workers=4 254 | ) 255 | model.fit(train_loader) 256 | print('Start Calculating Metrics......') 257 | test_ucands = build_candidates_set(test_ur, total_train_ur, item_pool, 1000) 258 | print('') 259 | print('Generate recommend list...') 260 | print('') 261 | preds = {} 262 | for u in tqdm(test_ucands.keys()): 263 | if feature_num == 0: 264 | gender = vote_user_info(df_data, u, 0) 265 | tmp = pd.DataFrame({ 266 | 'user': [u for _ in test_ucands[u]], 267 | 'item': test_ucands[u], 268 | 'gender': [gender for _ in test_ucands[u]], 269 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense 270 | }) 271 | elif feature_num == 1: 272 | age = vote_user_info(df_data, u, 1) 273 | tmp = pd.DataFrame({ 274 | 'user': [u for _ in test_ucands[u]], 275 | 'item': test_ucands[u], 276 | 'age': [age for _ in test_ucands[u]], 277 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense 278 | }) 279 | elif feature_num == 2: 280 | gender, age = vote_user_info(df_data, u, 2) 281 | tmp = pd.DataFrame({ 282 | 'user': [u for _ in test_ucands[u]], 283 | 'item': test_ucands[u], 284 | 'gender': [gender for _ in test_ucands[u]], 285 | 'age': [age for _ in test_ucands[u]], 286 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense 287 | }) 288 | else: 289 | tmp = pd.DataFrame({ 290 | 'user': [u for _ in test_ucands[u]], 291 | 'item': test_ucands[u], 292 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense 293 | }) 294 | tmp_neg_set = sampler.transform(tmp, is_training=False) 295 | tmp_dataset = FMData(tmp_neg_set, feature_num, is_training=False) 296 | tmp_loader = data.DataLoader( 297 | tmp_dataset, 298 | batch_size=1000, 299 | shuffle=False, 300 | num_workers=4 301 | ) 302 | for user, item, gender, age, label in tmp_loader: 303 | user = user.cuda() 304 | item = item.cuda() 305 | label = label.cuda() 306 | if feature_num == 0: 307 | gender = gender.cuda() 308 | elif feature_num == 1: 309 | age = age.cuda() 310 | elif feature_num == 2: 311 | gender = gender.cuda() 312 | age = age.cuda() 313 | 314 | if feature_num == 0: 315 | prediction = model.predict(user, item, g=gender) 316 | elif feature_num == 1: 317 | prediction = model.predict(user, item, a=age) 318 | elif feature_num == 2: 319 | prediction = model.predict(user, item, g=gender, a=age) 320 | else: 321 | prediction = model.predict(user, item) 322 | _, indices = torch.topk(prediction, 50) 323 | top_n = torch.take(torch.tensor(test_ucands[u]), indices).cpu().numpy() 324 | preds[u] = top_n 325 | 326 | for u in preds.keys(): 327 | preds[u] = [1 if i in test_ur[u] else 0 for i in preds[u]] 328 | 329 | print('Save metric@k result to res folder...') 330 | if feature_num == 0: 331 | result_save_path = f'./res/fm_{args.dataset}_audio/gender/' 332 | elif feature_num == 1: 333 | result_save_path = f'./res/fm_{args.dataset}_audio/age/' 334 | elif feature_num == 2: 335 | result_save_path = f'./res/fm_{args.dataset}_audio/gender_age/' 336 | else: 337 | result_save_path = f'./res/fm_{args.dataset}_audio/no_ag/' 338 | if not os.path.exists(result_save_path): 339 | os.makedirs(result_save_path) 340 | res = pd.DataFrame({'metric@K': ['pre', 'rec', 'hr', 'map', 'mrr', 'ndcg']}) 341 | for k in [1, 5, 10, 20, 30, 50]: 342 | tmp_preds = preds.copy() 343 | tmp_preds = {key: rank_list[:k] for key, rank_list in tmp_preds.items()} 344 | 345 | pre_k = np.mean([precision_at_k(r, k) for r in tmp_preds.values()]) 346 | rec_k = recall_at_k(tmp_preds, test_ur, k) 347 | hr_k = hr_at_k(tmp_preds, test_ur) 348 | map_k = map_at_k(tmp_preds.values()) 349 | mrr_k = mrr_at_k(tmp_preds, k) 350 | ndcg_k = np.mean([ndcg_at_k(r, k) for r in tmp_preds.values()]) 351 | 352 | if k == 10: 353 | print(f'Precision@{k}: {pre_k:.4f}') 354 | print(f'Recall@{k}: {rec_k:.4f}') 355 | print(f'HR@{k}: {hr_k:.4f}') 356 | print(f'MAP@{k}: {map_k:.4f}') 357 | print(f'MRR@{k}: {mrr_k:.4f}') 358 | print(f'NDCG@{k}: {ndcg_k:.4f}') 359 | 360 | res[k] = np.array([pre_k, rec_k, hr_k, map_k, mrr_k, ndcg_k]) 361 | 362 | common_prefix = f'{args.num_ng}_{args.factors}_{args.lr}_{args.batch_size}_{args.val}' 363 | 364 | res.to_csv( 365 | f'{result_save_path}{common_prefix}_results.csv', 366 | index=False 367 | ) 368 | 369 | 370 | 371 | 372 | 373 | -------------------------------------------------------------------------------- /Dialogue/coat_attr.py: -------------------------------------------------------------------------------- 1 | import random 2 | from coat_utils import * 3 | 4 | gender_tag = {'0': ['00', '01', '02'], 5 | '1': ['10', '11']} 6 | 7 | jacket_tag = {'0': ['00', '01', '02'], 8 | '1': ['10', '11'], 9 | '2': ['20'], 10 | '3': ['30']} 11 | 12 | color_tag = {'0': ['00', '01', '02'], 13 | '1': ['10', '11'], 14 | '2': ['20'], 15 | '3': ['30']} 16 | 17 | def index_attr_pattern(utterance_pattern, tags, pattern_mode=None): 18 | ''' 19 | return: 20 | [ 21 | { 22 | "tag":['00'], 23 | "nl": "Do you want coats for men or women?" 24 | }, 25 | { 26 | "tag":['00'], 27 | "nl": "Do you want men's coats or women's coats?" 28 | }, 29 | ] 30 | ''' 31 | utterances = [] 32 | if pattern_mode != None: 33 | tag = tags[str(pattern_mode)] 34 | for utterance in utterance_pattern: 35 | if set(utterance["tag"]) & set(tag): 36 | utterances.append(utterance) 37 | else: 38 | for utterance in utterance_pattern: 39 | if set(utterance["tag"]) & set(tags): 40 | utterances.append(utterance) 41 | 42 | return utterances 43 | 44 | def check_repeat(tmp): 45 | tmp_set = set(tmp) 46 | if len(tmp) == len(tmp_set): 47 | return True 48 | else: 49 | return False 50 | 51 | def generate_gender_dialogue(agent_pattern, user_pattern, gender_val, gender_all, gender_weight): 52 | gender_dialogue = {} 53 | pattern_mode = random.choices([0, 1], weights=[12, 88], k=1)[0] 54 | agent_pattern = index_attr_pattern(agent_pattern["gender"], gender_tag, pattern_mode) 55 | utterance = random.choice(agent_pattern) 56 | utterance_content = utterance['nl'] 57 | utterance_tag = utterance['tag'] 58 | if pattern_mode == 0: 59 | gender_dialogue['Q'] = utterance_content 60 | utterance = random.choice(index_attr_pattern(user_pattern["gender"], utterance_tag)) 61 | utterance_content = utterance['nl'] 62 | utterance_content = utterance_content.replace('$gender$', gender_val, 1) 63 | gender_dialogue['A'] = utterance_content[0].upper() + utterance_content[1:] 64 | elif pattern_mode == 1: 65 | gender_tmp = random.choices(sorted(gender_all), weights=gender_weight, k=1) 66 | gender_tmp = get_item_gender(gender_tmp[0]) 67 | gender_dialogue['Q'] = utterance_content.replace('$gender$', gender_tmp, 1) 68 | if gender_tmp == gender_val: 69 | utterance = random.choice(index_attr_pattern(user_pattern["gender_pos"], utterance_tag)) 70 | else: 71 | utterance = random.choice(index_attr_pattern(user_pattern["gender_neg"], utterance_tag)) 72 | gender_dialogue['A'] = utterance['nl'].replace('$gender$', gender_tmp, 1) 73 | 74 | return gender_dialogue 75 | 76 | 77 | def gen_pattern_mode_one(attr, u_content, u_tag, g_tag, other, info): 78 | ''' 79 | u_content: Do you like $jacket$ coats? 80 | g_tag: 0 or 1, whether slot ground truth 81 | ''' 82 | val = info[0] 83 | all = info[1] 84 | weight = info[2] 85 | user_pattern = info[3] 86 | dialogue = {} 87 | tmp_other = other 88 | slot = '$' + attr + '$' 89 | if attr == 'jacket': 90 | pos = "jacket_pos" 91 | neg = "jacket_neg" 92 | add_val = get_item_type_index(val) 93 | get_item_func = get_item_type 94 | elif attr == 'color': 95 | pos = "color_pos" 96 | neg = "color_neg" 97 | add_val = get_item_color_index(val) 98 | get_item_func = get_item_color 99 | if g_tag: 100 | slot_val = val 101 | dialogue['Q'] = u_content.replace(slot, val, 1) 102 | utterance = random.choice(index_attr_pattern(user_pattern[pos], u_tag)) 103 | answer = utterance['nl'] 104 | if slot in answer: 105 | answer = answer.replace(slot, slot_val, 1) 106 | dialogue['A'] = answer 107 | tmp = [] 108 | else: 109 | tmp_other.append(add_val) 110 | 111 | while True: 112 | tmp = random.choices(sorted(all), weights=weight, k=1) 113 | if not set(tmp_other) & set(tmp): 114 | break 115 | dialogue['Q'] = u_content.replace(slot, get_item_func(tmp[0]), 1) 116 | utterance = random.choice(index_attr_pattern(user_pattern[neg], u_tag)) 117 | answer = utterance['nl'] 118 | if slot in answer: 119 | answer = answer.replace(slot, get_item_func(tmp[0]), 1) 120 | dialogue['A'] = answer[0].upper() + answer[1:] 121 | 122 | return dialogue['Q'], dialogue['A'], tmp 123 | 124 | 125 | def gen_pattern_mode_two(attr, u_content, u_tag, g_tag, other, info): 126 | ''' 127 | u_content: Do you prefer $jacket$ or $jacket$ coats? 128 | g_tag: 0 or 1, whether slot ground truth 129 | ''' 130 | val = info[0] 131 | all = info[1] 132 | weight = info[2] 133 | user_pattern = info[3] 134 | dialogue = {} 135 | tmp_other = other 136 | slot = '$' + attr + '$' 137 | if attr == 'jacket': 138 | pos = "jacket" 139 | neg = "jacket_neg" 140 | add_val = get_item_type_index(val) 141 | get_item_func = get_item_type 142 | elif attr == 'color': 143 | pos = "color" 144 | neg = "color_neg" 145 | add_val = get_item_color_index(val) 146 | get_item_func = get_item_color 147 | if g_tag: 148 | tmp_other.append(add_val) 149 | while True: 150 | tmp = random.choices(sorted(all), weights=weight, k=1) 151 | if not set(tmp_other) & set(tmp): 152 | break 153 | order = random.randint(0, 1) 154 | if order: 155 | u_content = u_content.replace(slot, get_item_func(tmp[0]), 1) 156 | u_content = u_content.replace(slot, val, 1) 157 | else: 158 | u_content = u_content.replace(slot, val, 1) 159 | u_content = u_content.replace(slot, get_item_func(tmp[0]), 1) 160 | dialogue['Q'] = u_content 161 | utterance = random.choice(index_attr_pattern(user_pattern[pos], u_tag)) 162 | utterance_content = utterance['nl'] 163 | utterance_content = utterance_content.replace(slot, val, 1) 164 | dialogue['A'] = utterance_content[0].upper() + utterance_content[1:] 165 | else: 166 | tmp_other.append(add_val) 167 | while True: 168 | tmp = random.choices(sorted(all), weights=weight, k=2) 169 | if (not set(tmp_other) & set(tmp)) and (check_repeat(tmp)): 170 | break 171 | random.shuffle(tmp) 172 | for v in tmp: 173 | u_content = u_content.replace(slot, get_item_func(v), 1) 174 | dialogue['Q'] = u_content 175 | utterance = random.choice(index_attr_pattern(user_pattern[neg], u_tag)) 176 | dialogue['A'] = utterance['nl'][0].upper() + utterance['nl'][1:] 177 | 178 | return dialogue['Q'], dialogue['A'], tmp 179 | 180 | 181 | def select_gen_pattern(mode): 182 | if mode == 1: 183 | func = gen_pattern_mode_one 184 | else: 185 | func = gen_pattern_mode_two 186 | 187 | return func 188 | 189 | 190 | def get_agent_content_tag(attr, pattern, mode, tag): 191 | pattern = index_attr_pattern(pattern[attr], tag, mode) 192 | utterance = random.choice(pattern) 193 | utterance_content = utterance['nl'] 194 | utterance_tag = utterance['tag'] 195 | 196 | return utterance_content, utterance_tag 197 | 198 | 199 | def generate_jacket_dialogue(agent_pattern, user_pattern, jacket_val, jacket_all, jacket_weight): 200 | jacket_dialogue = {} 201 | jacket_info = (jacket_val, jacket_all, jacket_weight, user_pattern) 202 | jacket_other = [6] 203 | if jacket_val == 'other': 204 | pattern_mode = 3 205 | else: 206 | pattern_mode = random.choices([0, 1, 2], weights=[10, 45, 45], k=1)[0] 207 | utterance_content, utterance_tag = get_agent_content_tag("jacket", agent_pattern, pattern_mode, jacket_tag) 208 | if pattern_mode == 0: 209 | jacket_dialogue['Q'] = utterance_content 210 | utterance = random.choice(index_attr_pattern(user_pattern["jacket"], utterance_tag)) 211 | utterance_content = utterance['nl'] 212 | utterance_content = utterance_content.replace('$jacket$', jacket_val, 1) 213 | jacket_dialogue['A'] = utterance_content[0].upper() + utterance_content[1:] 214 | elif pattern_mode == 1 or pattern_mode == 2: 215 | rounds = random.randint(1,3) 216 | gen_func = select_gen_pattern(pattern_mode) 217 | if rounds == 1: 218 | jacket_dialogue['Q'], jacket_dialogue['A'], _ = gen_func('jacket', utterance_content, utterance_tag, 1, jacket_other, jacket_info) 219 | elif rounds >= 2: 220 | jacket_dialogue['Q'], jacket_dialogue['A'], added_attr = gen_func('jacket', utterance_content, utterance_tag, 0, jacket_other, jacket_info) 221 | last_pm = pattern_mode 222 | pattern_mode = random.randint(1,2) 223 | gen_func = select_gen_pattern(pattern_mode) 224 | jacket_other = jacket_other + added_attr 225 | if rounds == 2: 226 | g_tag = 1 227 | else: 228 | g_tag = 0 229 | utterance_content, utterance_tag = get_agent_content_tag("jacket", agent_pattern, pattern_mode, jacket_tag) 230 | if last_pm == pattern_mode: 231 | jacket_dialogue['Q1'], jacket_dialogue['A1'], added_attr = gen_func('jacket', utterance_content, utterance_tag, g_tag, jacket_other, jacket_info) 232 | else: 233 | jacket_dialogue['Q1'], jacket_dialogue['A1'], added_attr = gen_func('jacket', utterance_content, utterance_tag, g_tag, jacket_other, jacket_info) 234 | if rounds == 3: 235 | g_tag = 1 236 | last_pm = pattern_mode 237 | pattern_mode = random.randint(1,2) 238 | gen_func = select_gen_pattern(pattern_mode) 239 | jacket_other = jacket_other + added_attr 240 | utterance_content, utterance_tag = get_agent_content_tag("jacket", agent_pattern, pattern_mode, jacket_tag) 241 | if last_pm == pattern_mode: 242 | jacket_dialogue['Q2'], jacket_dialogue['A2'], _ = gen_func('jacket', utterance_content, utterance_tag, g_tag, jacket_other, jacket_info) 243 | else: 244 | jacket_dialogue['Q2'], jacket_dialogue['A2'], _ = gen_func('jacket', utterance_content, utterance_tag, g_tag, jacket_other, jacket_info) 245 | elif pattern_mode == 3: 246 | rounds = random.randint(1,2) 247 | while True: 248 | added_attr = random.choices(sorted(jacket_all), weights=jacket_weight, k=3) 249 | if (not set(jacket_other) & set(added_attr)) and (check_repeat(added_attr)): 250 | break 251 | random.shuffle(added_attr) 252 | for v in added_attr: 253 | utterance_content = utterance_content.replace('$jacket$', get_item_type(v), 1) 254 | jacket_dialogue['Q'] = utterance_content 255 | utterance = random.choice(index_attr_pattern(user_pattern['jacket_neg'], utterance_tag)) 256 | jacket_dialogue['A'] = utterance['nl'][0].upper() + utterance['nl'][1:] 257 | if rounds == 2: 258 | pattern_mode = random.randint(1,3) 259 | utterance_content, utterance_tag = get_agent_content_tag("jacket", agent_pattern, pattern_mode, jacket_tag) 260 | jacket_other = jacket_other + added_attr 261 | if pattern_mode != 3: 262 | gen_func = select_gen_pattern(pattern_mode) 263 | jacket_dialogue['Q1'], jacket_dialogue['A1'], _ = gen_func('jacket', utterance_content, utterance_tag, 0, jacket_other, jacket_info) 264 | else: 265 | while True: 266 | added_attr = random.choices(sorted(jacket_all), weights=jacket_weight, k=3) 267 | if (not set(jacket_other) & set(added_attr)) and (check_repeat(added_attr)): 268 | break 269 | random.shuffle(added_attr) 270 | for v in added_attr: 271 | utterance_content = utterance_content.replace('$jacket$', get_item_type(v), 1) 272 | jacket_dialogue['Q1'] = utterance_content 273 | utterance = random.choice(index_attr_pattern(user_pattern['jacket_neg'], utterance_tag)) 274 | jacket_dialogue['A1'] = utterance['nl'][0].upper() + utterance['nl'][1:] 275 | 276 | return jacket_dialogue 277 | 278 | 279 | def generate_color_dialogue(agent_pattern, user_pattern, color_val, color_all, color_weight): 280 | color_dialogue = {} 281 | color_info = (color_val, color_all, color_weight, user_pattern) 282 | color_other = [9] 283 | if color_val == 'other': 284 | pattern_mode = 3 285 | else: 286 | pattern_mode = random.choices([0, 1, 2], weights=[20, 40, 40], k=1)[0] 287 | utterance_content, utterance_tag = get_agent_content_tag("color", agent_pattern, pattern_mode, color_tag) 288 | if pattern_mode == 0: 289 | color_dialogue['Q'] = utterance_content 290 | utterance = random.choice(index_attr_pattern(user_pattern["color"], utterance_tag)) 291 | utterance_content = utterance['nl'] 292 | utterance_content = utterance_content.replace('$color$', color_val, 1) 293 | color_dialogue['A'] = utterance_content[0].upper() + utterance_content[1:] 294 | elif pattern_mode == 1 or pattern_mode == 2: 295 | rounds = random.randint(1,3) 296 | gen_func = select_gen_pattern(pattern_mode) 297 | if rounds == 1: 298 | color_dialogue['Q'], color_dialogue['A'], _ = gen_func('color', utterance_content, utterance_tag, 1, color_other, color_info) 299 | elif rounds >= 2: 300 | color_dialogue['Q'], color_dialogue['A'], added_attr = gen_func('color', utterance_content, utterance_tag, 0, color_other, color_info) 301 | last_pm = pattern_mode 302 | pattern_mode = random.randint(1,2) 303 | gen_func = select_gen_pattern(pattern_mode) 304 | color_other = color_other + added_attr 305 | if rounds == 2: 306 | g_tag = 1 307 | else: 308 | g_tag = 0 309 | utterance_content, utterance_tag = get_agent_content_tag("color", agent_pattern, pattern_mode, color_tag) 310 | if last_pm == pattern_mode: 311 | color_dialogue['Q1'], color_dialogue['A1'], added_attr = gen_func('color', utterance_content, utterance_tag, g_tag, color_other, color_info) 312 | else: 313 | color_dialogue['Q1'], color_dialogue['A1'], added_attr = gen_func('color', utterance_content, utterance_tag, g_tag, color_other, color_info) 314 | if rounds == 3: 315 | g_tag = 1 316 | last_pm = pattern_mode 317 | pattern_mode = random.randint(1,2) 318 | gen_func = select_gen_pattern(pattern_mode) 319 | color_other = color_other + added_attr 320 | utterance_content, utterance_tag = get_agent_content_tag("color", agent_pattern, pattern_mode, color_tag) 321 | if last_pm == pattern_mode: 322 | color_dialogue['Q2'], color_dialogue['A2'], _ = gen_func('color', utterance_content, utterance_tag, g_tag, color_other, color_info) 323 | else: 324 | color_dialogue['Q2'], color_dialogue['A2'], _ = gen_func('color', utterance_content, utterance_tag, g_tag, color_other, color_info) 325 | elif pattern_mode == 3: 326 | rounds = random.randint(1,2) 327 | while True: 328 | added_attr = random.choices(sorted(color_all), weights=color_weight, k=3) 329 | if (not set(color_other) & set(added_attr)) and (check_repeat(added_attr)): 330 | break 331 | random.shuffle(added_attr) 332 | for v in added_attr: 333 | utterance_content = utterance_content.replace('$color$', get_item_color(v), 1) 334 | color_dialogue['Q'] = utterance_content 335 | utterance = random.choice(index_attr_pattern(user_pattern['color_neg'], utterance_tag)) 336 | color_dialogue['A'] = utterance['nl'][0].upper() + utterance['nl'][1:] 337 | if rounds == 2: 338 | pattern_mode = random.randint(1,3) 339 | utterance_content, utterance_tag = get_agent_content_tag("color", agent_pattern, pattern_mode, color_tag) 340 | color_other = color_other + added_attr 341 | if pattern_mode != 3: 342 | gen_func = select_gen_pattern(pattern_mode) 343 | color_dialogue['Q1'], color_dialogue['A1'], _ = gen_func('color', utterance_content, utterance_tag, 0, color_other, color_info) 344 | else: 345 | while True: 346 | added_attr = random.choices(sorted(color_all), weights=color_weight, k=3) 347 | if (not set(color_other) & set(added_attr)) and (check_repeat(added_attr)): 348 | break 349 | random.shuffle(added_attr) 350 | for v in added_attr: 351 | utterance_content = utterance_content.replace('$color$', get_item_color(v), 1) 352 | color_dialogue['Q1'] = utterance_content 353 | utterance = random.choice(index_attr_pattern(user_pattern['color_neg'], utterance_tag)) 354 | color_dialogue['A1'] = utterance['nl'][0].upper() + utterance['nl'][1:] 355 | 356 | return color_dialogue 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | --------------------------------------------------------------------------------