├── Recommender
├── data
│ └── .gitkeep
├── requirements.txt
├── README.md
├── utils.py
├── gen_label.py
├── FMRecommender.py
├── run_classifier.py
├── run.py
└── fm_audio.py
├── Speech
├── vits_lib
│ ├── pretrain
│ │ └── .gitkeep
│ ├── text
│ │ ├── __pycache__
│ │ │ ├── symbols.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── cleaners.cpython-38.pyc
│ │ ├── symbols.py
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ └── cleaners.py
│ ├── monotonic_align
│ │ ├── setup.py
│ │ ├── __init__.py
│ │ └── core.pyx
│ ├── configs
│ │ ├── ljs_base.json
│ │ ├── ljs_nosdp.json
│ │ └── vctk_base.json
│ ├── commons.py
│ ├── utils.py
│ ├── transforms.py
│ ├── attentions.py
│ └── modules.py
├── requirements.txt
├── README.md
├── inference_user.py
├── inference.py
├── data
│ └── speaker_info
│ │ ├── speaker-info.txt
│ │ └── vctk_audio_sid_text_val_filelist.txt.cleaned.txt
└── utils.py
├── .gitignore
├── images
├── logo.png
└── framework.png
├── Dialogue
├── requirements.txt
├── data
│ ├── .DS_Store
│ └── ml-1m
│ │ └── user_preference.npy
├── README.md
├── config
│ ├── recommend_pattern.py
│ └── thanks_pattern.py
├── movie_utils.py
├── coat_utils.py
├── gen_coat.py
├── gen_movie.py
└── coat_attr.py
├── Evaluate
├── requirements.txt
├── evaluate.py
└── fed.py
└── README.md
/Recommender/data/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Speech/vits_lib/pretrain/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | */.DS_Store
3 | .vscode
4 | */.vscode
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyllll/VCRS/HEAD/images/logo.png
--------------------------------------------------------------------------------
/Dialogue/requirements.txt:
--------------------------------------------------------------------------------
1 | random
2 | numpy
3 | pandas
4 | lightgbm
5 | re
6 | json
--------------------------------------------------------------------------------
/images/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyllll/VCRS/HEAD/images/framework.png
--------------------------------------------------------------------------------
/Dialogue/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Dialogue/data/.DS_Store
--------------------------------------------------------------------------------
/Evaluate/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch
3 | transformers
4 | random
5 | json
6 | argparse
--------------------------------------------------------------------------------
/Recommender/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchaudio
3 | argparse
4 | tqdm
5 | transformers
6 | sklearn
--------------------------------------------------------------------------------
/Dialogue/data/ml-1m/user_preference.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Dialogue/data/ml-1m/user_preference.npy
--------------------------------------------------------------------------------
/Speech/vits_lib/text/__pycache__/symbols.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Speech/vits_lib/text/__pycache__/symbols.cpython-38.pyc
--------------------------------------------------------------------------------
/Speech/vits_lib/text/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Speech/vits_lib/text/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/Speech/vits_lib/text/__pycache__/cleaners.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyllll/VCRS/HEAD/Speech/vits_lib/text/__pycache__/cleaners.cpython-38.pyc
--------------------------------------------------------------------------------
/Speech/requirements.txt:
--------------------------------------------------------------------------------
1 | Cython==0.29.21
2 | librosa==0.8.0
3 | matplotlib==3.3.1
4 | phonemizer==2.2.1
5 | scipy==1.5.2
6 | tensorboard==2.3.0
7 | Unidecode==1.1.1
8 | torch
9 | torchvision
10 | pydub
11 | random
12 | tqdm
13 | ipython
14 | numpy
--------------------------------------------------------------------------------
/Speech/vits_lib/monotonic_align/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | import numpy
4 |
5 | setup(
6 | name = 'monotonic_align',
7 | ext_modules = cythonize("core.pyx"),
8 | include_dirs=[numpy.get_include()]
9 | )
10 |
--------------------------------------------------------------------------------
/Speech/vits_lib/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Defines the set of symbols used in text input to the model.
5 | '''
6 | _pad = '_'
7 | _punctuation = ';:,.!?¡¿—…"«»“” '
8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
10 |
11 |
12 | # Export all symbols:
13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
14 |
15 | # Special symbol ids
16 | SPACE_ID = symbols.index(" ")
17 |
--------------------------------------------------------------------------------
/Speech/vits_lib/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from .monotonic_align.core import maximum_path_c
4 |
5 |
6 | def maximum_path(neg_cent, mask):
7 | """ Cython optimized version.
8 | neg_cent: [b, t_t, t_s]
9 | mask: [b, t_t, t_s]
10 | """
11 | device = neg_cent.device
12 | dtype = neg_cent.dtype
13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
14 | path = np.zeros(neg_cent.shape, dtype=np.int32)
15 |
16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max)
19 | return torch.from_numpy(path).to(device=device, dtype=dtype)
20 |
--------------------------------------------------------------------------------
/Speech/vits_lib/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/Dialogue/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | This is the code for text-based conversation generation
3 |
4 | ## Installtion
5 | install dependencies via:
6 | ```
7 | cd ./Dialogue/
8 | pip install -r requirements.txt
9 | ```
10 |
11 | ## Prepare Data
12 | 1. If you want to generate conversation in the moive domain, please download the [ml-1m.csv](https://drive.google.com/file/d/1iOum0fcgPzyvV5Mj8EuNdgtd0eLPz2qt/view?usp=sharing) and put it in `./data/ml-1m/`. The ml-1m.csv file contains the movie information scraped from the imdb website, and this file is divided into 10 columns:
13 |
14 | ` | user_id | item_id | rating | timestamp | movie | director | actors | country | title | genres |`
15 |
16 | 2. Corresponding templates are in the `./config/` directory. You can also add some new templates to enrich the conversation.
17 |
18 | 3. For convenience, we also summarize the templates we use in the [document](https://fssntlo70a.feishu.cn/sheets/shtcn5rtVsQa69LA1efCJpWeIEd).
19 |
20 | ## RUN
21 | Generate conversation in the e-commerce domain.
22 | ```
23 | python gen_coat.py
24 | ```
25 | Generate conversation in the movie domain.
26 | ```
27 | python gen_movie.py
28 | ```
29 | You can find generated conversation in `./res/`, such as `dialogue_info_coat.json` or `dialogue_info_ml-1m.json`.
--------------------------------------------------------------------------------
/Speech/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | This is the code for voice-based conversation generation
3 |
4 | ## Installtion
5 | 1. install dependencies via:
6 | ```
7 | cd ./Speech/
8 | pip install -r requirements.txt
9 | ```
10 | 2. Build Monotonic Alignment Search
11 | ```
12 | cd ./vits_lib/monotonic_align/
13 | python3 setup.py build_ext --inplace
14 | ```
15 |
16 | ## Pretrained model
17 | 1. Download [pretrained VITS models](https://drive.google.com/drive/folders/1ksarh-cJf3F5eKJjLVWY0X1j1qsQqiS2) from GoogleDrive. The pretrained models are provided by [VITS Repo](https://github.com/jaywalnut310/vits).
18 |
19 | 2. Put the pretrained models in the `./vits_lib/pretrain/` directory.
20 |
21 | ## RUN
22 | 1. Generate audio containing only the content of the user's conversation.
23 | ```
24 | python inference_user.py --dataset='xxx'
25 | ```
26 | ```xxx``` is ```coat``` or ```ml-1m```
27 | 2. Generate audio containing full conversations (i.e., users and agents).
28 | ```
29 | python inference.py --dataset='xxx'
30 | ```
31 | ```xxx``` is ```coat``` or ```ml-1m```
32 | 3. Note that all python scripts are used to generate complete audio on ml-1m, you can generate a certain number of audio according to your needs.
33 | 4. All results are saved in `./speech_res/` directory.
34 |
--------------------------------------------------------------------------------
/Speech/vits_lib/monotonic_align/core.pyx:
--------------------------------------------------------------------------------
1 | cimport cython
2 | from cython.parallel import prange
3 |
4 |
5 | @cython.boundscheck(False)
6 | @cython.wraparound(False)
7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
8 | cdef int x
9 | cdef int y
10 | cdef float v_prev
11 | cdef float v_cur
12 | cdef float tmp
13 | cdef int index = t_x - 1
14 |
15 | for y in range(t_y):
16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
17 | if x == y:
18 | v_cur = max_neg_val
19 | else:
20 | v_cur = value[y-1, x]
21 | if x == 0:
22 | if y == 0:
23 | v_prev = 0.
24 | else:
25 | v_prev = max_neg_val
26 | else:
27 | v_prev = value[y-1, x-1]
28 | value[y, x] += max(v_prev, v_cur)
29 |
30 | for y in range(t_y - 1, -1, -1):
31 | path[y, index] = 1
32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
33 | index = index - 1
34 |
35 |
36 | @cython.boundscheck(False)
37 | @cython.wraparound(False)
38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
39 | cdef int b = paths.shape[0]
40 | cdef int i
41 | for i in prange(b, nogil=True):
42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
43 |
--------------------------------------------------------------------------------
/Speech/vits_lib/configs/ljs_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 64,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned",
21 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned",
22 | "text_cleaners":["english_cleaners2"],
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 22050,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null,
31 | "add_blank": true,
32 | "n_speakers": 0,
33 | "cleaned_text": true
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/Speech/vits_lib/configs/ljs_nosdp.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 64,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "training_files":"filelists/ljs_audio_text_train_filelist.txt.cleaned",
21 | "validation_files":"filelists/ljs_audio_text_val_filelist.txt.cleaned",
22 | "text_cleaners":["english_cleaners2"],
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 22050,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null,
31 | "add_blank": true,
32 | "n_speakers": 0,
33 | "cleaned_text": true
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false,
51 | "use_sdp": false
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/Speech/vits_lib/configs/vctk_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 10000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 64,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "training_files":"filelists/vctk_audio_sid_text_train_filelist.txt.cleaned",
21 | "validation_files":"filelists/vctk_audio_sid_text_val_filelist.txt.cleaned",
22 | "text_cleaners":["english_cleaners2"],
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 22050,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null,
31 | "add_blank": true,
32 | "n_speakers": 109,
33 | "cleaned_text": true
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false,
51 | "gin_channels": 256
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/Recommender/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | This is the code for the exploration of the two VCRSs datasets.
3 |
4 | ## Installtion
5 | install dependencies via:
6 | ```
7 | cd ./Recommender/
8 | pip install -r requirements.txt
9 | ```
10 |
11 | ## Prepare Data
12 | 1. Download datasets from Google Drive: [coat.tar.gz](https://drive.google.com/file/d/1FnpYhMaeskckxGheKjar0U4YHIdDKM6K/view?usp=share_link) and [ml-1m.tar.gz](https://drive.google.com/file/d/195ugsUrU51VMUjjI329M84qegtK2QuGC/view?usp=sharing)
13 |
14 | 2. Put the dataset in the `./data/` directory.
15 |
16 | 3. `tar -zxvf xxx.tar.gz`
17 |
18 | ## Multi-task Classification Module (MCM)
19 | MCM aims to extract explicit semantic features (e.g., user age) from our created VCRS datasets.
20 | * Extract features on Coat dataset
21 | ```
22 | python run_classifier.py --dataset='coat' --batch_size=64 --hidden_size=1024 --dropout=0.2 --epochs=20 --lr=0.0001 --test=1
23 | ```
24 |
25 | * Extract features on ML-1M dataset
26 | ```
27 | python run_classifier.py --dataset='ml-1m' --batch_size=64 --hidden_size=1024 --dropout=0.2 --epochs=20 --lr=0.0001 --test=1
28 | ```
29 |
30 | * The trained model is saved as `clf_model_xxx.pt`
31 |
32 | ## Voice Feature Integration Module (VFIM)
33 | VFIM seeks to integrate the extracted voice-related features into the recommendation model for a performance-enhanced recommendation, which builds a two-phase fusion framework together with the MCM.
34 |
35 | * Factorization Machineal (FM) algorithm on Coat
36 | ```
37 | python gen_label.py --dataset=coat
38 | python fm_audio.py --dataset=coat
39 | ```
40 |
41 | * Factorization Machineal (FM) algorithm on ML-1M
42 | ```
43 | python gen_label.py --dataset=ml-1m
44 | python fm_audio.py --dataset=ml-1m
45 | ```
--------------------------------------------------------------------------------
/Speech/vits_lib/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | from vits_lib.text import cleaners
3 | from vits_lib.text.symbols import symbols
4 |
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9 |
10 |
11 | def text_to_sequence(text, cleaner_names):
12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13 | Args:
14 | text: string to convert to a sequence
15 | cleaner_names: names of the cleaner functions to run the text through
16 | Returns:
17 | List of integers corresponding to the symbols in the text
18 | '''
19 | sequence = []
20 |
21 | clean_text = _clean_text(text, cleaner_names)
22 | for symbol in clean_text:
23 | symbol_id = _symbol_to_id[symbol]
24 | sequence += [symbol_id]
25 | return sequence
26 |
27 |
28 | def cleaned_text_to_sequence(cleaned_text):
29 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
30 | Args:
31 | text: string to convert to a sequence
32 | Returns:
33 | List of integers corresponding to the symbols in the text
34 | '''
35 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
36 | return sequence
37 |
38 |
39 | def sequence_to_text(sequence):
40 | '''Converts a sequence of IDs back to a string'''
41 | result = ''
42 | for symbol_id in sequence:
43 | s = _id_to_symbol[symbol_id]
44 | result += s
45 | return result
46 |
47 |
48 | def _clean_text(text, cleaner_names):
49 | for name in cleaner_names:
50 | cleaner = getattr(cleaners, name)
51 | if not cleaner:
52 | raise Exception('Unknown cleaner: %s' % name)
53 | text = cleaner(text)
54 | return text
55 |
--------------------------------------------------------------------------------
/Evaluate/evaluate.py:
--------------------------------------------------------------------------------
1 | import fed
2 | import json
3 | import os
4 | import random
5 | import argparse
6 | import numpy as np
7 | from collections import defaultdict
8 |
9 |
10 | def load_data(dataset):
11 | with open(f"../Dialogue/res/{dataset}/dialogue_info_{dataset}.json",'r') as load_f:
12 | coat = json.load(load_f)
13 | contexts = []
14 | for _, dialogue in coat.items():
15 | dialogue_content = dialogue["content"]
16 | context = []
17 | for _, text in dialogue_content.items():
18 | context.append(text)
19 | context = ' <|endoftext|> '.join(context)
20 | context = '<|endoftext|> ' + context + ' <|endoftext|>'
21 | contexts.append(context)
22 |
23 | return contexts
24 |
25 |
26 | if __name__ == "__main__":
27 | parser = argparse.ArgumentParser(description='dialogue evaluate')
28 | parser.add_argument('--dataset',
29 | type=str,
30 | default='coat',
31 | help='select dataset, option: coat, ml-1m, opendialKG, redial, inspire')
32 | parser.add_argument('--eval_num',
33 | type=int,
34 | default=400,
35 | help='the number of dialogues')
36 | args = parser.parse_args()
37 | data = load_data(args.dataset)
38 | model, tokenizer = fed.load_models("microsoft/DialoGPT-large")
39 | print("model load successful")
40 |
41 | scores = []
42 | data = random.sample(data, args.eval_num)
43 | for conversation in data:
44 | scores.append(fed.evaluate(conversation, model, tokenizer))
45 |
46 | fed_scores = defaultdict(list)
47 | for result in scores:
48 | score_val = 0.0
49 | for key, val in result.items():
50 | fed_scores[key].append(val)
51 | score_val += val
52 | fed_scores['fed_overall'].append(score_val / len(result))
53 |
54 | save_dir = f'./res/{args.dataset}/'
55 | if not os.path.exists(save_dir):
56 | os.makedirs(save_dir)
57 | with open(save_dir + f'{args.dataset}_results.json', 'w') as f:
58 | json.dump(fed_scores, f)
59 |
--------------------------------------------------------------------------------
/Dialogue/config/recommend_pattern.py:
--------------------------------------------------------------------------------
1 | movie_rec = [
2 | "Okay, I will suggest some movies you may like.",
3 | "Okay, I will recommend some movies you may like.",
4 | "All right, I'll give you some movie recommendations.",
5 | "Okay, I can recommend some movies that you might enjoy.",
6 | "Based on your preferences, I think you might enjoy the following movies.",
7 | "I have some movie suggestions that might match your interests.",
8 | "Okay, I'll offer some movie recommendations for you.",
9 | "Here are some movies that I think you might like based on your preferences.",
10 | "Here are some movies that I think would be a good fit for your interests.",
11 | "Okay, I'll suggest some movies for you to watch.",
12 | "Great, I can suggest some movies that you might like.",
13 | "I'll suggest some movies based on your interests.",
14 | "Ok, I have some movie recommendations that you might like.",
15 | "Let me recommend some movies for you based on your preferences.",
16 | "Great, I have suggestions ideas for movies that you might enjoy.",
17 | "Now, I have some movies that I think you might like based on your preferences.",
18 | "Okay, I have some movie recommendations that might match your interests.",
19 | "I will offer some movie suggestions based on what you've told me."
20 | ]
21 |
22 | coat_rec = [
23 | "Okay, I will suggest some coats you may like.",
24 | "Okay, I will recommend some coats you may like.",
25 | "All right, I'll give you some coats recommendations.",
26 | "Okay, I can recommend some coats that you might enjoy.",
27 | "Based on your preferences, I think you might enjoy the following coats.",
28 | "I have some coats suggestions that might match your interests.",
29 | "Okay, I'll offer some coats recommendations for you.",
30 | "Here are some coats that I think you might like based on your preferences.",
31 | "Here are some coats that I think would be a good fit for your interests.",
32 | "Okay, I'll suggest some coats for you.",
33 | "Great, I can suggest some coats that you might like.",
34 | "I'll suggest some coats based on your interests.",
35 | "Ok, I have some coats recommendations that you might like.",
36 | "Let me recommend some coats for you based on your preferences.",
37 | "Great, I have suggestions ideas for coats that you might enjoy.",
38 | "Now, I have some coats that I think you might like based on your preferences.",
39 | "Okay, I have some coats recommendations that might match your interests.",
40 | "I will offer some coats suggestions based on what you've told me."
41 | ]
--------------------------------------------------------------------------------
/Dialogue/movie_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def get_user_age(age):
5 | if age < 20:
6 | age_str = 'under 20'
7 | elif age >= 20 and age <= 30:
8 | age_str = '20-30'
9 | elif age > 30:
10 | age_str = 'over 30'
11 |
12 | return age_str
13 |
14 | def get_user_gender(gender):
15 | if gender == 'F':
16 | gender_str = 'women'
17 | elif gender == 'M':
18 | gender_str = 'men'
19 |
20 | return gender_str
21 |
22 | def get_item_actor(actors):
23 | actors = actors.strip('[').strip(']').replace("'","").split(',')
24 | actors = [actor.strip().strip('"') for actor in actors]
25 |
26 |
27 | return actors
28 |
29 | def get_item_genre(genre):
30 | genre = genre.split("|")
31 | genre = [gen.lower() for gen in genre]
32 |
33 | return genre
34 |
35 | def modify_country(country):
36 | country_dict = {'Australia': 'Australian', 'United States': 'American',
37 | 'Japan': 'Japanese', 'United Kingdom': 'British',
38 | 'Mexico': 'Mexican', 'Italy': 'Italian',
39 | 'France': 'French', 'Germany': 'German',
40 | 'Brazil': 'Brazilian', 'Spain': 'Spanish',
41 | 'Netherlands': 'Dutch', 'Canada': 'Canadian',
42 | 'Hong Kong': 'Hong Kong', 'Cuba': 'Cuban',
43 | 'Monaco': 'Monaco', 'Belgium': 'Belgian',
44 | 'Czech Republic': 'Czech', 'West Germany': 'West German',
45 | 'Ireland': 'Irish', 'Soviet Union': 'Soviet Union',
46 | 'New Zealand': 'New Zealand', 'China': 'Chinese',
47 | 'Luxembourg': 'Luxembourg', 'Taiwan': 'Taiwanese',
48 | 'South Africa': 'South African', 'Sweden': 'Swedish',
49 | 'Switzerland': 'Swiss', 'Denmark': 'Danish',
50 | 'Argentina': 'Argentine', 'Russia': 'Russian',
51 | 'Aruba': 'Aruba', 'India': 'Indian',
52 | 'Norway':'Norwegian', 'Federal Republic of Yugoslavia': 'Yugoslav',
53 | 'Iran': 'Iranian', 'Bhutan': 'Bhutanese',
54 | 'Vietnam': 'Vietnamese', 'Dominican Republic': 'Dominican',
55 | 'Hungary': 'Hungarian', 'Poland': 'Polish'}
56 |
57 | return country_dict[country]
58 |
59 |
60 | def check_in_english(utterance):
61 | for _, text in utterance.items():
62 | text = re.sub(r'[.,"\'-?:!;]', '', text)
63 | text = text.replace(' ','')
64 | if not text.encode('UTF-8').isalpha():
65 | return False
66 |
67 | return True
68 |
--------------------------------------------------------------------------------
/Speech/inference_user.py:
--------------------------------------------------------------------------------
1 | import IPython.display as ipd
2 | import argparse
3 | import os
4 | import json
5 | import random
6 | import numpy as np
7 | from tqdm import tqdm
8 | from pydub import AudioSegment
9 | from utils import *
10 | import logging
11 |
12 |
13 | if __name__ == '__main__':
14 | parser = argparse.ArgumentParser(description='dialogue')
15 | parser.add_argument('--dataset',
16 | type=str,
17 | default='ml-1m',
18 | help='select dataset, option: coat, ml-1m')
19 | parser.add_argument('--format',
20 | type=str,
21 | default='mp3',
22 | help='select format, option: wav, flac, mp3')
23 | parser.add_argument('--start_id',
24 | type=int,
25 | default=0)
26 | args = parser.parse_args()
27 | with open(f"../Dialogue/res/{args.dataset}/dialogue_info_{args.dataset}.json",'r') as load_f:
28 | dialogue = json.load(load_f)
29 | logger = logging.getLogger()
30 | logger.setLevel(logging.ERROR)
31 | vctk_to_speaker_id, total_speaker = preprocess_speaker_info()
32 | agent, user = load_vits_model()
33 | random.seed(2023)
34 | save_dir = f'./speech_res/{args.dataset}_{args.format}_user/'
35 | if not os.path.exists(save_dir):
36 | os.makedirs(save_dir)
37 | for k, v in tqdm(dialogue.items()):
38 | dia_id = int(k)
39 | content = v['content']
40 | user_id = v['user_id']
41 | item_id = v['item_id']
42 | gender = v['user_gender']
43 | age = v['user_age']
44 |
45 | speaker_list_id = selet_speaker_list_idx(age, gender)
46 | speaker_list = total_speaker[speaker_list_id]
47 | speaker = random.choice(speaker_list)
48 | speaker_id = vctk_to_speaker_id[speaker]
49 | speaker_speed = 1.2
50 |
51 | audio_list = []
52 | count = 0
53 | save_name = f'diaid{dia_id}_uid{user_id}_iid{item_id}_{age}_{gender}_{speaker}'
54 | tag = 0
55 | for uttr_name, uttr in content.items():
56 | if count % 2 == 0:
57 | s_audio = generate_user_speech(user, uttr, speaker_id, speaker_speed)
58 | if tag == 0:
59 | combine_audio = s_audio
60 | tag = 1
61 | else:
62 | combine_audio = np.hstack((combine_audio, s_audio))
63 | count += 1
64 | audio = ipd.Audio(combine_audio, rate=22050, normalize=False)
65 | with open(f'{save_dir}{save_name}.mp3', 'wb') as f:
66 | f.write(audio.data)
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/Speech/inference.py:
--------------------------------------------------------------------------------
1 | import IPython.display as ipd
2 | import argparse
3 | import os
4 | import json
5 | import random
6 | import numpy as np
7 | from tqdm import tqdm
8 | from pydub import AudioSegment
9 | from utils import *
10 | import logging
11 | import random
12 |
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser(description='dialogue')
16 | parser.add_argument('--dataset',
17 | type=str,
18 | default='ml-1m',
19 | help='select dataset, option: coat, ml-1m')
20 | parser.add_argument('--format',
21 | type=str,
22 | default='mp3',
23 | help='select format, option: wav, flac, mp3')
24 | args = parser.parse_args()
25 | with open(f"../Dialogue/res/{args.dataset}/dialogue_info_{args.dataset}.json",'r') as load_f:
26 | dialogue = json.load(load_f)
27 | random.seed(2023)
28 | logger = logging.getLogger()
29 | logger.setLevel(logging.ERROR)
30 | vctk_to_speaker_id, total_speaker = preprocess_speaker_info()
31 | agent, user = load_vits_model()
32 | zero_audio = np.zeros((20000,), dtype=np.float32)
33 | zero_audio = ipd.Audio(zero_audio, rate=user[0].data.sampling_rate, normalize=False)
34 | with open(f'zero.wav', 'wb') as f:
35 | f.write(zero_audio.data)
36 | zero_wav = AudioSegment.from_wav("./zero.wav")
37 | save_dir = f'./speech_res/{args.dataset}_{args.format}/'
38 | if not os.path.exists(save_dir):
39 | os.makedirs(save_dir)
40 | for k, v in tqdm(dialogue.items()):
41 | dia_id = int(k)
42 |
43 | content = v['content']
44 | user_id = v['user_id']
45 | item_id = v['item_id']
46 | gender = v['user_gender']
47 | age = v['user_age']
48 |
49 | speaker_list_id = selet_speaker_list_idx(age, gender)
50 | speaker_list = total_speaker[speaker_list_id]
51 | speaker = random.choice(speaker_list)
52 | speaker_id = vctk_to_speaker_id[speaker]
53 | speaker_speed = 1.2
54 |
55 | audio_list = []
56 | count = 0
57 | save_name = f'diaid{dia_id}_uid{user_id}_iid{item_id}_{age}_{gender}_{speaker}'
58 | for uttr_name, uttr in content.items():
59 | if (count % 2 == 0):
60 | s_audio = generate_user_speech_audio(user, uttr, speaker_id, speaker_speed)
61 | elif (count % 2 != 0):
62 | s_audio = generate_agent_speech_audio(agent, uttr)
63 | count += 1
64 | with open(f'{uttr_name}.wav', 'wb') as f:
65 | f.write(s_audio.data)
66 | audio_list.append(uttr_name)
67 | tag = 0
68 | for i, uttr_name in enumerate(audio_list):
69 | wav_file = AudioSegment.from_wav(f"./{uttr_name}.wav")
70 | if tag == 0:
71 | combine_wav = wav_file
72 | combine_wav = combine_wav + zero_wav
73 | tag = 1
74 | else:
75 | combine_wav = combine_wav + wav_file
76 | if (i != (len(audio_list) - 1)) and (i != (len(audio_list) - 3)):
77 | combine_wav = combine_wav + zero_wav
78 | os.remove(f"./{uttr_name}.wav")
79 | if args.format == 'wav':
80 | combine_wav.export(f"{save_dir}{save_name}.wav", format="wav")
81 | elif args.format == 'mp3':
82 | combine_wav.export(f"{save_dir}{save_name}.mp3", format="mp3")
83 | elif args.format == 'flac':
84 | combine_wav.export(f"{save_dir}{save_name}.flac", format="flac")
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/Dialogue/coat_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def age_map(age):
4 | if age == 0:
5 | age = 1
6 | elif age == 1:
7 | age = 2
8 | elif age == 2:
9 | age = 2
10 | elif age == 3:
11 | age = 2
12 | elif age == 4:
13 | age = 2
14 | elif age == 5:
15 | age = 0
16 |
17 | return age
18 |
19 | def get_item_index(feature, user_id):
20 | item_feature = feature[user_id]
21 | index = np.where(item_feature == 1)[0][ : -1]
22 | gender = index[0]
23 | jacket = index[1] - 2
24 | color = index[2] - 18
25 |
26 | return gender, jacket, color
27 |
28 | def get_user_age(age):
29 | if age == 0:
30 | age_str = 'under 20'
31 | elif age == 1:
32 | age_str = '20-30'
33 | elif age == 2:
34 | age_str = 'over 30'
35 |
36 | return age_str
37 |
38 | def get_user_gender(gender):
39 | if gender == 0:
40 | gender_str = 'men'
41 | elif gender == 1:
42 | gender_str = 'women'
43 | return gender_str
44 |
45 | def get_item_gender(gender):
46 | if gender == 0:
47 | gender_str = 'men'
48 | elif gender == 1:
49 | gender_str = 'women'
50 |
51 | return gender_str
52 |
53 | def get_item_type(jacket):
54 | if jacket == 0:
55 | jacket_str = 'bomber'
56 | elif jacket == 1:
57 | jacket_str = 'cropped'
58 | elif jacket == 2:
59 | jacket_str = 'field'
60 | elif jacket == 3:
61 | jacket_str = 'fleece'
62 | elif jacket == 4:
63 | jacket_str = 'insulated'
64 | elif jacket == 5:
65 | jacket_str = 'motorcycle'
66 | elif jacket == 6:
67 | jacket_str = 'other'
68 | elif jacket == 7:
69 | jacket_str = 'packable'
70 | elif jacket == 8:
71 | jacket_str = 'parkas'
72 | elif jacket == 9:
73 | jacket_str = 'pea'
74 | elif jacket == 10:
75 | jacket_str = 'rain'
76 | elif jacket == 11:
77 | jacket_str = 'shells'
78 | elif jacket == 12:
79 | jacket_str = 'track'
80 | elif jacket == 13:
81 | jacket_str = 'trench'
82 | elif jacket == 14:
83 | jacket_str = 'vests'
84 | elif jacket == 15:
85 | jacket_str = 'waterproof'
86 |
87 | return jacket_str
88 |
89 | def get_item_type_index(jacket_val):
90 | index_dict = {'bomber':0, 'cropped':1, 'field':2, 'fleece':3, 'insulated':4,
91 | 'motorcycle':5, 'other':6, 'packable':7, 'parkas':8, 'pea':9,
92 | 'rain':10, 'shells':11, 'track':12, 'trench':13, 'vests':14, 'waterproof':15}
93 |
94 | return index_dict[jacket_val]
95 |
96 | def get_item_color(color):
97 | if color == 0:
98 | color_str = 'beige'
99 | elif color == 1:
100 | color_str = 'black'
101 | elif color == 2:
102 | color_str = 'blue'
103 | elif color == 3:
104 | color_str = 'brown'
105 | elif color == 4:
106 | color_str = 'gray'
107 | elif color == 5:
108 | color_str = 'green'
109 | elif color == 6:
110 | color_str = 'multi'
111 | elif color == 7:
112 | color_str = 'navy'
113 | elif color == 8:
114 | color_str = 'olive'
115 | elif color == 9:
116 | color_str = 'other'
117 | elif color == 10:
118 | color_str = 'pink'
119 | elif color == 11:
120 | color_str = 'purple'
121 | elif color == 12:
122 | color_str = 'red'
123 |
124 | return color_str
125 |
126 | def get_item_color_index(color_val):
127 | index_dict = {'beige':0, 'black':1, 'blue':2, 'brown':3, 'gray':4,
128 | 'green':5, 'multi':6, 'navy':7, 'olive':8, 'other':9,
129 | 'pink':10, 'purple':11, 'red':12,}
130 |
131 | return index_dict[color_val]
--------------------------------------------------------------------------------
/Speech/data/speaker_info/speaker-info.txt:
--------------------------------------------------------------------------------
1 | ID AGE GENDER ACCENTS REGION COMMENTS
2 | p225 23 F English Southern England
3 | p226 22 M English Surrey
4 | p227 38 M English Cumbria
5 | p228 22 F English Southern England
6 | p229 23 F English Southern England
7 | p230 22 F English Stockton-on-tees
8 | p231 23 F English Southern England
9 | p232 23 M English Southern England
10 | p233 23 F English Staffordshire
11 | p234 22 F Scottish West Dumfries
12 | p236 23 F English Manchester
13 | p237 22 M Scottish Fife
14 | p238 22 F NorthernIrish Belfast
15 | p239 22 F English SW England
16 | p240 21 F English Southern England
17 | p241 21 M Scottish Perth
18 | p243 22 M English London
19 | p244 22 F English Manchester
20 | p245 25 M Irish Dublin
21 | p246 22 M Scottish Selkirk
22 | p247 22 M Scottish Argyll
23 | p248 23 F Indian
24 | p249 22 F Scottish Aberdeen
25 | p250 22 F English SE England
26 | p251 26 M Indian
27 | p252 22 M Scottish Edinburgh
28 | p253 22 F Welsh Cardiff
29 | p254 21 M English Surrey
30 | p255 19 M Scottish Galloway
31 | p256 24 M English Birmingham
32 | p257 24 F English Southern England
33 | p258 22 M English Southern England
34 | p259 23 M English Nottingham
35 | p260 21 M Scottish Orkney
36 | p261 26 F NorthernIrish Belfast
37 | p262 23 F Scottish Edinburgh
38 | p263 22 M Scottish Aberdeen
39 | p264 23 F Scottish West Lothian
40 | p265 23 F Scottish Ross
41 | p266 22 F Irish Athlone
42 | p267 23 F English Yorkshire
43 | p268 23 F English Southern England
44 | p269 20 F English Newcastle
45 | p270 21 M English Yorkshire
46 | p271 19 M Scottish Fife
47 | p272 23 M Scottish Edinburgh
48 | p273 23 M English Suffolk
49 | p274 22 M English Essex
50 | p275 23 M Scottish Midlothian
51 | p276 24 F English Oxford
52 | p277 23 F English NE England
53 | p278 22 M English Cheshire
54 | p279 23 M English Leicester
55 | p280 25 F Unknown France (mic2 files unavailable)
56 | p281 29 M Scottish Edinburgh
57 | p282 23 F English Newcastle
58 | p283 24 F Irish Cork
59 | p284 20 M Scottish Fife
60 | p285 21 M Scottish Edinburgh
61 | p286 23 M English Newcastle
62 | p287 23 M English York
63 | p288 22 F Irish Dublin
64 | p292 23 M NorthernIrish Belfast
65 | p293 22 F NorthernIrish Belfast
66 | p294 33 F American San Francisco
67 | p295 23 F Irish Dublin
68 | p297 20 F American New York
69 | p298 19 M Irish Tipperary
70 | p299 25 F American California
71 | p300 23 F American California
72 | p301 23 F American North Carolina
73 | p302 20 M Canadian Montreal
74 | p303 24 F Canadian Toronto
75 | p304 22 M NorthernIrish Belfast
76 | p305 19 F American Philadelphia
77 | p306 21 F American New York
78 | p307 23 F Canadian Ontario
79 | p308 18 F American Alabama
80 | p310 21 F American Tennessee
81 | p311 21 M American Iowa
82 | p312 19 F Canadian Hamilton
83 | p313 24 F Irish County Down
84 | p314 26 F SouthAfrican Cape Town
85 | p315 18 M American New England (Text and mic2 files unavailable)
86 | p316 20 M Canadian Alberta
87 | p317 23 F Canadian Hamilton
88 | p318 32 F American Napa
89 | p323 19 F SouthAfrican Pretoria
90 | p326 26 M Australian English Sydney
91 | p329 23 F American
92 | p330 26 F American
93 | p333 19 F American Indiana
94 | p334 18 M American Chicago
95 | p335 25 F NewZealand English
96 | p336 18 F SouthAfrican Johannesburg
97 | p339 21 F American Pennsylvania
98 | p340 18 F Irish Dublin
99 | p341 26 F American Ohio
100 | p343 27 F Canadian Alberta
101 | p345 22 M American Florida
102 | p347 26 M SouthAfrican Johannesburg
103 | p351 21 F NorthernIrish Derry
104 | p360 19 M American New Jersey
105 | p361 19 F American New Jersey
106 | p362 29 F American
107 | p363 22 M Canadian Toronto
108 | p364 23 M Irish Donegal
109 | p374 28 M Australian English
110 | p376 22 M Indian
111 | s5 22 F British (new speaker, more data would be released for training a speaker-dependent model)
112 |
--------------------------------------------------------------------------------
/Dialogue/config/thanks_pattern.py:
--------------------------------------------------------------------------------
1 | movie_agent = ["It was my pleasure to help.",
2 | "I'm happy to have been able to assist.",
3 | "No problem, I'm glad I could help.",
4 | "I'm glad I could make a difference.",
5 | "I'm glad I could be of service.",
6 | "It was my pleasure to assist you.",
7 | "I'm glad I could provide some assistance.",
8 | "You're welcome. I'm always happy to help with movie recommendations.",
9 | "It was my pleasure to provide some suggestions. I hope you enjoy the movies I recommended.",
10 | "I'm glad I could assist with your movie selection. Let me know if you have any other questions or need further recommendations.",
11 | "I'm always happy to assist with movie recommendations.",
12 | "I'm glad I could provide some movie recommendations for you. Let me know if you have any feedback or need further suggestions.",
13 | "I'm glad I could be of service with your movie selection. I hope you enjoy the movies I recommended.",
14 | "You're welcome.",
15 | "Hope I made some good suggestions.",
16 | "Glad to help.",
17 | ]
18 |
19 | movie_user = ["Thanks for all the help.",
20 | "Thanks for the advice, I'll check them.",
21 | "Thanks a lot for your recommendations.",
22 | "Thanks.",
23 | "Thank you.",
24 | "Thanks a lot for your help in finding something.",
25 | "Thank you so much.",
26 | "Thank you so much for the recommendations. I really appreciate your help.",
27 | "Thank you for your recommendation.",
28 | "I'm grateful for your suggestions.",
29 | "Thanks for the recommendations. I'm excited to try out some of the films you suggested.",
30 | "I really appreciate your help with movie recommendations.",
31 | "I appreciate your recommendations.",
32 | "Thanks for your help with movie suggestions.",
33 | "I'm grateful for your assistance with movie recommendations.",
34 | "I appreciate your help with movie suggestions.",
35 | ]
36 |
37 |
38 | coat_agent = ["It was my pleasure to help.",
39 | "I'm happy to have been able to assist.",
40 | "No problem, I'm glad I could help.",
41 | "I'm glad I could make a difference.",
42 | "I'm glad I could be of service.",
43 | "It was my pleasure to assist you.",
44 | "I'm glad I could provide some assistance.",
45 | "You're welcome. I'm always happy to help with coats recommendations.",
46 | "It was my pleasure to provide some suggestions. I hope you enjoy the coats I recommended.",
47 | "I'm glad I could assist with your coats selection. Let me know if you have any other questions or need further recommendations.",
48 | "I'm always happy to assist with coats recommendations.",
49 | "I'm glad I could provide some coats recommendations for you. Let me know if you have any feedback or need further suggestions.",
50 | "I'm glad I could be of service with your coats selection. I hope you enjoy the coats I recommended.",
51 | "You're welcome.",
52 | "Hope I made some good suggestions.",
53 | "Glad to help."
54 | ]
55 |
56 |
57 | coat_user = ["Thank you so much for the recommendations.",
58 | "Your suggestions were really helpful.",
59 | "I really appreciate your advice.",
60 | "Thanks for taking the time to help me out.",
61 | "I'm so grateful for your recommendations.",
62 | "Thank you for your kind assistance.",
63 | "Thanks for all the help.",
64 | "Thanks for the advice, I'll check them.",
65 | "Thanks a lot for your recommendations.",
66 | "Thanks.",
67 | "Thank you.",
68 | "Thanks a lot for your help in finding something.",
69 | "Thank you so much.",
70 | "Thank you so much for the recommendations. I really appreciate your help.",
71 | "Thank you for your recommendation.",
72 | "Thanks for the recommendations. I'm excited to try out some of the coats you suggested.",
73 | "I really appreciate your help with coats recommendations.",
74 | "I'm grateful for your assistance with coats recommendations.",
75 | "I appreciate your help with coats suggestions."
76 | ]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
12 |
13 |
14 | Table of Contents
15 |
16 | -
17 | About The Project
18 |
19 | -
20 | Dataset Description
21 |
22 | - Potential Solution Exploration
23 | - Data Construction
24 | - Acknowledgement
25 |
26 |
27 |
28 | # About The Project
29 | This project is the code of paper "[Towards Building Voice-based Conversational Recommender Systems: Datasets, Potential Solutions, and Prospects](https://arxiv.org/abs/2306.08219)". In this project, we aim to provide two voice-based conversational recommender systems datasets in the e-commerce and movie domains.
30 |
31 |
32 |
33 | # Dataset Description
34 | You can download datasets from GoogleDrive. The datasets consist of two parts: [coat.tar.gz](https://drive.google.com/file/d/1FnpYhMaeskckxGheKjar0U4YHIdDKM6K/view?usp=share_link) and [ml-1m.tar.gz](https://drive.google.com/file/d/195ugsUrU51VMUjjI329M84qegtK2QuGC/view?usp=sharing)
35 |
36 | ## Dataset files
37 | The data file is formatted as a mp3 file and the file name form is `diaidxx_uidxx_iidxx_xx_xx_xx.mp3`.
38 |
39 | For example, for file `diaid21_uid249_iid35_20-30_men_251.mp3`, its meaning is as follows:
40 | ```
41 | diaid21: corresponds to dialogue 21 in the text-based conversation dataset
42 | uid249: user id is 249
43 | iid35: item id is 35
44 | 20-30: user's age is between 20 and 30
45 | men: user's gender is male
46 | 251: corresponds to speaker p251 in vctk dataset
47 | ```
48 | Speaker information on the vctk dataset can be found [here](www.udialogue.org/download/cstr-vctk-corpus.html)
49 |
50 | ## Case study
51 | Here we provide a demo of a data file (i.e., `diaid21_uid249_iid35_20-30_men_251.mp3`) that contains text and audio dialogue between the user and the agent.
52 |
53 | https://github.com/hyllll/VCRS/assets/38367896/b2b1b4d7-e860-46f4-9d6f-24ac8e1a6192
54 |
55 | Note that since we currently only explore the impact of speech on VCRS from the user's perspective, only the user's speech is included in the provided dataset. If you want complete dialogue audio, you can generate it through the code we provide.
56 |
57 | # Potential Solution Exploration
58 | We propose to extract explicit semantic features from the voice data and then incorporate them into the recommendation model in a two-phase fusion manner.
59 |
60 | Please refer to [here](https://github.com/hyllll/VCRS/tree/main/Recommender) for how to run the code.
61 |
62 | # Data Construction
63 | Our VCRSs dataset creation task includes four steps: (1) backbone dataset selection; (2)text-based conversation generation; (3) voice-based conversation generation; and (4) quality evaluation.
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 | ## Backbone dataset selection
72 | We choose [Coat](www.cs.cornell.edu/~schnabts/mnar/) and [ML-1M](grouplens.org/datasets/movielens/1m/) as our backbone datasets. Using user-item interactions and item features to simulate a text-based conversation between users and agents for recommendation and using user features to assign proper speakers.
73 |
74 | ## Text-based conversation generation
75 | Please refer to [here](https://github.com/hyllll/VCRS/tree/main/Dialogue) for how to generate the text-based conversation and the code is in `./Dialogue/` directory.
76 |
77 | ## Voice-based conversation generation
78 | Please refer to [here](https://github.com/hyllll/VCRS/tree/main/Speech) for how to generate the voice-based conversation.
79 |
80 | ## Quality evaluation
81 | We adopt the fine-grained evaluation of dialogue (FED) metric to measure the quality of the generated text-based conversation.
82 | ### Installtion
83 | ```
84 | pip install -r requirements.txt
85 | ```
86 | ### RUN
87 | 1. ```cd ./Evaluate/```
88 | 2. ``` python evaluate.py --dataset='xxx'```, ```xxx``` is ```coat``` or ```ml-1m```.
89 | 3. All results are saved in `./res/` directory.
90 |
91 |
92 |
93 | # Acknowledgement
94 | * Convert text to audio using [VITS](https://github.com/jaywalnut310/vits), a SOTA end-to-end text-to-speech (TTS) model.
95 |
96 | * Improve code efficiency by [conv_rec_sys](https://github.com/xxkkrr/conv_rec_sys).
97 |
98 | * Evaluate text-based conversation with [Fed](https://github.com/exe1023/DialEvalMetrics).
99 |
--------------------------------------------------------------------------------
/Speech/vits_lib/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 | import re
16 | from unidecode import unidecode
17 | from phonemizer import phonemize
18 | import logging
19 | import sys
20 | from logging import Logger
21 |
22 |
23 | # Regular expression matching whitespace:
24 | _whitespace_re = re.compile(r'\s+')
25 |
26 | # List of (regular expression, replacement) pairs for abbreviations:
27 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
28 | ('mrs', 'misess'),
29 | ('mr', 'mister'),
30 | ('dr', 'doctor'),
31 | ('st', 'saint'),
32 | ('co', 'company'),
33 | ('jr', 'junior'),
34 | ('maj', 'major'),
35 | ('gen', 'general'),
36 | ('drs', 'doctors'),
37 | ('rev', 'reverend'),
38 | ('lt', 'lieutenant'),
39 | ('hon', 'honorable'),
40 | ('sgt', 'sergeant'),
41 | ('capt', 'captain'),
42 | ('esq', 'esquire'),
43 | ('ltd', 'limited'),
44 | ('col', 'colonel'),
45 | ('ft', 'fort'),
46 | ]]
47 |
48 |
49 | def expand_abbreviations(text):
50 | for regex, replacement in _abbreviations:
51 | text = re.sub(regex, replacement, text)
52 | return text
53 |
54 |
55 | def expand_numbers(text):
56 | return normalize_numbers(text)
57 |
58 |
59 | def lowercase(text):
60 | return text.lower()
61 |
62 |
63 | def collapse_whitespace(text):
64 | return re.sub(_whitespace_re, ' ', text)
65 |
66 |
67 | def convert_to_ascii(text):
68 | return unidecode(text)
69 |
70 |
71 | def basic_cleaners(text):
72 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
73 | text = lowercase(text)
74 | text = collapse_whitespace(text)
75 | return text
76 |
77 |
78 | def transliteration_cleaners(text):
79 | '''Pipeline for non-English text that transliterates to ASCII.'''
80 | text = convert_to_ascii(text)
81 | text = lowercase(text)
82 | text = collapse_whitespace(text)
83 | return text
84 |
85 |
86 | def english_cleaners(text):
87 | '''Pipeline for English text, including abbreviation expansion.'''
88 | text = convert_to_ascii(text)
89 | text = lowercase(text)
90 | text = expand_abbreviations(text)
91 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, words_mismatch='ignore', logger=get_logger())
92 | phonemes = collapse_whitespace(phonemes)
93 | return phonemes
94 |
95 |
96 | def english_cleaners2(text):
97 | '''Pipeline for English text, including abbreviation expansion. + punctuation + stress'''
98 | text = convert_to_ascii(text)
99 | text = lowercase(text)
100 | text = expand_abbreviations(text)
101 | phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, preserve_punctuation=True, with_stress=True, words_mismatch='ignore', logger=get_logger())
102 | phonemes = collapse_whitespace(phonemes)
103 | return phonemes
104 |
105 |
106 |
107 | def get_logger(verbosity: str = 'quiet', name: str = 'phonemizer') -> Logger:
108 | """Returns a configured logging.Logger instance
109 | from https://github.com/bootphon/phonemizer/blob/master/phonemizer/logger.py
110 | The logger is configured to output messages on the standard error stream
111 | (stderr).
112 | Parameters
113 | ----------
114 | verbosity (str) : The level of verbosity, must be 'verbose' (displays
115 | debug/info and warning messages), 'normal' (warnings only) or 'quiet' (do
116 | not display anything).
117 | name (str) : The logger name, default to 'phonemizer'
118 | Raises
119 | ------
120 | RuntimeError if `verbosity` is not 'normal', 'verbose', or 'quiet'.
121 | """
122 | # make sure the verbosity argument is valid
123 | valid_verbosity = ['normal', 'verbose', 'quiet']
124 | if verbosity not in valid_verbosity:
125 | raise RuntimeError(
126 | f'verbosity is {verbosity} but must be in '
127 | f'{", ".join(valid_verbosity)}')
128 |
129 | logger = logging.getLogger(name)
130 |
131 | # setup output to stderr
132 | logger.handlers = []
133 | handler = logging.StreamHandler(sys.stderr)
134 |
135 | # setup verbosity level
136 | logger.setLevel(logging.ERROR)
137 | # print(verbosity)
138 | if verbosity == 'verbose':
139 | logger.setLevel(logging.DEBUG)
140 | elif verbosity == 'quiet':
141 | handler = logging.NullHandler()
142 |
143 | # setup messages format
144 | handler.setFormatter(logging.Formatter('[%(levelname)s] %(message)s'))
145 | logger.addHandler(handler)
146 | return logger
147 |
--------------------------------------------------------------------------------
/Speech/vits_lib/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | def init_weights(m, mean=0.0, std=0.01):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | m.weight.data.normal_(mean, std)
12 |
13 |
14 | def get_padding(kernel_size, dilation=1):
15 | return int((kernel_size*dilation - dilation)/2)
16 |
17 |
18 | def convert_pad_shape(pad_shape):
19 | l = pad_shape[::-1]
20 | pad_shape = [item for sublist in l for item in sublist]
21 | return pad_shape
22 |
23 |
24 | def intersperse(lst, item):
25 | result = [item] * (len(lst) * 2 + 1)
26 | result[1::2] = lst
27 | return result
28 |
29 |
30 | def kl_divergence(m_p, logs_p, m_q, logs_q):
31 | """KL(P||Q)"""
32 | kl = (logs_q - logs_p) - 0.5
33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34 | return kl
35 |
36 |
37 | def rand_gumbel(shape):
38 | """Sample from the Gumbel distribution, protect from overflows."""
39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40 | return -torch.log(-torch.log(uniform_samples))
41 |
42 |
43 | def rand_gumbel_like(x):
44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45 | return g
46 |
47 |
48 | def slice_segments(x, ids_str, segment_size=4):
49 | ret = torch.zeros_like(x[:, :, :segment_size])
50 | for i in range(x.size(0)):
51 | idx_str = ids_str[i]
52 | idx_end = idx_str + segment_size
53 | ret[i] = x[i, :, idx_str:idx_end]
54 | return ret
55 |
56 |
57 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
58 | b, d, t = x.size()
59 | if x_lengths is None:
60 | x_lengths = t
61 | ids_str_max = x_lengths - segment_size + 1
62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63 | ret = slice_segments(x, ids_str, segment_size)
64 | return ret, ids_str
65 |
66 |
67 | def get_timing_signal_1d(
68 | length, channels, min_timescale=1.0, max_timescale=1.0e4):
69 | position = torch.arange(length, dtype=torch.float)
70 | num_timescales = channels // 2
71 | log_timescale_increment = (
72 | math.log(float(max_timescale) / float(min_timescale)) /
73 | (num_timescales - 1))
74 | inv_timescales = min_timescale * torch.exp(
75 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78 | signal = F.pad(signal, [0, 0, 0, channels % 2])
79 | signal = signal.view(1, channels, length)
80 | return signal
81 |
82 |
83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84 | b, channels, length = x.size()
85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86 | return x + signal.to(dtype=x.dtype, device=x.device)
87 |
88 |
89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90 | b, channels, length = x.size()
91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93 |
94 |
95 | def subsequent_mask(length):
96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97 | return mask
98 |
99 |
100 | @torch.jit.script
101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102 | n_channels_int = n_channels[0]
103 | in_act = input_a + input_b
104 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106 | acts = t_act * s_act
107 | return acts
108 |
109 |
110 | def convert_pad_shape(pad_shape):
111 | l = pad_shape[::-1]
112 | pad_shape = [item for sublist in l for item in sublist]
113 | return pad_shape
114 |
115 |
116 | def shift_1d(x):
117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118 | return x
119 |
120 |
121 | def sequence_mask(length, max_length=None):
122 | if max_length is None:
123 | max_length = length.max()
124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125 | return x.unsqueeze(0) < length.unsqueeze(1)
126 |
127 |
128 | def generate_path(duration, mask):
129 | """
130 | duration: [b, 1, t_x]
131 | mask: [b, 1, t_y, t_x]
132 | """
133 | device = duration.device
134 |
135 | b, _, t_y, t_x = mask.shape
136 | cum_duration = torch.cumsum(duration, -1)
137 |
138 | cum_duration_flat = cum_duration.view(b * t_x)
139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140 | path = path.view(b, t_x, t_y)
141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142 | path = path.unsqueeze(1).transpose(2,3) * mask
143 | return path
144 |
145 |
146 | def clip_grad_value_(parameters, clip_value, norm_type=2):
147 | if isinstance(parameters, torch.Tensor):
148 | parameters = [parameters]
149 | parameters = list(filter(lambda p: p.grad is not None, parameters))
150 | norm_type = float(norm_type)
151 | if clip_value is not None:
152 | clip_value = float(clip_value)
153 |
154 | total_norm = 0
155 | for p in parameters:
156 | param_norm = p.grad.data.norm(norm_type)
157 | total_norm += param_norm.item() ** norm_type
158 | if clip_value is not None:
159 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
160 | total_norm = total_norm ** (1. / norm_type)
161 | return total_norm
162 |
--------------------------------------------------------------------------------
/Evaluate/fed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
7 | from torch.utils.data.distributed import DistributedSampler
8 |
9 | from transformers import AutoTokenizer, AutoModelWithLMHead
10 |
11 | #tokenizer = GPT2Tokenizer.from_pretrained('dialogpt')
12 | #model = GPT2LMHeadModel.from_pretrained('gpt2')
13 | #weights = torch.load("dialogpt/small_fs.pkl")
14 | #weights = {k.replace("module.", ""): v for k,v in weights.items()}
15 | #weights["lm_head.weight"] = weights["lm_head.decoder.weight"]
16 | #weights.pop("lm_head.decoder.weight",None)
17 | #model.load_state_dict(weights)
18 |
19 |
20 | def load_models(name="microsoft/DialoGPT-large"):
21 | tokenizer = AutoTokenizer.from_pretrained(name, mirror='tuna')
22 | model = AutoModelWithLMHead.from_pretrained(name, mirror='tuna')
23 | model.to("cuda")
24 | return model, tokenizer
25 |
26 | def score(text, tokenizer, model):
27 | if not text.startswith("<|endoftext|> "):
28 | text = "<|endoftext|> " + text
29 | input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0) # Batch size 1
30 | tokenize_input = tokenizer.tokenize(text)
31 | #50256 is the token_id for <|endoftext|>
32 | tensor_input = torch.tensor([ tokenizer.convert_tokens_to_ids(tokenize_input)]).cuda()
33 | with torch.no_grad():
34 | outputs = model(tensor_input, labels=tensor_input)
35 | loss, logits = outputs[:2]
36 |
37 | return loss.item()
38 |
39 | def evaluate(conversation, model, tokenizer):
40 | scores = {}
41 | turn_level_utts = {
42 | "interesting": {
43 | "positive": ["Wow that is really interesting.", "That's really interesting!", "Cool! That sounds super interesting."],
44 | "negative": ["That's not very interesting.", "That's really boring.", "That was a really boring response."]
45 | },
46 | "engaging": {
47 | "positive": ["Wow! That's really cool!", "Tell me more!", "I'm really interested in learning more about this."],
48 | "negative": ["Let's change the topic.", "I don't really care. That's pretty boring.", "I want to talk about something else."]
49 | },
50 | "specific": {
51 | "positive": ["That's good to know. Cool!", "I see, that's interesting.", "That's a good point."],
52 | "negative": ["That's a very generic response.", "Not really relevant here.", "That's not really relevant here."]
53 | },
54 | "relevant": {
55 | "positive": [],
56 | "negative": ["That's not even related to what I said.", "Don't change the topic!", "Why are you changing the topic?"]
57 | },
58 | "correct": {
59 | "positive": [],
60 | "negative": ["You're not understanding me!", "I am so confused right now!", "I don't understand what you're saying."]
61 | },
62 | "semantically appropriate": {
63 | "positive": ["That makes sense!", "You have a good point."],
64 | "negative": ["That makes no sense!"]
65 | },
66 | "understandable": {
67 | "positive": ["That makes sense!", "You have a good point."],
68 | "negative": ["I don't understand at all!", "I'm so confused!", "That makes no sense!", "What does that even mean?"]
69 | },
70 | "fluent": {
71 | "positive": ["That makes sense!", "You have a good point."],
72 | "negative": ["Is that real English?", "I'm so confused right now!", "That makes no sense!"]
73 | },
74 | }
75 | for metric,utts in turn_level_utts.items():
76 | pos = utts["positive"]
77 | neg = utts["negative"]
78 |
79 | # Positive score
80 | high_score = 0
81 | for m in pos:
82 | hs = score(conversation + " <|endoftext|> " + m, tokenizer, model)
83 | high_score += hs
84 |
85 | high_score = high_score/max(len(pos), 1)
86 |
87 | # Negative score
88 | low_score = 0
89 | for m in neg:
90 | ls = score(conversation + " <|endoftext|> " + m, tokenizer, model)
91 | low_score += ls
92 | low_score = low_score/max(len(neg), 1)
93 |
94 | scores[metric] = (low_score - high_score)
95 |
96 | dialog_level_utts = {
97 | "coherent": {
98 | "positive": [],
99 | "negative": ["You're making no sense at all.", "You're changing the topic so much!", "You are so confusing."]
100 | },
101 | "error recovery": {
102 | "positive": [],
103 | "negative": ["I am so confused right now.", "You're really confusing.", "I don't understand what you're saying."]
104 | },
105 | "consistent": {
106 | "positive": [],
107 | "negative": ["That's not what you said earlier!", "Stop contradicting yourself!"],
108 | },
109 | "diverse": {
110 | "positive": [],
111 | "negative": ["Stop saying the same thing repeatedly.", "Why are you repeating yourself?", "Stop repeating yourself!"]
112 | },
113 | "depth": {
114 | "positive": [],
115 | "negative": ["Stop changing the topic so much.", "Don't change the topic!"],
116 | },
117 | "likeable": {
118 | "positive": ["I like you!", "You're super polite and fun to talk to", "Great talking to you."],
119 | "negative": ["You're not very nice.", "You're not very fun to talk to.", "I don't like you."]
120 | },
121 | "understand": {
122 | "positive": [],
123 | "negative": ["You're not understanding me!", "What are you trying to say?", "I don't understand what you're saying."]
124 | },
125 | "flexible": {
126 | "positive": ["You're very easy to talk to!", "Wow you can talk about a lot of things!"],
127 | "negative": ["I don't want to talk about that!", "Do you know how to talk about something else?"],
128 | },
129 | "informative": {
130 | "positive": ["Thanks for all the information!", "Wow that's a lot of information.", "You know a lot of facts!"],
131 | "negative": ["You're really boring.", "You don't really know much."],
132 | },
133 | "inquisitive": {
134 | "positive": ["You ask a lot of questions!", "That's a lot of questions!"],
135 | "negative": ["You don't ask many questions.", "You don't seem interested."],
136 | },
137 | }
138 | for metric,utts in dialog_level_utts.items():
139 | pos = utts["positive"]
140 | neg = utts["negative"]
141 |
142 | # Positive
143 | high_score = 0
144 | for m in pos:
145 | hs = score(conversation + " <|endoftext|> " + m, tokenizer, model)
146 | high_score += hs
147 |
148 | high_score = high_score/max(len(pos), 1)
149 |
150 | # Negative
151 | low_score = 0
152 | for m in neg:
153 | ls = score(conversation + " <|endoftext|> " + m, tokenizer, model)
154 | low_score += ls
155 | low_score = low_score/max(len(neg), 1)
156 |
157 | scores[metric] = (low_score - high_score)
158 |
159 | return scores
--------------------------------------------------------------------------------
/Speech/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import vits_lib.commons as commons
3 | import vits_lib.utils as vits_utils
4 | from vits_lib.models import SynthesizerTrn
5 | from vits_lib.text.symbols import symbols
6 | from vits_lib.text import text_to_sequence
7 | import IPython.display as ipd
8 |
9 |
10 |
11 | def get_text(text, hps):
12 | text_norm = text_to_sequence(text, hps.data.text_cleaners)
13 | if hps.data.add_blank:
14 | text_norm = commons.intersperse(text_norm, 0)
15 | text_norm = torch.LongTensor(text_norm)
16 | return text_norm
17 |
18 | def get_vid_sid(line):
19 | '''
20 | Return:
21 | vctk_id: vctk id
22 | speaker_id: speaker id in vits
23 | '''
24 | line = line.split('|')
25 | vctk_id = int(line[0].split('/')[1][1:])
26 | speaker_id = int(line[1])
27 |
28 | return vctk_id, speaker_id
29 |
30 | def split_speaker(s_info):
31 | # TODO: optimieze if else
32 | under_20_male = []
33 | under_20_female = []
34 | between_20_30_male = []
35 | between_20_30_female = []
36 | above_30_male = []
37 | above_30_female = []
38 |
39 | for i, info in enumerate(s_info):
40 | if i == 0:
41 | continue
42 | else:
43 | gender = info[10:11]
44 | age = int(info[6:8])
45 | user_id = int(info[1:4])
46 | if age < 20 and gender == 'M':
47 | under_20_male.append(user_id)
48 | elif age < 20 and gender == 'F':
49 | under_20_female.append(user_id)
50 | elif age >= 20 and age <=30 and gender == 'M':
51 | between_20_30_male.append(user_id)
52 | elif age >= 20 and age <=30 and gender == 'F':
53 | between_20_30_female.append(user_id)
54 | elif age > 30 and gender == 'M':
55 | above_30_male.append(user_id)
56 | else:
57 | above_30_female.append(user_id)
58 |
59 | between_20_30_female.remove(5)
60 | under_20_male.remove(315)
61 |
62 | total_speaker = []
63 | total_speaker.append(under_20_male)
64 | total_speaker.append(under_20_female)
65 | total_speaker.append(between_20_30_male)
66 | total_speaker.append(between_20_30_female)
67 | total_speaker.append(above_30_male)
68 | total_speaker.append(above_30_female)
69 |
70 | return total_speaker
71 |
72 |
73 | def preprocess_speaker_info():
74 | with open('./data/speaker_info/speaker-info.txt') as f:
75 | lines = f.readlines()
76 | s_info = []
77 | for sub in lines:
78 | s_info.append(sub.replace("\n", ""))
79 |
80 | with open('./data/speaker_info/vctk_audio_sid_text_test_filelist.txt.cleaned.txt') as f:
81 | sid_text_test_filelist = f.readlines()
82 | with open('./data/speaker_info/vctk_audio_sid_text_train_filelist.txt.cleaned.txt') as f:
83 | sid_text_train_filelist = f.readlines()
84 | with open('./data/speaker_info/vctk_audio_sid_text_val_filelist.txt.cleaned.txt') as f:
85 | sid_text_val_filelist = f.readlines()
86 | filelist = sid_text_test_filelist + sid_text_train_filelist + sid_text_val_filelist
87 | vctk_to_speaker_id = {}
88 | for line in filelist:
89 | v_id, s_id = get_vid_sid(line)
90 | if v_id not in vctk_to_speaker_id.keys():
91 | vctk_to_speaker_id[v_id] = s_id
92 | total_speaker = split_speaker(s_info)
93 |
94 | return vctk_to_speaker_id, total_speaker
95 |
96 |
97 | def load_vits_model():
98 | hps_agent = vits_utils.get_hparams_from_file("./vits_lib/configs/ljs_base.json")
99 | net_agent = SynthesizerTrn(
100 | len(symbols),
101 | hps_agent.data.filter_length // 2 + 1,
102 | hps_agent.train.segment_size // hps_agent.data.hop_length,
103 | **hps_agent.model).cuda()
104 | _ = net_agent.eval()
105 | _ = vits_utils.load_checkpoint("./vits_lib/pretrain/pretrained_ljs.pth", net_agent, None)
106 |
107 | hps_user = vits_utils.get_hparams_from_file("./vits_lib/configs/vctk_base.json")
108 | net_user = SynthesizerTrn(
109 | len(symbols),
110 | hps_user.data.filter_length // 2 + 1,
111 | hps_user.train.segment_size // hps_user.data.hop_length,
112 | n_speakers=hps_user.data.n_speakers,
113 | **hps_user.model).cuda()
114 | _ = net_user.eval()
115 | _ = vits_utils.load_checkpoint("./vits_lib/pretrain/pretrained_vctk.pth", net_user, None)
116 |
117 | return (hps_agent, net_agent), (hps_user, net_user)
118 |
119 |
120 | def generate_agent_speech_audio(agent, text):
121 | hps_agent = agent[0]
122 | net_agent = agent[1]
123 |
124 | stn_tst = get_text(text, hps_agent)
125 | with torch.no_grad():
126 | x_tst = stn_tst.cuda().unsqueeze(0)
127 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
128 | audio = net_agent.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1.0)[0][0,0].data.cpu().float().numpy()
129 | audio = ipd.Audio(audio, rate=hps_agent.data.sampling_rate, normalize=False)
130 |
131 | return audio
132 |
133 |
134 | def generate_user_speech(user, text, sid, speaker_speed):
135 | hps_user = user[0]
136 | net_user = user[1]
137 |
138 | stn_tst = get_text(text, hps_user)
139 | with torch.no_grad():
140 | x_tst = stn_tst.cuda().unsqueeze(0)
141 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
142 | sid = torch.LongTensor([sid]).cuda()
143 | audio = net_user.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=speaker_speed)[0][0,0].data.cpu().float().numpy()
144 | # audio = ipd.Audio(audio, rate=hps_user.data.sampling_rate, normalize=False)
145 |
146 | return audio
147 |
148 |
149 | def generate_user_speech_audio(user, text, sid, speaker_speed):
150 | hps_user = user[0]
151 | net_user = user[1]
152 |
153 | stn_tst = get_text(text, hps_user)
154 | with torch.no_grad():
155 | x_tst = stn_tst.cuda().unsqueeze(0)
156 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
157 | sid = torch.LongTensor([sid]).cuda()
158 | audio = net_user.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=speaker_speed)[0][0,0].data.cpu().float().numpy()
159 | audio = ipd.Audio(audio, rate=hps_user.data.sampling_rate, normalize=False)
160 |
161 | return audio
162 |
163 |
164 |
165 | def selet_speaker_list_idx(age, gender):
166 | if age == 'under 20' and gender == 'men':
167 | return 0
168 | elif age == 'under 20' and gender == 'women':
169 | return 1
170 | elif age == '20-30' and gender == 'men':
171 | return 2
172 | elif age == '20-30' and gender == 'women':
173 | return 3
174 | elif age == 'over 30' and gender == 'men':
175 | return 4
176 | else:
177 | return 5
--------------------------------------------------------------------------------
/Speech/data/speaker_info/vctk_audio_sid_text_val_filelist.txt.cleaned.txt:
--------------------------------------------------------------------------------
1 | DUMMY2/p364/p364_240.wav|88|ɪt hɐd hˈæpənd tə hˌɪm.
2 | DUMMY2/p280/p280_148.wav|52|ɪt ɪz ˈoʊpən sˈiːzən ɑːnðɪ ˈoʊld fˈɜːm.
3 | DUMMY2/p231/p231_320.wav|50|haʊˈɛvɚ, hiː ɪz ɐ kˈoʊtʃ, ænd hiː ɹɪmˈeɪnz ɐ kˈoʊtʃ æt hˈɑːɹt.
4 | DUMMY2/p282/p282_129.wav|83|ɪt ɪz nˌɑːɾə jˈuːtˈɜːn.
5 | DUMMY2/p254/p254_015.wav|41|ðə ɡɹˈiːks jˈuːzd tʊ ɪmˈædʒɪn ðˌɐɾɪt wʌzɐ sˈaɪn fɹʌmðə ɡˈɑːdz tə foːɹtˈɛl wˈɔːɹ ɔːɹ hˈɛvi ɹˈeɪn.
6 | DUMMY2/p228/p228_285.wav|57|ðə sˈɔŋz ɑːɹ dʒˈʌst sˌoʊ ɡˈʊd.
7 | DUMMY2/p334/p334_307.wav|38|ɪf ðeɪ dˈoʊnt, ðeɪ kæn ɛkspˈɛkt ðɛɹ fˈʌndɪŋ təbi kˈʌt.
8 | DUMMY2/p287/p287_081.wav|77|aɪv nˈɛvɚ sˈiːn ˈɛnɪθˌɪŋ lˈaɪk ɪt.
9 | DUMMY2/p247/p247_083.wav|14|ɪt ɪz ɐ dʒˈɑːb kɹiːˈeɪʃən skˈiːm.
10 | DUMMY2/p264/p264_051.wav|65|wiː wɜː lˈiːdɪŋ baɪ tˈuː ɡˈoʊlz.
11 | DUMMY2/p335/p335_058.wav|49|lˈɛts sˈiː ðæt ˈɪnkɹiːs ˌoʊvɚ ðə jˈɪɹz.
12 | DUMMY2/p236/p236_225.wav|75|ðɛɹ ɪz nˈoʊ kwˈɪk fˈɪks.
13 | DUMMY2/p374/p374_353.wav|11|ænd ðæt bɹˈɪŋz ˌʌs tə ðə pˈɔɪnt.
14 | DUMMY2/p272/p272_076.wav|69|sˈaʊndz lˈaɪk ðə sˈɪksθ sˈɛns?
15 | DUMMY2/p271/p271_152.wav|27|ðə pətˈɪʃən wʌz fˈɔːɹməli pɹɪzˈɛntᵻd æt dˈaʊnɪŋ stɹˈiːt jˈɛstɚdˌeɪ.
16 | DUMMY2/p228/p228_127.wav|57|ðeɪv ɡɑːt tʊ ɐkˈaʊnt fɔːɹ ɪt.
17 | DUMMY2/p276/p276_223.wav|106|ɪts bˌɪn ɐ hˈʌmblɪŋ jˈɪɹ.
18 | DUMMY2/p262/p262_248.wav|45|ðə pɹˈɑːdʒɛkt hɐz ɔːlɹˌɛdi sɪkjˈʊɹd ðə səpˈoːɹt ʌv sˌɜː ʃˈɔːn kɑːnɚɹi.
19 | DUMMY2/p314/p314_086.wav|51|ðə tˈiːm ðɪs jˈɪɹ ɪz ɡˌoʊɪŋ plˈeɪsᵻz.
20 | DUMMY2/p225/p225_038.wav|101|dˈaɪvɪŋ ɪz nˈoʊ pˈɑːɹt ʌv fˈʊtbɔːl.
21 | DUMMY2/p279/p279_088.wav|25|ðə ʃˈɛɹhoʊldɚz wɪl vˈoʊt tə wˈaɪnd ˈʌp ðə kˈʌmpəni ˌɑːn fɹˈaɪdeɪ mˈɔːɹnɪŋ.
22 | DUMMY2/p272/p272_018.wav|69|ˈæɹɪstˌɑːɾəl θˈɔːt ðætðə ɹˈeɪnboʊ wʌz kˈɔːzd baɪ ɹɪflˈɛkʃən ʌvðə sˈʌnz ɹˈeɪz baɪ ðə ɹˈeɪn.
23 | DUMMY2/p256/p256_098.wav|90|ʃiː tˈoʊld ðə hˈɛɹəld.
24 | DUMMY2/p261/p261_218.wav|100|ˈɔːl wɪl biː ɹɪvˈiːld ɪn dˈuː kˈoːɹs.
25 | DUMMY2/p265/p265_063.wav|73|ɪt ʃˌʊdənt kˈʌm æz ɐ sɚpɹˈaɪz, bˌʌt ɪt dˈʌz.
26 | DUMMY2/p314/p314_042.wav|51|ɪt ɪz ˈɔːl ɐbˌaʊt pˈiːpəl bˌiːɪŋ ɐsˈɑːltᵻd, ɐbjˈuːsd.
27 | DUMMY2/p241/p241_188.wav|86|ˈaɪ wˈɪʃ ˈaɪ kʊd sˈeɪ sˈʌmθɪŋ.
28 | DUMMY2/p283/p283_111.wav|95|ɪts ɡˈʊd tə hæv ɐ vˈɔɪs.
29 | DUMMY2/p275/p275_006.wav|40|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ.
30 | DUMMY2/p228/p228_092.wav|57|tədˈeɪ ˈaɪ kˌʊdənt ɹˈʌn ˈɑːn ɪt.
31 | DUMMY2/p295/p295_343.wav|92|ðɪ ˈætməsfˌɪɹ ɪz bˈɪznəslˌaɪk.
32 | DUMMY2/p228/p228_187.wav|57|ðeɪ wɪl ɹˈʌn ɐ mˈaɪl.
33 | DUMMY2/p294/p294_317.wav|104|ɪt dˈɪdnt pˌʊt mˌiː ˈɔf.
34 | DUMMY2/p231/p231_445.wav|50|ɪt sˈaʊndᵻd lˈaɪk ɐ bˈɑːm.
35 | DUMMY2/p272/p272_086.wav|69|tədˈeɪ ʃiː hɐzbɪn ɹɪlˈiːsd.
36 | DUMMY2/p255/p255_210.wav|31|ɪt wʌz wˈɜːθ ɐ fˈoʊɾəɡɹˌæf.
37 | DUMMY2/p229/p229_060.wav|67|ænd ɐ fˈɪlm mˈeɪkɚ wʌz bˈɔːɹn.
38 | DUMMY2/p260/p260_232.wav|81|ðə hˈoʊm ˈɑːfɪs wʊd nˌɑːt ɹɪlˈiːs ˌɛni fˈɜːðɚ diːtˈeɪlz ɐbˌaʊt ðə ɡɹˈuːp.
39 | DUMMY2/p245/p245_025.wav|59|dʒˈɑːnsən wʌz pɹˈɪɾi lˈoʊ.
40 | DUMMY2/p333/p333_185.wav|64|ðɪs ˈɛɹiə ɪz pˈɜːfɛkt fɔːɹ tʃˈɪldɹən.
41 | DUMMY2/p244/p244_242.wav|78|hiː ɪz ɐ mˈæn ʌvðə pˈiːpəl.
42 | DUMMY2/p376/p376_187.wav|71|"ɪt ɪz ɐ tˈɛɹəbəl lˈɔs."
43 | DUMMY2/p239/p239_156.wav|48|ɪt ɪz ɐ ɡˈʊd lˈaɪfstaɪl.
44 | DUMMY2/p307/p307_037.wav|22|hiː ɹɪlˈiːsd ɐ hˈæfdˈʌzən sˈoʊloʊ ˈælbəmz.
45 | DUMMY2/p305/p305_185.wav|54|ˈaɪ æm nˌɑːt ˈiːvən θˈɪŋkɪŋ ɐbˌaʊt ðˈæt.
46 | DUMMY2/p272/p272_081.wav|69|ɪt wʌz mˈædʒɪk.
47 | DUMMY2/p302/p302_297.wav|30|aɪm tɹˈaɪɪŋ tə stˈeɪ ˈoʊpən ˌɑːn ðˈæt.
48 | DUMMY2/p275/p275_320.wav|40|wiː ɑːɹ ɪnðɪ ˈɛnd ɡˈeɪm.
49 | DUMMY2/p239/p239_231.wav|48|ðˈɛn wiː wɪl fˈeɪs ðə dˈeɪnɪʃ tʃˈæmpiənz.
50 | DUMMY2/p268/p268_301.wav|87|ɪt wʌz ˈoʊnli lˈeɪɾɚ ðætðə kəndˈɪʃən wʌz dˌaɪəɡnˈoʊzd.
51 | DUMMY2/p336/p336_088.wav|98|ðeɪ fˈeɪld tə ɹˈiːtʃ ɐɡɹˈiːmənt jˈɛstɚdˌeɪ.
52 | DUMMY2/p278/p278_255.wav|10|ðeɪ mˌeɪd sˈʌtʃ dᵻsˈɪʒənz ɪn lˈʌndən.
53 | DUMMY2/p361/p361_132.wav|79|ðæt ɡɑːt mˌiː ˈaʊt.
54 | DUMMY2/p307/p307_146.wav|22|juː hˈoʊp hiː pɹɪvˈeɪlz.
55 | DUMMY2/p244/p244_147.wav|78|ðeɪ kʊd nˌɑːt ɪɡnˈoːɹ ðə wɪl ʌv pˈɑːɹləmənt, hiː klˈeɪmd.
56 | DUMMY2/p294/p294_283.wav|104|ðɪs ɪz ˌaʊɚɹ ʌnfˈɪnɪʃt bˈɪznəs.
57 | DUMMY2/p283/p283_300.wav|95|ˈaɪ wʊdhɐv ðə hˈæmɚɹ ɪnðə kɹˈaʊd.
58 | DUMMY2/p239/p239_079.wav|48|ˈaɪ kæn ˌʌndɚstˈænd ðə fɹʌstɹˈeɪʃənz ʌv ˌaʊɚ fˈænz.
59 | DUMMY2/p264/p264_009.wav|65|ðɛɹˈɪz , ɐkˈoːɹdɪŋ tə lˈɛdʒənd, ɐ bˈɔɪlɪŋ pˈɑːt ʌv ɡˈoʊld æt wˈʌn ˈɛnd.
60 | DUMMY2/p307/p307_348.wav|22|hiː dɪdnˌɑːt əpˈoʊz ðə dɪvˈoːɹs.
61 | DUMMY2/p304/p304_308.wav|72|wiː ɑːɹ ðə ɡˈeɪtweɪ tə dʒˈʌstɪs.
62 | DUMMY2/p281/p281_056.wav|36|nˈʌn hɐz ˈɛvɚ bˌɪn fˈaʊnd.
63 | DUMMY2/p267/p267_158.wav|0|wiː wɜː ɡˈɪvən ɐ wˈɔːɹm ænd fɹˈɛndli ɹɪsˈɛpʃən.
64 | DUMMY2/p300/p300_169.wav|102|hˌuː dˈuː ðiːz pˈiːpəl θˈɪŋk ðeɪ ɑːɹ?
65 | DUMMY2/p276/p276_177.wav|106|ðeɪ ɛɡzˈɪst ɪn nˈeɪm ɐlˈoʊn.
66 | DUMMY2/p228/p228_245.wav|57|ɪt ɪz ɐ pˈɑːlɪsi wˌɪtʃ hɐz ðə fˈʊl səpˈoːɹt ʌvðə mˈɪnɪstɚ.
67 | DUMMY2/p300/p300_303.wav|102|aɪm wˈʌndɚɹɪŋ wˌʌt juː fˈiːl ɐbˌaʊt ðə jˈʌŋɡəst.
68 | DUMMY2/p362/p362_247.wav|15|ðɪs wʊd ɡˈɪv skˈɑːtlənd ɐɹˈaʊnd ˈeɪt mˈɛmbɚz.
69 | DUMMY2/p326/p326_031.wav|28|juːnˈaɪɾᵻd wɜːɹ ɪn kəntɹˈoʊl wɪðˌaʊt ˈɔːlweɪz bˌiːɪŋ dˈɑːmɪnənt.
70 | DUMMY2/p361/p361_288.wav|79|ˈaɪ dɪdnˌɑːt θˈɪŋk ɪt wʌz vˈɛɹi pɹˈɑːpɚ.
71 | DUMMY2/p286/p286_145.wav|63|tˈaɪɡɚɹ ɪz nˌɑːt ðə nˈɔːɹm.
72 | DUMMY2/p234/p234_071.wav|3|ʃiː dˈɪd ðæt fɚðə ɹˈɛst ʌv hɜː lˈaɪf.
73 | DUMMY2/p263/p263_296.wav|39|ðə dᵻsˈɪʒən wʌz ɐnˈaʊnst æt ɪts ˈænjuːəl kˈɑːnfɹəns ɪn dˈʌnfɚmlˌaɪn.
74 | DUMMY2/p323/p323_228.wav|34|ʃiː bɪkˌeɪm ɐ hˈɛɹoʊˌɪn ʌv maɪ tʃˈaɪldhʊd.
75 | DUMMY2/p280/p280_346.wav|52|ɪt wʌzɐ bˈɪt lˈaɪk hˌævɪŋ tʃˈɪldɹən.
76 | DUMMY2/p333/p333_080.wav|64|bˌʌt ðə tɹˈædʒədi dɪdnˌɑːt stˈɑːp ðˈɛɹ.
77 | DUMMY2/p226/p226_268.wav|43|ðæt dᵻsˈɪʒən ɪz fɚðə bɹˈɪɾɪʃ pˈɑːɹləmənt ænd pˈiːpəl.
78 | DUMMY2/p362/p362_314.wav|15|ɪz ðæt ɹˈaɪt?
79 | DUMMY2/p240/p240_047.wav|93|ɪt ɪz sˌoʊ sˈæd.
80 | DUMMY2/p250/p250_207.wav|24|juː kʊd fˈiːl ðə hˈiːt.
81 | DUMMY2/p273/p273_176.wav|56|nˈiːðɚ sˈaɪd wʊd ɹɪvˈiːl ðə diːtˈeɪlz ʌvðɪ ˈɑːfɚ.
82 | DUMMY2/p316/p316_147.wav|85|ænd fɹˈæŋkli, ɪts bˌɪn ɐ wˈaɪl.
83 | DUMMY2/p265/p265_047.wav|73|ɪt ɪz juːnˈiːk.
84 | DUMMY2/p336/p336_353.wav|98|sˈʌmtaɪmz juː ɡˈɛt ðˌɛm, sˈʌmtaɪmz juː dˈoʊnt.
85 | DUMMY2/p230/p230_376.wav|35|ðɪs hˈæzənt hˈæpənd ɪn ɐ vˈækjuːm.
86 | DUMMY2/p308/p308_209.wav|107|ðɛɹ ɪz ɡɹˈeɪt pətˈɛnʃəl ˌɑːn ðɪs ɹˈɪvɚ.
87 | DUMMY2/p250/p250_442.wav|24|wiː hɐvnˌɑːt jˈɛt ɹɪsˈiːvd ɐ lˈɛɾɚ fɹʌmðɪ ˈaɪɹɪʃ.
88 | DUMMY2/p260/p260_037.wav|81|ɪts ɐ fˈækt.
89 | DUMMY2/p299/p299_345.wav|58|wɪɹ vˈɛɹi ɛksˈaɪɾᵻd ænd tʃˈælɪndʒd baɪ ðə pɹˈɑːdʒɛkt.
90 | DUMMY2/p269/p269_218.wav|94|ɐ ɡɹˈæmpiən pəlˈiːs spˈoʊksmən sˈɛd.
91 | DUMMY2/p306/p306_014.wav|12|tə ðə hˈiːbɹuːz ɪt wʌzɐ tˈoʊkən ðæt ðɛɹ wʊd biː nˈoʊmˌoːɹ jˌuːnɪvˈɜːsəl flˈʌdz.
92 | DUMMY2/p271/p271_292.wav|27|ɪts ɐ ɹˈɛkɚd lˈeɪbəl, nˌɑːɾə fˈɔːɹm ʌv mjˈuːzɪk.
93 | DUMMY2/p247/p247_225.wav|14|ˈaɪ æm kənsˈɪdɚd ɐ tˈiːneɪdʒɚ.
94 | DUMMY2/p294/p294_094.wav|104|ɪt ʃˌʊd biː ɐ kəndˈɪʃən ʌv ɛmplˈɔɪmənt.
95 | DUMMY2/p269/p269_031.wav|94|ɪz ðɪs ˈækjʊɹət?
96 | DUMMY2/p275/p275_116.wav|40|ɪts nˌɑːt fˈɛɹ.
97 | DUMMY2/p265/p265_006.wav|73|wˌɛn ðə sˈʌnlaɪt stɹˈaɪks ɹˈeɪndɹɑːps ɪnðɪ ˈɛɹ, ðeɪ ˈækt æz ɐ pɹˈɪzəm ænd fˈɔːɹm ɐ ɹˈeɪnboʊ.
98 | DUMMY2/p285/p285_072.wav|2|mˈɪstɚɹ ˈɜːvaɪn sˈɛd mˈɪstɚ ɹˈæfɚɾi wʌz nˈaʊ ɪn ɡˈʊd spˈɪɹɪts.
99 | DUMMY2/p270/p270_167.wav|8|wiː dˈɪd wˌʌt wiː hædtə dˈuː.
100 | DUMMY2/p360/p360_397.wav|60|ɪt ɪz ɐ ɹɪlˈiːf.
101 |
--------------------------------------------------------------------------------
/Dialogue/gen_coat.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import lightgbm as lgb
4 | import random
5 | import json
6 | import os
7 | from coat_utils import *
8 | from collections import defaultdict
9 | from config.coat_pattern import start_pattern, agent_pattern, user_pattern
10 | from config.thanks_pattern import coat_agent, coat_user
11 | from config.recommend_pattern import coat_rec
12 | from coat_attr import generate_gender_dialogue, generate_jacket_dialogue, generate_color_dialogue
13 |
14 | def load_data():
15 | coat = pd.read_csv("./data/coat/coat_info.csv")
16 | coat['age'] = coat['age'].apply(lambda age: age_map(age))
17 | item_feature = np.genfromtxt('./data/coat/item_features.ascii', dtype=None)
18 |
19 | return (coat, item_feature)
20 |
21 | def calculate_user_preference(data):
22 | coat = data[0]
23 | item_feature = data[1]
24 | user_record = defaultdict(set)
25 | item_matrix = np.zeros((300, 3))
26 | for _, row in coat.iterrows():
27 | user_id = int(row['user'])
28 | item_id = int(row['item'])
29 | user_record[user_id].add(item_id)
30 | gender, jacket, color = get_item_index(item_feature, item_id)
31 | item_matrix[item_id][0] = gender
32 | item_matrix[item_id][1] = jacket
33 | item_matrix[item_id][2] = color
34 |
35 | item_csv = pd.DataFrame(item_matrix)
36 | item_csv.insert(item_csv.shape[1], 'label', 0)
37 | item_csv.columns = ['gender','jacket', 'color', 'label']
38 | features_cols = ['gender', 'jacket', 'color']
39 | user_preference = defaultdict(list)
40 | for user in range(300):
41 | user_item_csv = item_csv.copy()
42 | record = list(user_record[user])
43 | for item in record:
44 | user_item_csv.loc[item, 'label'] = 1
45 | X = user_item_csv[features_cols]
46 | Y = user_item_csv.label
47 |
48 | cls = lgb.LGBMClassifier(importance_type='gain')
49 | cls.fit(X, Y)
50 |
51 | indices = np.argsort(cls.booster_.feature_importance(importance_type='gain'))
52 | feature = [features_cols[i] for i in indices]
53 |
54 | user_preference[user] = feature
55 |
56 | return user_preference
57 |
58 | def get_user_item_info(data):
59 | user_info = {}
60 | item_info = {}
61 | coat = data[0]
62 | item_feature = data[1]
63 | for _, row in coat.iterrows():
64 | user_id = int(row['user'])
65 | item_id = int(row['item'])
66 | user_info[user_id] = {}
67 | user_info[user_id]['age'] = get_user_age(row['age'])
68 | user_info[user_id]['gender'] = get_user_gender(row['gender'])
69 |
70 | if item_id not in item_info.keys():
71 | gender, jacket, color = get_item_index(item_feature, item_id)
72 | item_info[item_id] = {}
73 | item_info[item_id]['gender'] = get_item_gender(gender)
74 | item_info[item_id]['jacket'] = get_item_type(jacket)
75 | item_info[item_id]['color'] = get_item_color(color)
76 |
77 | return user_info, item_info
78 |
79 | def calculate_attr_weights(data):
80 | gender_all = {}
81 | jacket_all = {}
82 | color_all = {}
83 | coat = data[0]
84 | item_feature = data[1]
85 | for _, row in coat.iterrows():
86 | item_id = int(row['item'])
87 | gender, jacket, color = get_item_index(item_feature, item_id)
88 | if gender not in gender_all.keys():
89 | gender_all[gender] = 1
90 | else:
91 | gender_all[gender] += 1
92 |
93 | if jacket not in jacket_all.keys():
94 | jacket_all[jacket] = 1
95 | else:
96 | jacket_all[jacket] += 1
97 |
98 | if color not in color_all.keys():
99 | color_all[color] = 1
100 | else:
101 | color_all[color] += 1
102 |
103 | gender_weight = [gender_all[i] for i in sorted(gender_all)]
104 | jacket_weight = [jacket_all[i] for i in sorted(jacket_all)]
105 | color_weight = [color_all[i] for i in sorted(color_all)]
106 |
107 | return (gender_weight, jacket_weight, color_weight), (gender_all, jacket_all, color_all)
108 |
109 |
110 | if __name__ == '__main__':
111 | coat_data = load_data()
112 | user_preference = calculate_user_preference(coat_data)
113 | user_info, item_info = get_user_item_info(coat_data)
114 | weights, attr_counts = calculate_attr_weights(coat_data)
115 | print("data load complete")
116 | dialogue_info = {}
117 | dialogue_id = 0
118 | for _, row in coat_data[0].iterrows():
119 | user_id = int(row['user'])
120 | item_id = int(row['item'])
121 |
122 | new_dialogue = {}
123 | new_dialogue["user_id"] = user_id
124 | new_dialogue["item_id"] = item_id
125 |
126 | new_dialogue["user_gender"] = user_info[user_id]["gender"]
127 | new_dialogue["user_age"] = user_info[user_id]["age"]
128 | new_dialogue["content"] = {}
129 | new_dialogue["content"]["start"] = random.choice(start_pattern)
130 |
131 | dialouge_order = user_preference[user_id]
132 | tmp_new_dialogue = []
133 |
134 | for slot in dialouge_order:
135 | if slot == "gender":
136 | gender_val = item_info[item_id]["gender"]
137 | utterance = generate_gender_dialogue(agent_pattern, user_pattern, gender_val, attr_counts[0], weights[0])
138 | elif slot == "jacket":
139 | jacket_val = item_info[item_id]["jacket"]
140 | utterance = generate_jacket_dialogue(agent_pattern, user_pattern, jacket_val, attr_counts[1], weights[1])
141 | elif slot == "color":
142 | color_val = item_info[item_id]["color"]
143 | utterance = generate_color_dialogue(agent_pattern, user_pattern, color_val, attr_counts[2], weights[2])
144 | tmp_new_dialogue.append(utterance)
145 | end_dialogue = {}
146 | end_dialogue["rec"] = random.choice(coat_rec)
147 | end_dialogue["thanks_user"] = random.choice(coat_user)
148 | end_dialogue["thanks_agent"] = random.choice(coat_agent)
149 | tmp_new_dialogue.append(end_dialogue)
150 | print("finish:", dialogue_id)
151 | start_index = 0
152 | end_index = 0
153 | step = 0
154 | name = ["Q1", "A1", "Q2", "A2", "Q3", "A3", "Q4", "A4", "Q5", "A5", "Q6", "A6", "Q7", "A7", "Q8", "A8", "Q9", "A9", "Q10", "A10", "Q11", "A11", "Q12", "A12"]
155 | for dia in tmp_new_dialogue:
156 | end_index = len(dia) + end_index
157 | tmp_name = name[start_index : end_index]
158 | tmp_dia = []
159 | for _, v in dia.items():
160 | tmp_dia.append(v)
161 | for i, val in enumerate(tmp_name):
162 | new_dialogue["content"][val] = tmp_dia[i]
163 | start_index = end_index
164 | dialogue_info[dialogue_id] = new_dialogue
165 | dialogue_id = dialogue_id + 1
166 |
167 |
168 | res_path = './res/coat/'
169 | if not os.path.exists(res_path):
170 | os.makedirs(res_path)
171 | with open(res_path + 'dialogue_info_coat.json', 'w') as f:
172 | json.dump(dialogue_info, f, indent=4)
173 |
174 |
175 |
176 |
177 |
178 |
179 |
--------------------------------------------------------------------------------
/Recommender/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from collections import defaultdict
4 |
5 |
6 | def get_ur(df):
7 | """
8 | Method of getting user-rating pairs
9 | Parameters
10 | ----------
11 | df : pd.DataFrame, rating dataframe
12 |
13 | Returns
14 | -------
15 | ur : dict, dictionary stored user-items interactions
16 | """
17 | ur = defaultdict(set)
18 | for _, row in df.iterrows():
19 | ur[int(row['user'])].add(int(row['item']))
20 |
21 | return ur
22 |
23 |
24 | def build_candidates_set(test_ur, train_ur, item_pool, candidates_num=1000):
25 | """
26 | method of building candidate items for ranking
27 | Parameters
28 | ----------
29 | test_ur : dict, ground_truth that represents the relationship of user and item in the test set
30 | train_ur : dict, this represents the relationship of user and item in the train set
31 | item_pool : the set of all items
32 | candidates_num : int, the number of candidates
33 | Returns
34 | -------
35 | test_ucands : dict, dictionary storing candidates for each user in test set
36 | """
37 | # random.seed(1)
38 | test_ucands = defaultdict(list)
39 | for k, v in test_ur.items():
40 | sample_num = candidates_num - len(v) if len(v) < candidates_num else 0
41 | sub_item_pool = item_pool - v - train_ur[k] # remove GT & interacted
42 | sample_num = min(len(sub_item_pool), sample_num)
43 | if sample_num == 0:
44 | samples = random.sample(v, candidates_num)
45 | test_ucands[k] = list(set(samples))
46 | else:
47 | samples = random.sample(sub_item_pool, sample_num)
48 | test_ucands[k] = list(v | set(samples))
49 |
50 | return test_ucands
51 |
52 |
53 | def precision_at_k(r, k):
54 | """
55 | Precision calculation method
56 | Parameters
57 | ----------
58 | r : List, list of the rank items
59 | k : int, top-K number
60 |
61 | Returns
62 | -------
63 | pre : float, precision value
64 | """
65 | assert k >= 1
66 | r = np.asarray(r)[:k] != 0
67 | if r.size != k:
68 | raise ValueError('Relevance score length < k')
69 | # return np.mean(r)
70 | pre = sum(r) / len(r)
71 |
72 | return pre
73 |
74 |
75 | def recall_at_k(rs, test_ur, k):
76 | """
77 | Recall calculation method
78 | Parameters
79 | ----------
80 | rs : Dict, {user : rank items} for test set
81 | test_ur : Dict, {user : items} for test set ground truth
82 | k : int, top-K number
83 |
84 | Returns
85 | -------
86 | rec : float recall value
87 | """
88 | assert k >= 1
89 | res = []
90 | for user in test_ur.keys():
91 | r = np.asarray(rs[user])[:k] != 0
92 | if r.size != k:
93 | raise ValueError('Relevance score length < k')
94 | if len(test_ur[user]) == 0:
95 | raise KeyError(f'Invalid User Index: {user}')
96 | res.append(sum(r) / len(test_ur[user]))
97 | rec = np.mean(res)
98 |
99 | return rec
100 |
101 |
102 | def mrr_at_k(rs, k):
103 | """
104 | Mean Reciprocal Rank calculation method
105 | Parameters
106 | ----------
107 | rs : Dict, {user : rank items} for test set
108 | k : int, topK number
109 |
110 | Returns
111 | -------
112 | mrr : float, MRR value
113 | """
114 | assert k >= 1
115 | res = 0
116 | for r in rs.values():
117 | r = np.asarray(r)[:k] != 0
118 | for index, item in enumerate(r):
119 | if item == 1:
120 | res += 1 / (index + 1)
121 | mrr = res / len(rs)
122 |
123 | return mrr
124 |
125 |
126 | def ap(r):
127 | """
128 | Average precision calculation method
129 | Parameters
130 | ----------
131 | r : List, Relevance scores (list or numpy) in rank order (first element is the first item)
132 |
133 | Returns
134 | -------
135 | a_p : float, Average precision value
136 | """
137 | r = np.asarray(r) != 0
138 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]
139 | if not out:
140 | return 0.
141 | a_p = np.sum(out) / len(r)
142 |
143 | return a_p
144 |
145 |
146 | def map_at_k(rs):
147 | """
148 | Mean Average Precision calculation method
149 | Parameters
150 | ----------
151 | rs : Dict, {user : rank items} for test set
152 |
153 | Returns
154 | -------
155 | m_a_p : float, MAP value
156 | """
157 | m_a_p = np.mean([ap(r) for r in rs])
158 | return m_a_p
159 |
160 |
161 | def dcg_at_k(r, k):
162 | """
163 | Discounted Cumulative Gain calculation method
164 | Parameters
165 | ----------
166 | r : List, Relevance scores (list or numpy) in rank order
167 | (first element is the first item)
168 | k : int, top-K number
169 |
170 | Returns
171 | -------
172 | dcg : float, DCG value
173 | """
174 | assert k >= 1
175 | r = np.asfarray(r)[:k] != 0
176 | if r.size:
177 | dcg = np.sum(np.subtract(np.power(2, r), 1) / np.log2(np.arange(2, r.size + 2)))
178 | return dcg
179 | return 0.
180 |
181 |
182 | def ndcg_at_k(r, k):
183 | """
184 | Normalized Discounted Cumulative Gain calculation method
185 | Parameters
186 | ----------
187 | r : List, Relevance scores (list or numpy) in rank order
188 | (first element is the first item)
189 | k : int, top-K number
190 |
191 | Returns
192 | -------
193 | ndcg : float, NDCG value
194 | """
195 | assert k >= 1
196 | idcg = dcg_at_k(sorted(r, reverse=True), k)
197 | if not idcg:
198 | return 0.
199 | ndcg = dcg_at_k(r, k) / idcg
200 |
201 | return ndcg
202 |
203 |
204 | def hr_at_k(rs, test_ur):
205 | """
206 | Hit Ratio calculation method
207 | Parameters
208 | ----------
209 | rs : Dict, {user : rank items} for test set
210 | test_ur : (Deprecated) Dict, {user : items} for test set ground truth
211 |
212 | Returns
213 | -------
214 | hr : float, HR value
215 | """
216 | # another way for calculating hit rate
217 | # numer, denom = 0., 0.
218 | # for user in test_ur.keys():
219 | # numer += np.sum(rs[user])
220 | # denom += len(test_ur[user])
221 |
222 | # return numer / denom
223 | uhr = 0
224 | for r in rs.values():
225 | if np.sum(r) != 0:
226 | uhr += 1
227 | hr = uhr / len(rs)
228 |
229 | return hr
230 |
231 | def get_feature(feature_list):
232 | if 'gender' in feature_list:
233 | feature_num = 0
234 | elif 'age' in feature_list:
235 | feature_num = 1
236 | elif 'gender' in feature_list and 'age' in feature_list:
237 | feature_num = 2
238 | elif 'gender' not in feature_list and 'age' not in feature_list:
239 | feature_num = 3
240 |
241 | return feature_num
242 |
243 | def get_user_info(df, feature_num):
244 | user_info = dict()
245 | for _, row in df.iterrows():
246 | if feature_num == 0:
247 | user_info[int(row['user'])] = [int(row['gender'])]
248 | elif feature_num == 1:
249 | user_info[int(row['user'])] = [int(row['age'])]
250 | elif feature_num == 2:
251 | user_info[int(row['user'])] = [int(row['gender']), int(row['age'])]
252 | return user_info
253 |
254 |
--------------------------------------------------------------------------------
/Recommender/gen_label.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 | from torch.utils.data import Dataset,DataLoader
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchaudio
9 | import tqdm as tqdm
10 | import argparse
11 | from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel, AutoConfig
12 |
13 |
14 | class Wav2Vec2ClassificationModel(Wav2Vec2PreTrainedModel):
15 | def __init__(self, config, hidden_size, dropout):
16 | super().__init__(config)
17 |
18 | self.wav2vec2 = Wav2Vec2Model(config)
19 | self.hidden_size = hidden_size
20 | self.fc = nn.Linear(config.hidden_size, self.hidden_size)
21 | self.dropout = nn.Dropout(dropout)
22 | self.gender_fc = nn.Linear(self.hidden_size, 2)
23 | self.age_fc = nn.Linear(self.hidden_size, 3)
24 | self.tanh = nn.Tanh()
25 |
26 | self.init_weights()
27 |
28 | def freeze_feature_extractor(self):
29 | self.wav2vec2.feature_extractor._freeze_parameters()
30 |
31 | def merged_strategy(self, hidden_states):
32 | outputs = torch.mean(hidden_states, dim=1)
33 |
34 | return outputs
35 |
36 | def forward(
37 | self,
38 | input_values,
39 | attention_mask=None,
40 | output_attentions=None,
41 | output_hidden_states=None,
42 | return_dict=None,
43 | labels=None,
44 | ):
45 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
46 | with torch.no_grad():
47 | outputs = self.wav2vec2(
48 | input_values,
49 | attention_mask=attention_mask,
50 | output_attentions=output_attentions,
51 | output_hidden_states=output_hidden_states,
52 | return_dict=return_dict,
53 | )
54 |
55 | hidden_states = outputs[0]
56 | x = self.merged_strategy(hidden_states)
57 | x = self.dropout(x)
58 | x = self.fc(x)
59 | x = self.tanh(x)
60 | x = self.dropout(x)
61 | gender_logits = self.gender_fc(x)
62 | gender_logits = F.log_softmax(gender_logits, dim=-1)
63 | age_logits = self.age_fc(x)
64 | age_logits = F.log_softmax(age_logits, dim=-1)
65 |
66 | return age_logits, gender_logits
67 |
68 | def construct_data(input_dir):
69 | datalist = os.listdir(input_dir)
70 | user = []
71 | item = []
72 | gender = []
73 | age = []
74 | name = []
75 | for file in datalist:
76 | u_id = int(file.split('_')[1][3:])
77 | i_id = int(file.split('_')[2][3:])
78 | user.append(u_id)
79 | item.append(i_id)
80 |
81 | name.append(file)
82 | gender.append(file.split('_')[-2])
83 | age.append(file.split('_')[-3])
84 | data = pd.DataFrame({'user':user, 'item':item, 'gender':gender, 'age':age, 'audio_name':name})
85 |
86 | return data
87 |
88 | class GenDataset(Dataset):
89 | def __init__(self, df, audio_path):
90 | self.df = df
91 | self.audio_path = audio_path
92 | self.user = self.df.iloc[:,0].values
93 | self.item = self.df.iloc[:,1].values
94 | self.gender = self.df.iloc[:,2].values
95 | self.age = self.df.iloc[:,3].values
96 | self.audios = self.df.iloc[:,4].values
97 | self.age_labels = {'under 20':0, '20-30':1, 'over 30':2}
98 | self.gender_labels = {'women':0, 'men':1}
99 |
100 | def __len__(self):
101 | return len(self.df)
102 |
103 | def __getitem__(self, index):
104 | user = self.user[index]
105 | item = self.item[index]
106 | gender = self.gender_labels[self.gender[index]]
107 | age = self.age_labels[self.age[index]]
108 |
109 | audio_file_path = os.path.join(self.audio_path, self.audios[index])
110 | waveform, sample_rate = torchaudio.load(audio_file_path)
111 | waveform = self._resample(waveform, sample_rate)
112 |
113 | return user, item, gender, age, waveform
114 |
115 | def _resample(self, waveform, sample_rate):
116 | resampler = torchaudio.transforms.Resample(sample_rate,16000)
117 |
118 | return resampler(waveform)
119 |
120 | def pad_sequence(batch):
121 | # Make all tensor in a batch the same length by padding with zeros
122 | batch = [item.t() for item in batch]
123 | batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
124 | return batch.permute(0, 2, 1)
125 |
126 | def collate_fn(batch):
127 |
128 | users, items, audios, age_labels, gender_labels = [], [], [], [], []
129 |
130 | for user, item, gender, age, waveform in batch:
131 | audios += [waveform]
132 | users += [torch.tensor(user)]
133 | items += [torch.tensor(item)]
134 | age_labels += [torch.tensor(age)]
135 | gender_labels += [torch.tensor(gender)]
136 |
137 | audios = pad_sequence(audios)
138 | users = torch.stack(users)
139 | items = torch.stack(items)
140 | age_labels = torch.stack(age_labels)
141 | gender_labels = torch.stack(gender_labels)
142 |
143 | return users, items, gender_labels, age_labels, audios.squeeze(dim=1)
144 |
145 |
146 | def get_likely_index(tensor):
147 | # find most likely label index for each element in the batch
148 | return tensor.argmax(dim=-1)
149 |
150 | def number_of_correct(pred, target):
151 | # count number of correct predictions
152 | return pred.eq(target).sum().item()
153 |
154 | if __name__ == '__main__':
155 | parser = argparse.ArgumentParser(description='generate label on audio')
156 | parser.add_argument('--model',
157 | type=str,
158 | default='ml-1m')
159 | parser.add_argument('--dataset',
160 | type=str,
161 | default='coat')
162 | args = parser.parse_args()
163 |
164 | model = torch.load(f'./clf_model_{args.model}.pt')
165 | audio_path = f'./data/{args.dataset}/'
166 | data = construct_data(audio_path)
167 |
168 | pre_data = GenDataset(data, audio_path)
169 | pre_loader = DataLoader(
170 | pre_data,
171 | batch_size=4,
172 | shuffle=True,
173 | collate_fn=collate_fn,
174 | num_workers=4,
175 | pin_memory=False,
176 | )
177 |
178 | model.eval()
179 | age_correct = 0
180 | gender_correct = 0
181 | tag = 1
182 | for user, item, gender_label, age_label, waveform in tqdm.tqdm(pre_loader):
183 | waveform = waveform.to("cuda")
184 | age_label = age_label.to("cuda")
185 | gender_label = gender_label.to("cuda")
186 |
187 | age_logits, gender_logits = model(waveform)
188 |
189 | age_pred = get_likely_index(age_logits) # batch
190 | gender_pred = get_likely_index(gender_logits)
191 |
192 | age_correct += number_of_correct(age_pred, age_label)
193 | gender_correct += number_of_correct(gender_pred, gender_label)
194 |
195 | if tag == 1:
196 | users = user.numpy()
197 | items = item.numpy()
198 | genders = gender_pred.cpu().numpy()
199 | ages = age_pred.cpu().numpy()
200 | tag = 0
201 | else:
202 | users = np.hstack((users, user.numpy()))
203 | items = np.hstack((items, item.numpy()))
204 | genders = np.hstack((genders, gender_pred.cpu().numpy()))
205 | ages = np.hstack((ages, age_pred.cpu().numpy()))
206 |
207 |
208 | age_accu = age_correct / len(pre_loader.dataset) * 100.
209 | gender_accu = gender_correct / len(pre_loader.dataset) * 100.
210 |
211 | print(f"Age: {age_accu:.2f}% Gender: {gender_accu:.2f}%")
212 | result_save_path = './res/'
213 | if not os.path.exists(result_save_path):
214 | os.makedirs(result_save_path)
215 | data = pd.DataFrame({'user':list(users), 'item':list(items), 'gender':list(genders), 'age':list(ages)})
216 | data.to_csv(f'./res/{args.dataset}_predict.csv', index=False)
--------------------------------------------------------------------------------
/Speech/vits_lib/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import argparse
5 | import logging
6 | import json
7 | import subprocess
8 | import numpy as np
9 | from scipy.io.wavfile import read
10 | import torch
11 |
12 | MATPLOTLIB_FLAG = False
13 |
14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15 | logger = logging
16 |
17 |
18 | def load_checkpoint(checkpoint_path, model, optimizer=None):
19 | assert os.path.isfile(checkpoint_path)
20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21 | iteration = checkpoint_dict['iteration']
22 | learning_rate = checkpoint_dict['learning_rate']
23 | if optimizer is not None:
24 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
25 | saved_state_dict = checkpoint_dict['model']
26 | if hasattr(model, 'module'):
27 | state_dict = model.module.state_dict()
28 | else:
29 | state_dict = model.state_dict()
30 | new_state_dict= {}
31 | for k, v in state_dict.items():
32 | try:
33 | new_state_dict[k] = saved_state_dict[k]
34 | except:
35 | logger.info("%s is not in the checkpoint" % k)
36 | new_state_dict[k] = v
37 | if hasattr(model, 'module'):
38 | model.module.load_state_dict(new_state_dict)
39 | else:
40 | model.load_state_dict(new_state_dict)
41 | logger.info("Loaded checkpoint '{}' (iteration {})" .format(
42 | checkpoint_path, iteration))
43 | return model, optimizer, learning_rate, iteration
44 |
45 |
46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
47 | logger.info("Saving model and optimizer state at iteration {} to {}".format(
48 | iteration, checkpoint_path))
49 | if hasattr(model, 'module'):
50 | state_dict = model.module.state_dict()
51 | else:
52 | state_dict = model.state_dict()
53 | torch.save({'model': state_dict,
54 | 'iteration': iteration,
55 | 'optimizer': optimizer.state_dict(),
56 | 'learning_rate': learning_rate}, checkpoint_path)
57 |
58 |
59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
60 | for k, v in scalars.items():
61 | writer.add_scalar(k, v, global_step)
62 | for k, v in histograms.items():
63 | writer.add_histogram(k, v, global_step)
64 | for k, v in images.items():
65 | writer.add_image(k, v, global_step, dataformats='HWC')
66 | for k, v in audios.items():
67 | writer.add_audio(k, v, global_step, audio_sampling_rate)
68 |
69 |
70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
71 | f_list = glob.glob(os.path.join(dir_path, regex))
72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
73 | x = f_list[-1]
74 | print(x)
75 | return x
76 |
77 |
78 | def plot_spectrogram_to_numpy(spectrogram):
79 | global MATPLOTLIB_FLAG
80 | if not MATPLOTLIB_FLAG:
81 | import matplotlib
82 | matplotlib.use("Agg")
83 | MATPLOTLIB_FLAG = True
84 | mpl_logger = logging.getLogger('matplotlib')
85 | mpl_logger.setLevel(logging.WARNING)
86 | import matplotlib.pylab as plt
87 | import numpy as np
88 |
89 | fig, ax = plt.subplots(figsize=(10,2))
90 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
91 | interpolation='none')
92 | plt.colorbar(im, ax=ax)
93 | plt.xlabel("Frames")
94 | plt.ylabel("Channels")
95 | plt.tight_layout()
96 |
97 | fig.canvas.draw()
98 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
99 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
100 | plt.close()
101 | return data
102 |
103 |
104 | def plot_alignment_to_numpy(alignment, info=None):
105 | global MATPLOTLIB_FLAG
106 | if not MATPLOTLIB_FLAG:
107 | import matplotlib
108 | matplotlib.use("Agg")
109 | MATPLOTLIB_FLAG = True
110 | mpl_logger = logging.getLogger('matplotlib')
111 | mpl_logger.setLevel(logging.WARNING)
112 | import matplotlib.pylab as plt
113 | import numpy as np
114 |
115 | fig, ax = plt.subplots(figsize=(6, 4))
116 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
117 | interpolation='none')
118 | fig.colorbar(im, ax=ax)
119 | xlabel = 'Decoder timestep'
120 | if info is not None:
121 | xlabel += '\n\n' + info
122 | plt.xlabel(xlabel)
123 | plt.ylabel('Encoder timestep')
124 | plt.tight_layout()
125 |
126 | fig.canvas.draw()
127 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
128 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
129 | plt.close()
130 | return data
131 |
132 |
133 | def load_wav_to_torch(full_path):
134 | sampling_rate, data = read(full_path)
135 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate
136 |
137 |
138 | def load_filepaths_and_text(filename, split="|"):
139 | with open(filename, encoding='utf-8') as f:
140 | filepaths_and_text = [line.strip().split(split) for line in f]
141 | return filepaths_and_text
142 |
143 |
144 | def get_hparams(init=True):
145 | parser = argparse.ArgumentParser()
146 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
147 | help='JSON file for configuration')
148 | parser.add_argument('-m', '--model', type=str, required=True,
149 | help='Model name')
150 |
151 | args = parser.parse_args()
152 | model_dir = os.path.join("./logs", args.model)
153 |
154 | if not os.path.exists(model_dir):
155 | os.makedirs(model_dir)
156 |
157 | config_path = args.config
158 | config_save_path = os.path.join(model_dir, "config.json")
159 | if init:
160 | with open(config_path, "r") as f:
161 | data = f.read()
162 | with open(config_save_path, "w") as f:
163 | f.write(data)
164 | else:
165 | with open(config_save_path, "r") as f:
166 | data = f.read()
167 | config = json.loads(data)
168 |
169 | hparams = HParams(**config)
170 | hparams.model_dir = model_dir
171 | return hparams
172 |
173 |
174 | def get_hparams_from_dir(model_dir):
175 | config_save_path = os.path.join(model_dir, "config.json")
176 | with open(config_save_path, "r") as f:
177 | data = f.read()
178 | config = json.loads(data)
179 |
180 | hparams =HParams(**config)
181 | hparams.model_dir = model_dir
182 | return hparams
183 |
184 |
185 | def get_hparams_from_file(config_path):
186 | with open(config_path, "r") as f:
187 | data = f.read()
188 | config = json.loads(data)
189 |
190 | hparams =HParams(**config)
191 | return hparams
192 |
193 |
194 | def check_git_hash(model_dir):
195 | source_dir = os.path.dirname(os.path.realpath(__file__))
196 | if not os.path.exists(os.path.join(source_dir, ".git")):
197 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
198 | source_dir
199 | ))
200 | return
201 |
202 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
203 |
204 | path = os.path.join(model_dir, "githash")
205 | if os.path.exists(path):
206 | saved_hash = open(path).read()
207 | if saved_hash != cur_hash:
208 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(
209 | saved_hash[:8], cur_hash[:8]))
210 | else:
211 | open(path, "w").write(cur_hash)
212 |
213 |
214 | def get_logger(model_dir, filename="train.log"):
215 | global logger
216 | logger = logging.getLogger(os.path.basename(model_dir))
217 | logger.setLevel(logging.DEBUG)
218 |
219 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
220 | if not os.path.exists(model_dir):
221 | os.makedirs(model_dir)
222 | h = logging.FileHandler(os.path.join(model_dir, filename))
223 | h.setLevel(logging.DEBUG)
224 | h.setFormatter(formatter)
225 | logger.addHandler(h)
226 | return logger
227 |
228 |
229 | class HParams():
230 | def __init__(self, **kwargs):
231 | for k, v in kwargs.items():
232 | if type(v) == dict:
233 | v = HParams(**v)
234 | self[k] = v
235 |
236 | def keys(self):
237 | return self.__dict__.keys()
238 |
239 | def items(self):
240 | return self.__dict__.items()
241 |
242 | def values(self):
243 | return self.__dict__.values()
244 |
245 | def __len__(self):
246 | return len(self.__dict__)
247 |
248 | def __getitem__(self, key):
249 | return getattr(self, key)
250 |
251 | def __setitem__(self, key, value):
252 | return setattr(self, key, value)
253 |
254 | def __contains__(self, key):
255 | return key in self.__dict__
256 |
257 | def __repr__(self):
258 | return self.__dict__.__repr__()
259 |
--------------------------------------------------------------------------------
/Dialogue/gen_movie.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import lightgbm as lgb
4 | import random
5 | import json
6 | import os
7 | from collections import defaultdict
8 | from movie_utils import *
9 | from config.movie_pattern import start_pattern, agent_pattern, user_pattern
10 | from config.thanks_pattern import movie_agent, movie_user
11 | from config.recommend_pattern import movie_rec
12 | from movie_attr import generate_country_dialogue, generate_genre_dialogue, generate_director_dialogue, generate_actor_dialogue
13 |
14 |
15 | def load_data():
16 | movies = pd.read_csv("./data/ml-1m/ml-1m.csv")
17 | users = pd.read_csv("./data/ml-1m/users.dat", sep='::', names=['user_id', 'gender', 'age', 'occupation', 'zip'], engine='python')
18 |
19 | return (movies, users)
20 |
21 | def get_user_item_info(data):
22 | movies = data[0]
23 | users = data[1]
24 | user_record = defaultdict(set)
25 | user_info = {}
26 | item_info = {}
27 | for _, row in movies.iterrows():
28 | user_id = int(row['user_id'])
29 | item_id = int(row['movie_id'])
30 | user_record[user_id].add(item_id)
31 |
32 | if user_id not in user_info.keys():
33 | user_info[user_id] = {}
34 | age = list(users[users['user_id'] == user_id]['age'])[0]
35 | gender = list(users[users['user_id'] == user_id]['gender'])[0]
36 | user_info[user_id]['age'] = get_user_age(age)
37 | user_info[user_id]['gender'] = get_user_gender(gender)
38 |
39 | if item_id not in item_info.keys():
40 | item_info[item_id] = {}
41 | item_info[item_id]['director'] = row['director']
42 | item_info[item_id]['country'] = modify_country(row['country'])
43 | item_info[item_id]['actor'] = get_item_actor(row['actors'])
44 | item_info[item_id]['genre'] = get_item_genre(row['genres'])
45 |
46 |
47 | return user_record, user_info, item_info
48 |
49 | def calculate_user_preference(record, data):
50 | movies = data[0]
51 | columns = ["movie_id", "director", "actors", "country", "genres"]
52 | features_cols = ["director", "actors", "country", "genres"]
53 | user_preference = defaultdict(list)
54 |
55 | df_movie = movies[columns]
56 | df_movie = df_movie.drop_duplicates(keep='first')
57 | df_movie = df_movie.reset_index(drop=True)
58 | df_movie['director'] = pd.Categorical(df_movie['director']).codes
59 | df_movie['actors'] = pd.Categorical(df_movie['actors']).codes
60 | df_movie['country'] = pd.Categorical(df_movie['country']).codes
61 | df_movie['genres'] = pd.Categorical(df_movie['genres']).codes
62 | df_movie.insert(df_movie.shape[1], 'label', 0)
63 |
64 | for user in record:
65 | user_item_csv = df_movie.copy()
66 | record = list(user_record[user])
67 | for movie in record:
68 | user_item_csv.loc[user_item_csv['movie_id']==movie, 'label'] = 1
69 | X = user_item_csv[features_cols]
70 | Y = user_item_csv.label
71 |
72 | cls = lgb.LGBMClassifier(importance_type='gain')
73 | cls.fit(X, Y)
74 |
75 | indices = np.argsort(cls.booster_.feature_importance(importance_type='gain'))
76 | feature = [features_cols[i] for i in indices]
77 |
78 | user_preference[user] = feature
79 |
80 | return user_preference
81 |
82 | def calculate_attr_weights(info, movie):
83 | genre_all = {}
84 | country_all = {}
85 | actor_all = {}
86 | director_all = {}
87 |
88 | for _, row in movie.iterrows():
89 | item_id = int(row['movie_id'])
90 |
91 | genres = info[item_id]['genre']
92 | for genre in genres:
93 | if genre not in genre_all.keys():
94 | genre_all[genre] = 1
95 | else:
96 | genre_all[genre] += 1
97 |
98 | country = info[item_id]['country']
99 | if country not in country_all.keys():
100 | country_all[country] = 1
101 | else:
102 | country_all[country] += 1
103 |
104 | director = info[item_id]['director']
105 | if director not in director_all.keys():
106 | director_all[director] = 1
107 | else:
108 | director_all[director] += 1
109 |
110 | actors = info[item_id]['actor']
111 | for actor in actors:
112 | if actor not in actor_all.keys():
113 | actor_all[actor] = 1
114 | else:
115 | actor_all[actor] += 1
116 |
117 | other_directors = [k for k,v in director_all.items() if v < 50] #402, 10415
118 | other_actors = [k for k,v in actor_all.items() if v < 50] #1267, 33893
119 |
120 | genre_weight = [genre_all[i] for i in sorted(genre_all)]
121 | country_weight = [country_all[i] for i in sorted(country_all)]
122 | director_weight = [director_all[i] for i in sorted(director_all)]
123 | actor_weight = [actor_all[i] for i in sorted(actor_all)]
124 |
125 | return (genre_weight, country_weight, director_weight, actor_weight), (genre_all, country_all, director_all, actor_all), (other_directors, other_actors)
126 |
127 |
128 | if __name__ == '__main__':
129 | movie_data = load_data()
130 | user_record, user_info, item_info = get_user_item_info(movie_data)
131 | # user_preference = calculate_user_preference(user_record, movie_data)
132 | user_preference = np.load("./data/ml-1m/user_preference.npy", allow_pickle=True).item()
133 | weights, attr_counts, other_attr= calculate_attr_weights(item_info, movie_data[0])
134 | print("data load complete")
135 | dialogue_info = {}
136 | dialogue_id = 0
137 | for idx, row in movie_data[0].iterrows():
138 | user_id = int(row['user_id'])
139 | item_id = int(row['movie_id'])
140 |
141 | new_dialogue = {}
142 | new_dialogue["user_id"] = user_id
143 | new_dialogue["item_id"] = item_id
144 |
145 | new_dialogue["user_gender"] = user_info[user_id]["gender"]
146 | new_dialogue["user_age"] = user_info[user_id]["age"]
147 | new_dialogue["content"] = {}
148 | new_dialogue["content"]["start"] = random.choice(start_pattern)
149 |
150 | dialouge_order = user_preference[user_id]
151 | tmp_new_dialogue = []
152 |
153 | for slot in dialouge_order:
154 | english = True
155 | if slot == "country":
156 | country_val = item_info[item_id]["country"]
157 | utterance = generate_country_dialogue(agent_pattern, user_pattern, country_val)
158 | elif slot == "genres":
159 | genre_val = item_info[item_id]["genre"]
160 | utterance = generate_genre_dialogue(agent_pattern, user_pattern, genre_val, attr_counts[0], weights[0])
161 | elif slot == "director":
162 | director_val = item_info[item_id]["director"]
163 | utterance = generate_director_dialogue(agent_pattern, user_pattern, director_val, attr_counts[2], weights[2], other_attr[0])
164 | english = check_in_english(utterance)
165 | if not english:
166 | break
167 | elif slot == "actors":
168 | actor_val = item_info[item_id]["actor"]
169 | utterance = generate_actor_dialogue(agent_pattern, user_pattern, actor_val, attr_counts[3], weights[3], other_attr[1])
170 | english = check_in_english(utterance)
171 | if not english:
172 | break
173 | tmp_new_dialogue.append(utterance)
174 |
175 | if not english:
176 | continue
177 | end_dialogue = {}
178 | end_dialogue["rec"] = random.choice(movie_rec)
179 | end_dialogue["thanks_user"] = random.choice(movie_user)
180 | end_dialogue["thanks_agent"] = random.choice(movie_agent)
181 | tmp_new_dialogue.append(end_dialogue)
182 | print("finish:", dialogue_id)
183 | start_index = 0
184 | end_index = 0
185 | step = 0
186 | name = ["Q1", "A1", "Q2", "A2", "Q3", "A3", "Q4", "A4", "Q5", "A5", "Q6", "A6", "Q7", "A7", "Q8", "A8", "Q9", "A9", "Q10", "A10", "Q11", "A11", "Q12", "A12"]
187 | for dia in tmp_new_dialogue:
188 | end_index = len(dia) + end_index
189 | tmp_name = name[start_index : end_index]
190 | tmp_dia = []
191 | for _, v in dia.items():
192 | tmp_dia.append(v)
193 | for i, val in enumerate(tmp_name):
194 | new_dialogue["content"][val] = tmp_dia[i]
195 | start_index = end_index
196 | dialogue_info[dialogue_id] = new_dialogue
197 | dialogue_id = dialogue_id + 1
198 |
199 | res_path = './res/ml-1m/'
200 | if not os.path.exists(res_path):
201 | os.makedirs(res_path)
202 | with open(res_path + 'dialogue_info_ml-1m.json', 'w') as f:
203 | json.dump(dialogue_info, f, indent=4)
204 |
205 |
--------------------------------------------------------------------------------
/Recommender/FMRecommender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | from tqdm import tqdm
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.backends.cudnn as cudnn
10 | from daisy.utils.config import model_config, initializer_config, optimizer_config
11 |
12 | class PointFM(nn.Module):
13 | def __init__(self,
14 | user_num,
15 | item_num,
16 | factors=84,
17 | epochs=20,
18 | lr=0.001,
19 | reg_1 = 0.001,
20 | reg_2 = 0.001,
21 | loss_type='CL',
22 | optimizer='sgd',
23 | initializer='normal',
24 | gpuid='0',
25 | feature=1,
26 | early_stop=True):
27 | """
28 | Point-wise FM Recommender Class
29 | Parameters
30 | ----------
31 | user_num : int, the number of users
32 | item_num : int, the number of items
33 | factors : int, the number of latent factor
34 | epochs : int, number of training epochs
35 | lr : float, learning rate
36 | reg_1 : float, first-order regularization term
37 | reg_2 : float, second-order regularization term
38 | loss_type : str, loss function type
39 | optimizer : str, optimization method for training the algorithms
40 | initializer: str, parameter initializer
41 | gpuid : str, GPU ID
42 | early_stop : bool, whether to activate early stop mechanism
43 |
44 | """
45 | super(PointFM, self).__init__()
46 |
47 | os.environ['CUDA_VISIBLE_DEVICES'] = gpuid
48 | cudnn.benchmark = True
49 |
50 | self.epochs = epochs
51 | self.lr = lr
52 | self.reg_1 = reg_1
53 | self.reg_2 = reg_2
54 | self.feature = feature
55 |
56 | self.embed_user = nn.Embedding(user_num, factors)
57 | self.embed_item = nn.Embedding(item_num, factors)
58 |
59 | self.u_bias = nn.Embedding(user_num, 1)
60 | self.i_bias = nn.Embedding(item_num, 1)
61 |
62 | if self.feature == 0:
63 | self.embed_gender = nn.Embedding(2, factors)
64 | self.g_bias = nn.Embedding(2, 1)
65 | elif self.feature == 1:
66 | self.embed_age = nn.Embedding(3, factors)
67 | self.a_bias = nn.Embedding(3, 1)
68 | elif self.feature == 2:
69 | self.embed_gender = nn.Embedding(2, factors)
70 | self.g_bias = nn.Embedding(2, 1)
71 | self.embed_age = nn.Embedding(3, factors)
72 | self.a_bias = nn.Embedding(3, 1)
73 |
74 | self.bias_ = nn.Parameter(torch.tensor([0.0]))
75 |
76 | # init weight
77 | nn.init.normal_(self.embed_user.weight)
78 | nn.init.normal_(self.embed_item.weight)
79 | if self.feature == 0:
80 | nn.init.normal_(self.embed_gender.weight)
81 | nn.init.constant_(self.g_bias.weight, 0.0)
82 | elif self.feature == 1:
83 | nn.init.normal_(self.embed_age.weight)
84 | nn.init.constant_(self.a_bias.weight, 0.0)
85 | elif self.feature == 2:
86 | nn.init.normal_(self.embed_gender.weight)
87 | nn.init.constant_(self.g_bias.weight, 0.0)
88 | nn.init.normal_(self.embed_age.weight)
89 | nn.init.constant_(self.a_bias.weight, 0.0)
90 |
91 |
92 | nn.init.constant_(self.u_bias.weight, 0.0)
93 | nn.init.constant_(self.i_bias.weight, 0.0)
94 |
95 | self.loss_type = loss_type
96 | self.optimizer = optimizer
97 | self.early_stop = early_stop
98 |
99 | def forward(self, user, item, gender=None, age=None):
100 | embed_user = self.embed_user(user)
101 | embed_item = self.embed_item(item)
102 |
103 | pred = (embed_user * embed_item).sum(dim=-1, keepdim=True)
104 | pred += self.u_bias(user) + self.i_bias(item) + self.bias_
105 | if self.feature == 0:
106 | embed_gender = self.embed_gender(gender)
107 | pred += (embed_gender * embed_user).sum(dim=-1, keepdim=True) + (embed_gender * embed_item).sum(dim=-1, keepdim=True)
108 | pred += self.g_bias(gender)
109 | elif self.feature == 1:
110 | embed_age = self.embed_age(age)
111 | pred += (embed_age * embed_user).sum(dim=-1, keepdim=True) + (embed_age * embed_item).sum(dim=-1, keepdim=True)
112 | pred += self.a_bias(age)
113 | elif self.feature == 2:
114 | embed_gender = self.embed_gender(gender)
115 | embed_age = self.embed_age(age)
116 | pred += (embed_gender * embed_user).sum(dim=-1, keepdim=True) + (embed_gender * embed_item).sum(dim=-1, keepdim=True)
117 | pred += (embed_age * embed_user).sum(dim=-1, keepdim=True) + (embed_age * embed_item).sum(dim=-1, keepdim=True)
118 | pred += (embed_age * embed_gender).sum(dim=-1, keepdim=True)
119 | pred += self.g_bias(gender) + self.a_bias(age)
120 |
121 | return pred.view(-1)
122 |
123 | def fit(self, train_loader):
124 | if torch.cuda.is_available():
125 | self.cuda()
126 | else:
127 | self.cpu()
128 |
129 | optimizer = optim.SGD(self.parameters(), lr=self.lr)
130 |
131 | if self.loss_type == 'CL':
132 | criterion = nn.BCEWithLogitsLoss(reduction='sum')
133 | elif self.loss_type == 'SL':
134 | criterion = nn.MSELoss(reduction='sum')
135 | else:
136 | raise ValueError(f'Invalid loss type: {self.loss_type}')
137 |
138 | last_loss = 0.
139 | for epoch in range(1, self.epochs + 1):
140 | self.train()
141 |
142 | current_loss = 0.
143 | # set process bar display
144 | pbar = tqdm(train_loader)
145 | pbar.set_description(f'[Epoch {epoch:03d}]')
146 | for user, item, gender, age, label in pbar:
147 | user = user.cuda()
148 | item = item.cuda()
149 | label = label.cuda()
150 | if self.feature == 0:
151 | gender = gender.cuda()
152 | elif self.feature == 1:
153 | age = age.cuda()
154 | elif self.feature == 2:
155 | gender = gender.cuda()
156 | age = age.cuda()
157 |
158 | self.zero_grad()
159 | if self.feature == 0:
160 | prediction = self.forward(user, item, gender=gender)
161 | elif self.feature == 1:
162 | prediction = self.forward(user, item, age=age)
163 | elif self.feature == 2:
164 | prediction = self.forward(user, item, gender=gender, age=age)
165 | else:
166 | prediction = self.forward(user, item)
167 |
168 | loss = criterion(prediction, label)
169 | loss += self.reg_1 * (self.embed_item.weight.norm(p=1) + self.embed_user.weight.norm(p=1))
170 | loss += self.reg_2 * (self.embed_item.weight.norm() + self.embed_user.weight.norm())
171 |
172 | if self.feature == 0:
173 | loss += self.reg_1 * (self.embed_gender.weight.norm(p=1))
174 | loss += self.reg_2 * (self.embed_gender.weight.norm())
175 | elif self.feature == 1:
176 | loss += self.reg_1 * (self.embed_age.weight.norm(p=1))
177 | loss += self.reg_2 * (self.embed_age.weight.norm())
178 | elif self.feature == 2:
179 | loss += self.reg_1 * (self.embed_gender.weight.norm(p=1) + self.embed_age.weight.norm(p=1))
180 | loss += self.reg_2 * (self.embed_gender.weight.norm() + self.embed_age.weight.norm())
181 |
182 | if torch.isnan(loss):
183 | raise ValueError(f'Loss=Nan or Infinity: current settings does not fit the recommender')
184 |
185 | loss.backward()
186 | optimizer.step()
187 |
188 | pbar.set_postfix(loss=loss.item())
189 | current_loss += loss.item()
190 |
191 | self.eval()
192 | delta_loss = float(current_loss - last_loss)
193 | if (abs(delta_loss) < 1e-5) and self.early_stop:
194 | print('Satisfy early stop mechanism')
195 | break
196 | else:
197 | last_loss = current_loss
198 |
199 | def predict(self, u, i, g=None, a=None):
200 | if self.feature == 0:
201 | pred = self.forward(u, i, gender=g).cpu()
202 | elif self.feature == 1:
203 | pred = self.forward(u, i, age=a).cpu()
204 | elif self.feature == 2:
205 | pred = self.forward(u, i, gender=g, age=a).cpu()
206 | else:
207 | pred = self.forward(u, i).cpu()
208 |
209 | return pred
210 |
--------------------------------------------------------------------------------
/Speech/vits_lib/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | DEFAULT_MIN_BIN_WIDTH = 1e-3
8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
9 | DEFAULT_MIN_DERIVATIVE = 1e-3
10 |
11 |
12 | def piecewise_rational_quadratic_transform(inputs,
13 | unnormalized_widths,
14 | unnormalized_heights,
15 | unnormalized_derivatives,
16 | inverse=False,
17 | tails=None,
18 | tail_bound=1.,
19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21 | min_derivative=DEFAULT_MIN_DERIVATIVE):
22 |
23 | if tails is None:
24 | spline_fn = rational_quadratic_spline
25 | spline_kwargs = {}
26 | else:
27 | spline_fn = unconstrained_rational_quadratic_spline
28 | spline_kwargs = {
29 | 'tails': tails,
30 | 'tail_bound': tail_bound
31 | }
32 |
33 | outputs, logabsdet = spline_fn(
34 | inputs=inputs,
35 | unnormalized_widths=unnormalized_widths,
36 | unnormalized_heights=unnormalized_heights,
37 | unnormalized_derivatives=unnormalized_derivatives,
38 | inverse=inverse,
39 | min_bin_width=min_bin_width,
40 | min_bin_height=min_bin_height,
41 | min_derivative=min_derivative,
42 | **spline_kwargs
43 | )
44 | return outputs, logabsdet
45 |
46 |
47 | def searchsorted(bin_locations, inputs, eps=1e-6):
48 | bin_locations[..., -1] += eps
49 | return torch.sum(
50 | inputs[..., None] >= bin_locations,
51 | dim=-1
52 | ) - 1
53 |
54 |
55 | def unconstrained_rational_quadratic_spline(inputs,
56 | unnormalized_widths,
57 | unnormalized_heights,
58 | unnormalized_derivatives,
59 | inverse=False,
60 | tails='linear',
61 | tail_bound=1.,
62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64 | min_derivative=DEFAULT_MIN_DERIVATIVE):
65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66 | outside_interval_mask = ~inside_interval_mask
67 |
68 | outputs = torch.zeros_like(inputs)
69 | logabsdet = torch.zeros_like(inputs)
70 |
71 | if tails == 'linear':
72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73 | constant = np.log(np.exp(1 - min_derivative) - 1)
74 | unnormalized_derivatives[..., 0] = constant
75 | unnormalized_derivatives[..., -1] = constant
76 |
77 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
78 | logabsdet[outside_interval_mask] = 0
79 | else:
80 | raise RuntimeError('{} tails are not implemented.'.format(tails))
81 |
82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83 | inputs=inputs[inside_interval_mask],
84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87 | inverse=inverse,
88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89 | min_bin_width=min_bin_width,
90 | min_bin_height=min_bin_height,
91 | min_derivative=min_derivative
92 | )
93 |
94 | return outputs, logabsdet
95 |
96 | def rational_quadratic_spline(inputs,
97 | unnormalized_widths,
98 | unnormalized_heights,
99 | unnormalized_derivatives,
100 | inverse=False,
101 | left=0., right=1., bottom=0., top=1.,
102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104 | min_derivative=DEFAULT_MIN_DERIVATIVE):
105 | if torch.min(inputs) < left or torch.max(inputs) > right:
106 | raise ValueError('Input to a transform is not within its domain')
107 |
108 | num_bins = unnormalized_widths.shape[-1]
109 |
110 | if min_bin_width * num_bins > 1.0:
111 | raise ValueError('Minimal bin width too large for the number of bins')
112 | if min_bin_height * num_bins > 1.0:
113 | raise ValueError('Minimal bin height too large for the number of bins')
114 |
115 | widths = F.softmax(unnormalized_widths, dim=-1)
116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117 | cumwidths = torch.cumsum(widths, dim=-1)
118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119 | cumwidths = (right - left) * cumwidths + left
120 | cumwidths[..., 0] = left
121 | cumwidths[..., -1] = right
122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123 |
124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125 |
126 | heights = F.softmax(unnormalized_heights, dim=-1)
127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128 | cumheights = torch.cumsum(heights, dim=-1)
129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130 | cumheights = (top - bottom) * cumheights + bottom
131 | cumheights[..., 0] = bottom
132 | cumheights[..., -1] = top
133 | heights = cumheights[..., 1:] - cumheights[..., :-1]
134 |
135 | if inverse:
136 | bin_idx = searchsorted(cumheights, inputs)[..., None]
137 | else:
138 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
139 |
140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142 |
143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144 | delta = heights / widths
145 | input_delta = delta.gather(-1, bin_idx)[..., 0]
146 |
147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149 |
150 | input_heights = heights.gather(-1, bin_idx)[..., 0]
151 |
152 | if inverse:
153 | a = (((inputs - input_cumheights) * (input_derivatives
154 | + input_derivatives_plus_one
155 | - 2 * input_delta)
156 | + input_heights * (input_delta - input_derivatives)))
157 | b = (input_heights * input_derivatives
158 | - (inputs - input_cumheights) * (input_derivatives
159 | + input_derivatives_plus_one
160 | - 2 * input_delta))
161 | c = - input_delta * (inputs - input_cumheights)
162 |
163 | discriminant = b.pow(2) - 4 * a * c
164 | assert (discriminant >= 0).all()
165 |
166 | root = (2 * c) / (-b - torch.sqrt(discriminant))
167 | outputs = root * input_bin_widths + input_cumwidths
168 |
169 | theta_one_minus_theta = root * (1 - root)
170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171 | * theta_one_minus_theta)
172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173 | + 2 * input_delta * theta_one_minus_theta
174 | + input_derivatives * (1 - root).pow(2))
175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176 |
177 | return outputs, -logabsdet
178 | else:
179 | theta = (inputs - input_cumwidths) / input_bin_widths
180 | theta_one_minus_theta = theta * (1 - theta)
181 |
182 | numerator = input_heights * (input_delta * theta.pow(2)
183 | + input_derivatives * theta_one_minus_theta)
184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185 | * theta_one_minus_theta)
186 | outputs = input_cumheights + numerator / denominator
187 |
188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189 | + 2 * input_delta * theta_one_minus_theta
190 | + input_derivatives * (1 - theta).pow(2))
191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192 |
193 | return outputs, logabsdet
194 |
--------------------------------------------------------------------------------
/Recommender/run_classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import torch
5 | import torchaudio
6 | import argparse
7 | import time
8 | import tqdm as tqdm
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | from torch.utils.tensorboard import SummaryWriter
13 | from torch.utils.data import Dataset,DataLoader
14 | from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel, AutoConfig
15 | from sklearn.model_selection import train_test_split
16 |
17 |
18 |
19 | class Logger(object):
20 | def __init__(self, filename='default.log', stream=sys.stdout):
21 | self.terminal = stream
22 | self.log = open(filename, 'a')
23 |
24 | def write(self, message):
25 | self.terminal.write(message)
26 | self.log.write(message)
27 |
28 | def flush(self):
29 | pass
30 |
31 | class AudioDataset(Dataset):
32 | def __init__(self, datalist, audio_path, target_sample_rate, transformation=None):
33 | self.datalist = datalist
34 | self.audio_path = audio_path
35 | self.target_sample_rate = target_sample_rate
36 | self.transformation = None
37 | if transformation:
38 | self.transformation = transformation
39 |
40 |
41 | def __len__(self):
42 | return len(self.datalist)
43 |
44 | def __getitem__(self,idx):
45 | audio_file_path = os.path.join(self.audio_path, self.datalist[idx])
46 | audio_name = self.datalist[idx]
47 | age_label, gender_label = self._get_label(audio_name)
48 | waveform, sample_rate = torchaudio.load(audio_file_path)
49 | if sample_rate != self.target_sample_rate:
50 | waveform = self._resample(waveform, sample_rate)
51 |
52 | return waveform, age_label, gender_label
53 |
54 |
55 | def _get_label(self, audio_name):
56 | name_list = audio_name.split('_')
57 | age_labels = {'under 20':0, '20-30':1, 'over 30':2}
58 | gender_labels = {'women':0, 'men':1}
59 |
60 | age = age_labels[name_list[-3]]
61 | gender = gender_labels[name_list[-2]]
62 |
63 | return age, gender
64 |
65 | def _resample(self, waveform, sample_rate):
66 | resampler = torchaudio.transforms.Resample(sample_rate,self.target_sample_rate)
67 |
68 | return resampler(waveform)
69 |
70 |
71 | class Wav2Vec2ClassificationModel(Wav2Vec2PreTrainedModel):
72 | def __init__(self, config, hidden_size, dropout):
73 | super().__init__(config)
74 |
75 | self.wav2vec2 = Wav2Vec2Model(config)
76 | self.hidden_size = hidden_size
77 | self.fc = nn.Linear(config.hidden_size, self.hidden_size)
78 | self.dropout = nn.Dropout(dropout)
79 | self.gender_fc = nn.Linear(self.hidden_size, 2)
80 | self.age_fc = nn.Linear(self.hidden_size, 3)
81 | self.tanh = nn.Tanh()
82 | self.relu = nn.ReLU()
83 |
84 | self.init_weights()
85 |
86 | def freeze_feature_extractor(self):
87 | self.wav2vec2.feature_extractor._freeze_parameters()
88 |
89 | def merged_strategy(self, hidden_states):
90 | outputs = torch.mean(hidden_states, dim=1)
91 |
92 | return outputs
93 |
94 | def forward(
95 | self,
96 | input_values,
97 | attention_mask=None,
98 | output_attentions=None,
99 | output_hidden_states=None,
100 | return_dict=None,
101 | labels=None,
102 | ):
103 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104 | with torch.no_grad():
105 | outputs = self.wav2vec2(
106 | input_values,
107 | attention_mask=attention_mask,
108 | output_attentions=output_attentions,
109 | output_hidden_states=output_hidden_states,
110 | return_dict=return_dict,
111 | )
112 | hidden_states = outputs[0]
113 | x = self.merged_strategy(hidden_states)
114 | x = self.fc(x)
115 | x = self.relu(x)
116 | x = self.dropout(x)
117 | gender_logits = self.gender_fc(x)
118 | gender_logits = F.log_softmax(gender_logits, dim=-1)
119 | age_logits = self.age_fc(x)
120 | age_logits = F.log_softmax(age_logits, dim=-1)
121 |
122 | return age_logits, gender_logits
123 |
124 |
125 | def pad_sequence(batch):
126 | # Make all tensor in a batch the same length by padding with zeros
127 | batch = [item.t() for item in batch]
128 | batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
129 | return batch.permute(0, 2, 1)
130 |
131 | def collate_fn(batch):
132 |
133 | audios, age_labels, gender_labels = [], [], []
134 |
135 | for waveform, age, gender in batch:
136 | audios += [waveform]
137 | age_labels += [torch.tensor(age)]
138 | gender_labels += [torch.tensor(gender)]
139 |
140 | audios = pad_sequence(audios)
141 | age_labels = torch.stack(age_labels)
142 | gender_labels = torch.stack(gender_labels)
143 |
144 | return audios.squeeze(dim=1), age_labels, gender_labels
145 |
146 | def train_single_epoch(model, dataloader, optimizer, device):
147 | model.train()
148 | losses = []
149 | for waveform, age_label, gender_label in tqdm.tqdm(dataloader):
150 | waveform = waveform.to(device)
151 | age_label = age_label.to(device)
152 | gender_label = gender_label.to(device)
153 |
154 | age_logits, gender_logits = model(waveform)
155 |
156 | age_loss = F.nll_loss(age_logits, age_label)
157 | gender_loss = F.nll_loss(gender_logits, gender_label)
158 | loss = age_loss + gender_loss
159 |
160 | optimizer.zero_grad()
161 | loss.backward()
162 | optimizer.step()
163 | losses.append(loss.item())
164 | return losses
165 |
166 | def get_likely_index(tensor):
167 | # find most likely label index for each element in the batch
168 | return tensor.argmax(dim=-1)
169 |
170 | def number_of_correct(pred, target):
171 | # count number of correct predictions
172 | return pred.eq(target).sum().item()
173 |
174 | def test_val_single_epoch(model, dataloader, device):
175 | model.eval()
176 | age_correct = 0
177 | gender_correct = 0
178 | for waveform, age_label, gender_label in tqdm.tqdm(dataloader):
179 | waveform = waveform.to(device)
180 | age_label = age_label.to(device)
181 | gender_label = gender_label.to(device)
182 |
183 | age_logits, gender_logits = model(waveform) # batch * output
184 |
185 | age_pred = get_likely_index(age_logits) # batch
186 | gender_pred = get_likely_index(gender_logits)
187 |
188 | age_correct += number_of_correct(age_pred, age_label)
189 | gender_correct += number_of_correct(gender_pred, gender_label)
190 |
191 | age_accu = age_correct / len(dataloader.dataset) * 100.
192 | gender_accu = gender_correct / len(dataloader.dataset) * 100.
193 |
194 | return age_accu, gender_accu
195 |
196 | def train(model, train_loader, val_loader, optimizer, device, epochs):
197 | for epoch in tqdm.tqdm(range(epochs)):
198 | losses = train_single_epoch(model, train_loader, optimizer, device)
199 | loss = np.mean(losses)
200 | age_accu, gender_accu = test_val_single_epoch(model, val_loader, device)
201 | writer.add_scalar('loss', loss, epoch+1)
202 | writer.add_scalar('age_accuracy', age_accu, epoch+1)
203 | writer.add_scalar('gender_accuracy', gender_accu, epoch+1)
204 | print(f"\nTrain Epoch: {epoch + 1} Loss: {loss:.6f} Age: {age_accu:.2f}% Gender: {gender_accu:.2f}%")
205 | print('Finished Training')
206 |
207 |
208 | if __name__ == '__main__':
209 | parser = argparse.ArgumentParser(description='speech classification')
210 | parser.add_argument('--batch_size',
211 | type=int,
212 | default=4,)
213 | parser.add_argument('--hidden_size',
214 | type=int,
215 | default=1024)
216 | parser.add_argument('--dropout',
217 | type=float,
218 | default=0.2)
219 | parser.add_argument('--epochs',
220 | type=int,
221 | default=20)
222 | parser.add_argument('--lr',
223 | type=float,
224 | default=0.0001)
225 | parser.add_argument('--test',
226 | type=int,
227 | default=1)
228 | parser.add_argument('--dataset',
229 | type=str,
230 | default='ml-1m')
231 | args = parser.parse_args()
232 |
233 | cur_time = time.strftime("%Y-%m-%d-%H_%M_%S",time.localtime(time.time()))[5:]
234 |
235 | save_tb_log_path = f'./tb_log/{cur_time}_{args.batch_size}'
236 | if not os.path.exists('./tb_log/'):
237 | os.makedirs('./tb_log/')
238 | writer = SummaryWriter(save_tb_log_path, flush_secs=30)
239 |
240 | save_log = f'./log/{cur_time}_{args.lr}_{args.dropout}_{args.batch_size}.log'
241 | sys.stdout = Logger(save_log, sys.stdout)
242 |
243 | seed = 2023
244 | np.random.seed(seed)
245 | torch.manual_seed(seed)
246 |
247 | input_dir = f'./data/{args.dataset}/'
248 | datalist = os.listdir(input_dir)
249 |
250 | ratio_train = 0.8
251 | ratio_val = 0.1
252 | ratio_test = 0.1
253 |
254 | remaining, test_set = train_test_split(datalist, test_size=ratio_test, random_state=2023)
255 | ratio_remaining = 1 - ratio_test
256 | ratio_val_adjusted = ratio_val / ratio_remaining
257 | train_set, val_set = train_test_split(remaining, test_size=ratio_val_adjusted, random_state=2023)
258 |
259 |
260 | if torch.cuda.is_available():
261 | device='cuda'
262 | torch.cuda.manual_seed_all(seed)
263 | torch.backends.cudnn.benchmark = False
264 | else:
265 | device='cpu'
266 |
267 | train_set = AudioDataset(train_set, input_dir, 16000)
268 | val_set = AudioDataset(val_set, input_dir, 16000)
269 | if args.test:
270 | test_set = AudioDataset(test_set, input_dir, 16000)
271 |
272 | if device == "cuda":
273 | num_workers = 1
274 | pin_memory = False
275 | else:
276 | num_workers = 0
277 | pin_memory = False
278 |
279 | train_loader = DataLoader(
280 | train_set,
281 | batch_size=args.batch_size,
282 | shuffle=True,
283 | collate_fn=collate_fn,
284 | num_workers=num_workers,
285 | pin_memory=pin_memory,
286 | )
287 |
288 | val_loader = DataLoader(
289 | val_set,
290 | batch_size=args.batch_size,
291 | shuffle=False,
292 | drop_last=False,
293 | collate_fn=collate_fn,
294 | num_workers=num_workers,
295 | pin_memory=pin_memory,
296 | )
297 |
298 | if args.test:
299 | test_loader = DataLoader(
300 | test_set,
301 | batch_size=args.batch_size,
302 | shuffle=False,
303 | drop_last=False,
304 | collate_fn=collate_fn,
305 | num_workers=num_workers,
306 | pin_memory=pin_memory,
307 | )
308 |
309 | model_name_or_path = "lighteternal/wav2vec2-large-xlsr-53-greek"
310 | config = AutoConfig.from_pretrained(
311 | model_name_or_path,
312 | finetuning_task="wav2vec2_clf",
313 | )
314 |
315 | clf_model = Wav2Vec2ClassificationModel.from_pretrained(
316 | model_name_or_path,
317 | config=config,
318 | hidden_size=args.hidden_size,
319 | dropout=args.dropout,
320 | )
321 |
322 | clf_model = clf_model.to(device)
323 |
324 | optimizer = optim.Adam(clf_model.parameters(), lr=args.lr)
325 | clf_model.freeze_feature_extractor()
326 |
327 | train(clf_model, train_loader, val_loader, optimizer, device, args.epochs)
328 |
329 | if args.test:
330 | age_accu, gender_accu = test_val_single_epoch(clf_model, test_loader, device)
331 | torch.save(clf_model, f'./clf_model_{args.dataset}.pt')
332 | print(f"Test set: Age: {age_accu:.2f}% Gender: {gender_accu:.2f}%")
333 |
334 |
--------------------------------------------------------------------------------
/Recommender/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import datetime
5 | import numpy as np
6 | import pandas as pd
7 | import scipy.sparse as sp
8 | import torch.utils.data as data
9 | from tqdm import tqdm
10 | from sklearn.model_selection import train_test_split
11 | from utils import get_ur, build_candidates_set, precision_at_k, recall_at_k, map_at_k, hr_at_k, ndcg_at_k, mrr_at_k
12 | from FMRecommender import PointFM
13 |
14 |
15 | class Sample(object):
16 | def __init__(self, user_num, item_num, feature_num, num_ng=4):
17 | self.user_num = user_num
18 | self.item_num = item_num
19 | self.num_ng = num_ng
20 | self.feature_num = feature_num
21 |
22 | def transform(self, data, is_training=True):
23 | if not is_training:
24 | neg_set = []
25 | for _, row in data.iterrows():
26 | u = int(row['user'])
27 | i = int(row['item'])
28 | r = row['rating']
29 | js = []
30 | if self.feature_num == 0:
31 | neg_set.append([u, i, r, js])
32 | else:
33 | g = int(row['gender'])
34 | a = int(row['age'])
35 | neg_set.append([u, i, g, a, r, js])
36 | return neg_set
37 |
38 | user_num = self.user_num
39 | item_num = self.item_num
40 | pair_pos = sp.dok_matrix((user_num, item_num), dtype=np.float32)
41 | for _, row in data.iterrows():
42 | pair_pos[int(row['user']), int(row['item'])] = 1.0
43 | neg_set = []
44 | for _, row in data.iterrows():
45 | u = int(row['user'])
46 | i = int(row['item'])
47 | r = row['rating']
48 | js = []
49 | for _ in range(self.num_ng):
50 | j = np.random.randint(item_num)
51 | while (u, j) in pair_pos:
52 | j = np.random.randint(item_num)
53 | js.append(j)
54 | if self.feature_num == 0:
55 | neg_set.append([u, i, r, js])
56 | else:
57 | g = int(row['gender'])
58 | a = int(row['age'])
59 | neg_set.append([u, i, g, a, r, js])
60 | return neg_set
61 |
62 |
63 | class FMData(data.Dataset):
64 | def __init__(self, neg_set, feature_num, is_training=True, neg_label_val=0.):
65 | super(FMData, self).__init__()
66 | self.features_fill = []
67 | self.labels_fill = []
68 | self.feature_num = feature_num
69 | self.neg_label = neg_label_val
70 |
71 | if self.feature_num == 0:
72 | self.init_normal(neg_set, is_training)
73 | else:
74 | self.init_double_feature(neg_set, is_training)
75 |
76 | def __len__(self):
77 | return len(self.labels_fill)
78 |
79 | def __getitem__(self, index):
80 | features = self.features_fill
81 | labels = self.labels_fill
82 | user = features[index][0]
83 | item = features[index][1]
84 | label = labels[index]
85 | if self.feature_num == 0:
86 | return user, item, item, item, label
87 | else:
88 | gender = features[index][2]
89 | age = features[index][3]
90 | return user, item, gender, age, label
91 |
92 | def init_normal(self, neg_set, is_training=True):
93 | for u, i, r, js in neg_set:
94 | self.features_fill.append([int(u), int(i)])
95 | self.labels_fill.append(r)
96 |
97 | if is_training:
98 | for j in js:
99 | self.features_fill.append([int(u), int(j)])
100 | self.labels_fill.append(self.neg_label)
101 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32)
102 |
103 | def init_double_feature(self, neg_set, is_training=True):
104 | for u, i, g, a, r, js in neg_set:
105 | self.features_fill.append([int(u), int(i), int(g), int(a)])
106 | self.labels_fill.append(r)
107 |
108 | if is_training:
109 | for j in js:
110 | self.features_fill.append([int(u), int(j), int(g), int(a)])
111 | self.labels_fill.append(self.neg_label)
112 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32)
113 |
114 |
115 |
116 | def age_map(age):
117 | age_dict = {'under 20':0, '20-30':1, 'over 30':2}
118 |
119 | return age_dict[age]
120 |
121 | def gender_map_ml(gender):
122 | gender_dict = {'women':0, 'men':1}
123 |
124 | return gender_dict[gender]
125 |
126 | def gender_map_coat(gender):
127 | gender_dict = {'women':1, 'men':0}
128 |
129 | return gender_dict[gender]
130 |
131 |
132 | def load_data(path):
133 | df = pd.read_csv(path)
134 | df['user'] = pd.Categorical(df['user']).codes
135 | df['item'] = pd.Categorical(df['item']).codes
136 | user_num = df['user'].nunique()
137 | item_num = df['item'].nunique()
138 |
139 | return df, len(df), user_num, item_num
140 |
141 | def get_user_info(df, feature_num):
142 | user_info = dict()
143 | for _, row in df.iterrows():
144 | if feature_num == 0:
145 | user_info[int(row['user'])] = [int(row['item'])]
146 | else:
147 | user_info[int(row['user'])] = [int(row['gender']), int(row['age'])]
148 |
149 | return user_info
150 |
151 |
152 |
153 | if __name__ == '__main__':
154 | parser = argparse.ArgumentParser(description='fm recommender')
155 | parser.add_argument('--feature',
156 | type=str,
157 | default='user,item,gender,age,rating,normal')
158 | parser.add_argument('--val',
159 | type=int,
160 | default=1)
161 | parser.add_argument('--num_ng',
162 | type=int,
163 | default=5)
164 | parser.add_argument('--factors',
165 | type=int,
166 | default=16)
167 | parser.add_argument('--epochs',
168 | type=int,
169 | default=1)
170 | parser.add_argument('--lr',
171 | type=float,
172 | default=0.01)
173 | parser.add_argument('--batch_size',
174 | type=int,
175 | default=256)
176 | parser.add_argument('--dataset',
177 | type=str,
178 | default='coat')
179 | args = parser.parse_args()
180 | time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
181 | s_time = time[5:].replace(' ', '-').replace(':', '-')
182 | # seed = 1
183 | # np.random.seed(seed)
184 | # torch.manual_seed(seed)
185 | path = f'/home/workshop/dataset/lhy/Speech/recommend/data/{args.dataset}_mp3.csv'
186 | df, record_num, user_num, item_num = load_data(path)
187 | df.insert(df.shape[1], 'rating', 1)
188 | feature_list = args.feature.split(',')
189 | if 'gender' not in feature_list and 'age' not in feature_list:
190 | feature_num = 0
191 | elif 'gender' in feature_list and 'age' in feature_list and 'normal' not in feature_list:
192 | feature_num = 1
193 | elif 'gender' in feature_list and 'age' in feature_list and 'normal' in feature_list:
194 | feature_num = 2
195 | feature_list = ['user', 'item', 'age', 'gender', 'rating']
196 | df = df[feature_list]
197 | if feature_num != 0:
198 | df['age'] = df['age'].apply(lambda age: age_map(age))
199 | if args.dataset == 'coat':
200 | df['gender'] = df['gender'].apply(lambda gender: gender_map_coat(gender))
201 | else:
202 | df['gender'] = df['gender'].apply(lambda gender: gender_map_ml(gender))
203 |
204 |
205 | user_info = get_user_info(df, feature_num)
206 |
207 | ratio_train = 0.8
208 | ratio_val = 0.1
209 | ratio_test = 0.2
210 |
211 | train_set, test_set = train_test_split(df, test_size=ratio_test, random_state=2023)
212 | if args.val == 1:
213 | ratio_remaining = 1 - ratio_test
214 | ratio_val_adjusted = ratio_val / ratio_remaining
215 | train_set, val_set = train_test_split(train_set, test_size=ratio_val_adjusted, random_state=2023)
216 |
217 | test_ur = get_ur(test_set)
218 | total_train_ur = get_ur(train_set)
219 | item_pool = set(range(item_num))
220 | sampler = Sample(user_num, item_num, feature_num, args.num_ng)
221 | neg_set = sampler.transform(train_set, is_training=True)
222 | print("data sample complete")
223 | train_dataset = FMData(neg_set, feature_num, is_training=True)
224 | model = PointFM(user_num, item_num, args.factors, args.epochs, args.lr, feature=feature_num)
225 | if torch.cuda.is_available():
226 | # torch.cuda.manual_seed_all(seed)
227 | # torch.backends.cudnn.benchmark = False
228 | model = model.to("cuda")
229 | train_loader = data.DataLoader(
230 | train_dataset,
231 | batch_size=args.batch_size,
232 | shuffle=True,
233 | num_workers=4
234 | )
235 | model.fit(train_loader)
236 | print('Start Calculating Metrics......')
237 | test_ucands = build_candidates_set(test_ur, total_train_ur, item_pool, 1000)
238 | print('')
239 | print('Generate recommend list...')
240 | print('')
241 | preds = {}
242 | for u in tqdm(test_ucands.keys()):
243 | if feature_num == 0:
244 | tmp = pd.DataFrame({
245 | 'user': [u for _ in test_ucands[u]],
246 | 'item': test_ucands[u],
247 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense
248 | })
249 | else:
250 | gender = user_info[u][0]
251 | age = user_info[u][1]
252 | tmp = pd.DataFrame({
253 | 'user': [u for _ in test_ucands[u]],
254 | 'item': test_ucands[u],
255 | 'gender': [gender for _ in test_ucands[u]],
256 | 'age': [age for _ in test_ucands[u]],
257 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense
258 | })
259 | tmp_neg_set = sampler.transform(tmp, is_training=False)
260 | tmp_dataset = FMData(tmp_neg_set, feature_num, is_training=False)
261 | tmp_loader = data.DataLoader(
262 | tmp_dataset,
263 | batch_size=1000,
264 | shuffle=False,
265 | num_workers=4
266 | )
267 | for user, item, gender, age, label in tmp_loader:
268 | user = user.cuda()
269 | item = item.cuda()
270 | label = label.cuda()
271 |
272 | if feature_num != 0:
273 | gender = gender.cuda()
274 | age = age.cuda()
275 | prediction = model.predict(user, item, g=gender, a=age)
276 | else:
277 | prediction = model.predict(user, item)
278 | _, indices = torch.topk(prediction, 50)
279 | top_n = torch.take(torch.tensor(test_ucands[u]), indices).cpu().numpy()
280 | preds[u] = top_n
281 |
282 | for u in preds.keys():
283 | preds[u] = [1 if i in test_ur[u] else 0 for i in preds[u]]
284 |
285 | print('Save metric@k result to res folder...')
286 | if feature_num == 0:
287 | result_save_path = f'./res/fm/{args.dataset}/wo_ag/'
288 | elif feature_num == 1:
289 | result_save_path = f'./res/fm/{args.dataset}/ag_audio/'
290 | elif feature_num == 2:
291 | result_save_path = f'./res/fm/{args.dataset}/ag_normal/'
292 | if not os.path.exists(result_save_path):
293 | os.makedirs(result_save_path)
294 | res = pd.DataFrame({'metric@K': ['pre', 'rec', 'hr', 'map', 'mrr', 'ndcg']})
295 | for k in [1, 5, 10, 20, 30, 50]:
296 | tmp_preds = preds.copy()
297 | tmp_preds = {key: rank_list[:k] for key, rank_list in tmp_preds.items()}
298 |
299 | pre_k = np.mean([precision_at_k(r, k) for r in tmp_preds.values()])
300 | rec_k = recall_at_k(tmp_preds, test_ur, k)
301 | hr_k = hr_at_k(tmp_preds, test_ur)
302 | map_k = map_at_k(tmp_preds.values())
303 | mrr_k = mrr_at_k(tmp_preds, k)
304 | ndcg_k = np.mean([ndcg_at_k(r, k) for r in tmp_preds.values()])
305 |
306 | if k == 10:
307 | print(f'Precision@{k}: {pre_k:.4f}')
308 | print(f'Recall@{k}: {rec_k:.4f}')
309 | print(f'HR@{k}: {hr_k:.4f}')
310 | print(f'MAP@{k}: {map_k:.4f}')
311 | print(f'MRR@{k}: {mrr_k:.4f}')
312 | print(f'NDCG@{k}: {ndcg_k:.4f}')
313 |
314 | res[k] = np.array([pre_k, rec_k, hr_k, map_k, mrr_k, ndcg_k])
315 |
316 | common_prefix = f'{s_time}_{args.num_ng}_{args.factors}_{args.lr}_{args.batch_size}_{args.epochs}'
317 |
318 | res.to_csv(
319 | f'{result_save_path}{common_prefix}_results.csv',
320 | index=False
321 | )
--------------------------------------------------------------------------------
/Speech/vits_lib/attentions.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | import vits_lib.commons as commons
8 | import vits_lib.modules as modules
9 | from vits_lib.modules import LayerNorm
10 |
11 |
12 | class Encoder(nn.Module):
13 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
14 | super().__init__()
15 | self.hidden_channels = hidden_channels
16 | self.filter_channels = filter_channels
17 | self.n_heads = n_heads
18 | self.n_layers = n_layers
19 | self.kernel_size = kernel_size
20 | self.p_dropout = p_dropout
21 | self.window_size = window_size
22 |
23 | self.drop = nn.Dropout(p_dropout)
24 | self.attn_layers = nn.ModuleList()
25 | self.norm_layers_1 = nn.ModuleList()
26 | self.ffn_layers = nn.ModuleList()
27 | self.norm_layers_2 = nn.ModuleList()
28 | for i in range(self.n_layers):
29 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
30 | self.norm_layers_1.append(LayerNorm(hidden_channels))
31 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
32 | self.norm_layers_2.append(LayerNorm(hidden_channels))
33 |
34 | def forward(self, x, x_mask):
35 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
36 | x = x * x_mask
37 | for i in range(self.n_layers):
38 | y = self.attn_layers[i](x, x, attn_mask)
39 | y = self.drop(y)
40 | x = self.norm_layers_1[i](x + y)
41 |
42 | y = self.ffn_layers[i](x, x_mask)
43 | y = self.drop(y)
44 | x = self.norm_layers_2[i](x + y)
45 | x = x * x_mask
46 | return x
47 |
48 |
49 | class Decoder(nn.Module):
50 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
51 | super().__init__()
52 | self.hidden_channels = hidden_channels
53 | self.filter_channels = filter_channels
54 | self.n_heads = n_heads
55 | self.n_layers = n_layers
56 | self.kernel_size = kernel_size
57 | self.p_dropout = p_dropout
58 | self.proximal_bias = proximal_bias
59 | self.proximal_init = proximal_init
60 |
61 | self.drop = nn.Dropout(p_dropout)
62 | self.self_attn_layers = nn.ModuleList()
63 | self.norm_layers_0 = nn.ModuleList()
64 | self.encdec_attn_layers = nn.ModuleList()
65 | self.norm_layers_1 = nn.ModuleList()
66 | self.ffn_layers = nn.ModuleList()
67 | self.norm_layers_2 = nn.ModuleList()
68 | for i in range(self.n_layers):
69 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
70 | self.norm_layers_0.append(LayerNorm(hidden_channels))
71 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
72 | self.norm_layers_1.append(LayerNorm(hidden_channels))
73 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
74 | self.norm_layers_2.append(LayerNorm(hidden_channels))
75 |
76 | def forward(self, x, x_mask, h, h_mask):
77 | """
78 | x: decoder input
79 | h: encoder output
80 | """
81 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
82 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
83 | x = x * x_mask
84 | for i in range(self.n_layers):
85 | y = self.self_attn_layers[i](x, x, self_attn_mask)
86 | y = self.drop(y)
87 | x = self.norm_layers_0[i](x + y)
88 |
89 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
90 | y = self.drop(y)
91 | x = self.norm_layers_1[i](x + y)
92 |
93 | y = self.ffn_layers[i](x, x_mask)
94 | y = self.drop(y)
95 | x = self.norm_layers_2[i](x + y)
96 | x = x * x_mask
97 | return x
98 |
99 |
100 | class MultiHeadAttention(nn.Module):
101 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
102 | super().__init__()
103 | assert channels % n_heads == 0
104 |
105 | self.channels = channels
106 | self.out_channels = out_channels
107 | self.n_heads = n_heads
108 | self.p_dropout = p_dropout
109 | self.window_size = window_size
110 | self.heads_share = heads_share
111 | self.block_length = block_length
112 | self.proximal_bias = proximal_bias
113 | self.proximal_init = proximal_init
114 | self.attn = None
115 |
116 | self.k_channels = channels // n_heads
117 | self.conv_q = nn.Conv1d(channels, channels, 1)
118 | self.conv_k = nn.Conv1d(channels, channels, 1)
119 | self.conv_v = nn.Conv1d(channels, channels, 1)
120 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
121 | self.drop = nn.Dropout(p_dropout)
122 |
123 | if window_size is not None:
124 | n_heads_rel = 1 if heads_share else n_heads
125 | rel_stddev = self.k_channels**-0.5
126 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
127 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
128 |
129 | nn.init.xavier_uniform_(self.conv_q.weight)
130 | nn.init.xavier_uniform_(self.conv_k.weight)
131 | nn.init.xavier_uniform_(self.conv_v.weight)
132 | if proximal_init:
133 | with torch.no_grad():
134 | self.conv_k.weight.copy_(self.conv_q.weight)
135 | self.conv_k.bias.copy_(self.conv_q.bias)
136 |
137 | def forward(self, x, c, attn_mask=None):
138 | q = self.conv_q(x)
139 | k = self.conv_k(c)
140 | v = self.conv_v(c)
141 |
142 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
143 |
144 | x = self.conv_o(x)
145 | return x
146 |
147 | def attention(self, query, key, value, mask=None):
148 | # reshape [b, d, t] -> [b, n_h, t, d_k]
149 | b, d, t_s, t_t = (*key.size(), query.size(2))
150 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
151 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
152 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
153 |
154 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
155 | if self.window_size is not None:
156 | assert t_s == t_t, "Relative attention is only available for self-attention."
157 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
158 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
159 | scores_local = self._relative_position_to_absolute_position(rel_logits)
160 | scores = scores + scores_local
161 | if self.proximal_bias:
162 | assert t_s == t_t, "Proximal bias is only available for self-attention."
163 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
164 | if mask is not None:
165 | scores = scores.masked_fill(mask == 0, -1e4)
166 | if self.block_length is not None:
167 | assert t_s == t_t, "Local attention is only available for self-attention."
168 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
169 | scores = scores.masked_fill(block_mask == 0, -1e4)
170 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
171 | p_attn = self.drop(p_attn)
172 | output = torch.matmul(p_attn, value)
173 | if self.window_size is not None:
174 | relative_weights = self._absolute_position_to_relative_position(p_attn)
175 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
176 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
177 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
178 | return output, p_attn
179 |
180 | def _matmul_with_relative_values(self, x, y):
181 | """
182 | x: [b, h, l, m]
183 | y: [h or 1, m, d]
184 | ret: [b, h, l, d]
185 | """
186 | ret = torch.matmul(x, y.unsqueeze(0))
187 | return ret
188 |
189 | def _matmul_with_relative_keys(self, x, y):
190 | """
191 | x: [b, h, l, d]
192 | y: [h or 1, m, d]
193 | ret: [b, h, l, m]
194 | """
195 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
196 | return ret
197 |
198 | def _get_relative_embeddings(self, relative_embeddings, length):
199 | max_relative_position = 2 * self.window_size + 1
200 | # Pad first before slice to avoid using cond ops.
201 | pad_length = max(length - (self.window_size + 1), 0)
202 | slice_start_position = max((self.window_size + 1) - length, 0)
203 | slice_end_position = slice_start_position + 2 * length - 1
204 | if pad_length > 0:
205 | padded_relative_embeddings = F.pad(
206 | relative_embeddings,
207 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
208 | else:
209 | padded_relative_embeddings = relative_embeddings
210 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
211 | return used_relative_embeddings
212 |
213 | def _relative_position_to_absolute_position(self, x):
214 | """
215 | x: [b, h, l, 2*l-1]
216 | ret: [b, h, l, l]
217 | """
218 | batch, heads, length, _ = x.size()
219 | # Concat columns of pad to shift from relative to absolute indexing.
220 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
221 |
222 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
223 | x_flat = x.view([batch, heads, length * 2 * length])
224 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
225 |
226 | # Reshape and slice out the padded elements.
227 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
228 | return x_final
229 |
230 | def _absolute_position_to_relative_position(self, x):
231 | """
232 | x: [b, h, l, l]
233 | ret: [b, h, l, 2*l-1]
234 | """
235 | batch, heads, length, _ = x.size()
236 | # padd along column
237 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
238 | x_flat = x.view([batch, heads, length**2 + length*(length -1)])
239 | # add 0's in the beginning that will skew the elements after reshape
240 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
241 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
242 | return x_final
243 |
244 | def _attention_bias_proximal(self, length):
245 | """Bias for self-attention to encourage attention to close positions.
246 | Args:
247 | length: an integer scalar.
248 | Returns:
249 | a Tensor with shape [1, 1, length, length]
250 | """
251 | r = torch.arange(length, dtype=torch.float32)
252 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
253 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
254 |
255 |
256 | class FFN(nn.Module):
257 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
258 | super().__init__()
259 | self.in_channels = in_channels
260 | self.out_channels = out_channels
261 | self.filter_channels = filter_channels
262 | self.kernel_size = kernel_size
263 | self.p_dropout = p_dropout
264 | self.activation = activation
265 | self.causal = causal
266 |
267 | if causal:
268 | self.padding = self._causal_padding
269 | else:
270 | self.padding = self._same_padding
271 |
272 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
273 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
274 | self.drop = nn.Dropout(p_dropout)
275 |
276 | def forward(self, x, x_mask):
277 | x = self.conv_1(self.padding(x * x_mask))
278 | if self.activation == "gelu":
279 | x = x * torch.sigmoid(1.702 * x)
280 | else:
281 | x = torch.relu(x)
282 | x = self.drop(x)
283 | x = self.conv_2(self.padding(x * x_mask))
284 | return x * x_mask
285 |
286 | def _causal_padding(self, x):
287 | if self.kernel_size == 1:
288 | return x
289 | pad_l = self.kernel_size - 1
290 | pad_r = 0
291 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
292 | x = F.pad(x, commons.convert_pad_shape(padding))
293 | return x
294 |
295 | def _same_padding(self, x):
296 | if self.kernel_size == 1:
297 | return x
298 | pad_l = (self.kernel_size - 1) // 2
299 | pad_r = self.kernel_size // 2
300 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
301 | x = F.pad(x, commons.convert_pad_shape(padding))
302 | return x
303 |
--------------------------------------------------------------------------------
/Speech/vits_lib/modules.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import scipy
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10 | from torch.nn.utils import weight_norm, remove_weight_norm
11 |
12 | import vits_lib.commons as commons
13 | from vits_lib.commons import init_weights, get_padding
14 | from vits_lib.transforms import piecewise_rational_quadratic_transform
15 |
16 |
17 | LRELU_SLOPE = 0.1
18 |
19 |
20 | class LayerNorm(nn.Module):
21 | def __init__(self, channels, eps=1e-5):
22 | super().__init__()
23 | self.channels = channels
24 | self.eps = eps
25 |
26 | self.gamma = nn.Parameter(torch.ones(channels))
27 | self.beta = nn.Parameter(torch.zeros(channels))
28 |
29 | def forward(self, x):
30 | x = x.transpose(1, -1)
31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32 | return x.transpose(1, -1)
33 |
34 |
35 | class ConvReluNorm(nn.Module):
36 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
37 | super().__init__()
38 | self.in_channels = in_channels
39 | self.hidden_channels = hidden_channels
40 | self.out_channels = out_channels
41 | self.kernel_size = kernel_size
42 | self.n_layers = n_layers
43 | self.p_dropout = p_dropout
44 | assert n_layers > 1, "Number of layers should be larger than 0."
45 |
46 | self.conv_layers = nn.ModuleList()
47 | self.norm_layers = nn.ModuleList()
48 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
49 | self.norm_layers.append(LayerNorm(hidden_channels))
50 | self.relu_drop = nn.Sequential(
51 | nn.ReLU(),
52 | nn.Dropout(p_dropout))
53 | for _ in range(n_layers-1):
54 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
55 | self.norm_layers.append(LayerNorm(hidden_channels))
56 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
57 | self.proj.weight.data.zero_()
58 | self.proj.bias.data.zero_()
59 |
60 | def forward(self, x, x_mask):
61 | x_org = x
62 | for i in range(self.n_layers):
63 | x = self.conv_layers[i](x * x_mask)
64 | x = self.norm_layers[i](x)
65 | x = self.relu_drop(x)
66 | x = x_org + self.proj(x)
67 | return x * x_mask
68 |
69 |
70 | class DDSConv(nn.Module):
71 | """
72 | Dialted and Depth-Separable Convolution
73 | """
74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
75 | super().__init__()
76 | self.channels = channels
77 | self.kernel_size = kernel_size
78 | self.n_layers = n_layers
79 | self.p_dropout = p_dropout
80 |
81 | self.drop = nn.Dropout(p_dropout)
82 | self.convs_sep = nn.ModuleList()
83 | self.convs_1x1 = nn.ModuleList()
84 | self.norms_1 = nn.ModuleList()
85 | self.norms_2 = nn.ModuleList()
86 | for i in range(n_layers):
87 | dilation = kernel_size ** i
88 | padding = (kernel_size * dilation - dilation) // 2
89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
90 | groups=channels, dilation=dilation, padding=padding
91 | ))
92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
93 | self.norms_1.append(LayerNorm(channels))
94 | self.norms_2.append(LayerNorm(channels))
95 |
96 | def forward(self, x, x_mask, g=None):
97 | if g is not None:
98 | x = x + g
99 | for i in range(self.n_layers):
100 | y = self.convs_sep[i](x * x_mask)
101 | y = self.norms_1[i](y)
102 | y = F.gelu(y)
103 | y = self.convs_1x1[i](y)
104 | y = self.norms_2[i](y)
105 | y = F.gelu(y)
106 | y = self.drop(y)
107 | x = x + y
108 | return x * x_mask
109 |
110 |
111 | class WN(torch.nn.Module):
112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
113 | super(WN, self).__init__()
114 | assert(kernel_size % 2 == 1)
115 | self.hidden_channels =hidden_channels
116 | self.kernel_size = kernel_size,
117 | self.dilation_rate = dilation_rate
118 | self.n_layers = n_layers
119 | self.gin_channels = gin_channels
120 | self.p_dropout = p_dropout
121 |
122 | self.in_layers = torch.nn.ModuleList()
123 | self.res_skip_layers = torch.nn.ModuleList()
124 | self.drop = nn.Dropout(p_dropout)
125 |
126 | if gin_channels != 0:
127 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
129 |
130 | for i in range(n_layers):
131 | dilation = dilation_rate ** i
132 | padding = int((kernel_size * dilation - dilation) / 2)
133 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
134 | dilation=dilation, padding=padding)
135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
136 | self.in_layers.append(in_layer)
137 |
138 | # last one is not necessary
139 | if i < n_layers - 1:
140 | res_skip_channels = 2 * hidden_channels
141 | else:
142 | res_skip_channels = hidden_channels
143 |
144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
146 | self.res_skip_layers.append(res_skip_layer)
147 |
148 | def forward(self, x, x_mask, g=None, **kwargs):
149 | output = torch.zeros_like(x)
150 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
151 |
152 | if g is not None:
153 | g = self.cond_layer(g)
154 |
155 | for i in range(self.n_layers):
156 | x_in = self.in_layers[i](x)
157 | if g is not None:
158 | cond_offset = i * 2 * self.hidden_channels
159 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
160 | else:
161 | g_l = torch.zeros_like(x_in)
162 |
163 | acts = commons.fused_add_tanh_sigmoid_multiply(
164 | x_in,
165 | g_l,
166 | n_channels_tensor)
167 | acts = self.drop(acts)
168 |
169 | res_skip_acts = self.res_skip_layers[i](acts)
170 | if i < self.n_layers - 1:
171 | res_acts = res_skip_acts[:,:self.hidden_channels,:]
172 | x = (x + res_acts) * x_mask
173 | output = output + res_skip_acts[:,self.hidden_channels:,:]
174 | else:
175 | output = output + res_skip_acts
176 | return output * x_mask
177 |
178 | def remove_weight_norm(self):
179 | if self.gin_channels != 0:
180 | torch.nn.utils.remove_weight_norm(self.cond_layer)
181 | for l in self.in_layers:
182 | torch.nn.utils.remove_weight_norm(l)
183 | for l in self.res_skip_layers:
184 | torch.nn.utils.remove_weight_norm(l)
185 |
186 |
187 | class ResBlock1(torch.nn.Module):
188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
189 | super(ResBlock1, self).__init__()
190 | self.convs1 = nn.ModuleList([
191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
192 | padding=get_padding(kernel_size, dilation[0]))),
193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
194 | padding=get_padding(kernel_size, dilation[1]))),
195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
196 | padding=get_padding(kernel_size, dilation[2])))
197 | ])
198 | self.convs1.apply(init_weights)
199 |
200 | self.convs2 = nn.ModuleList([
201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
202 | padding=get_padding(kernel_size, 1))),
203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
204 | padding=get_padding(kernel_size, 1))),
205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
206 | padding=get_padding(kernel_size, 1)))
207 | ])
208 | self.convs2.apply(init_weights)
209 |
210 | def forward(self, x, x_mask=None):
211 | for c1, c2 in zip(self.convs1, self.convs2):
212 | xt = F.leaky_relu(x, LRELU_SLOPE)
213 | if x_mask is not None:
214 | xt = xt * x_mask
215 | xt = c1(xt)
216 | xt = F.leaky_relu(xt, LRELU_SLOPE)
217 | if x_mask is not None:
218 | xt = xt * x_mask
219 | xt = c2(xt)
220 | x = xt + x
221 | if x_mask is not None:
222 | x = x * x_mask
223 | return x
224 |
225 | def remove_weight_norm(self):
226 | for l in self.convs1:
227 | remove_weight_norm(l)
228 | for l in self.convs2:
229 | remove_weight_norm(l)
230 |
231 |
232 | class ResBlock2(torch.nn.Module):
233 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
234 | super(ResBlock2, self).__init__()
235 | self.convs = nn.ModuleList([
236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
237 | padding=get_padding(kernel_size, dilation[0]))),
238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
239 | padding=get_padding(kernel_size, dilation[1])))
240 | ])
241 | self.convs.apply(init_weights)
242 |
243 | def forward(self, x, x_mask=None):
244 | for c in self.convs:
245 | xt = F.leaky_relu(x, LRELU_SLOPE)
246 | if x_mask is not None:
247 | xt = xt * x_mask
248 | xt = c(xt)
249 | x = xt + x
250 | if x_mask is not None:
251 | x = x * x_mask
252 | return x
253 |
254 | def remove_weight_norm(self):
255 | for l in self.convs:
256 | remove_weight_norm(l)
257 |
258 |
259 | class Log(nn.Module):
260 | def forward(self, x, x_mask, reverse=False, **kwargs):
261 | if not reverse:
262 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
263 | logdet = torch.sum(-y, [1, 2])
264 | return y, logdet
265 | else:
266 | x = torch.exp(x) * x_mask
267 | return x
268 |
269 |
270 | class Flip(nn.Module):
271 | def forward(self, x, *args, reverse=False, **kwargs):
272 | x = torch.flip(x, [1])
273 | if not reverse:
274 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
275 | return x, logdet
276 | else:
277 | return x
278 |
279 |
280 | class ElementwiseAffine(nn.Module):
281 | def __init__(self, channels):
282 | super().__init__()
283 | self.channels = channels
284 | self.m = nn.Parameter(torch.zeros(channels,1))
285 | self.logs = nn.Parameter(torch.zeros(channels,1))
286 |
287 | def forward(self, x, x_mask, reverse=False, **kwargs):
288 | if not reverse:
289 | y = self.m + torch.exp(self.logs) * x
290 | y = y * x_mask
291 | logdet = torch.sum(self.logs * x_mask, [1,2])
292 | return y, logdet
293 | else:
294 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
295 | return x
296 |
297 |
298 | class ResidualCouplingLayer(nn.Module):
299 | def __init__(self,
300 | channels,
301 | hidden_channels,
302 | kernel_size,
303 | dilation_rate,
304 | n_layers,
305 | p_dropout=0,
306 | gin_channels=0,
307 | mean_only=False):
308 | assert channels % 2 == 0, "channels should be divisible by 2"
309 | super().__init__()
310 | self.channels = channels
311 | self.hidden_channels = hidden_channels
312 | self.kernel_size = kernel_size
313 | self.dilation_rate = dilation_rate
314 | self.n_layers = n_layers
315 | self.half_channels = channels // 2
316 | self.mean_only = mean_only
317 |
318 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
319 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
320 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
321 | self.post.weight.data.zero_()
322 | self.post.bias.data.zero_()
323 |
324 | def forward(self, x, x_mask, g=None, reverse=False):
325 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
326 | h = self.pre(x0) * x_mask
327 | h = self.enc(h, x_mask, g=g)
328 | stats = self.post(h) * x_mask
329 | if not self.mean_only:
330 | m, logs = torch.split(stats, [self.half_channels]*2, 1)
331 | else:
332 | m = stats
333 | logs = torch.zeros_like(m)
334 |
335 | if not reverse:
336 | x1 = m + x1 * torch.exp(logs) * x_mask
337 | x = torch.cat([x0, x1], 1)
338 | logdet = torch.sum(logs, [1,2])
339 | return x, logdet
340 | else:
341 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
342 | x = torch.cat([x0, x1], 1)
343 | return x
344 |
345 |
346 | class ConvFlow(nn.Module):
347 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
348 | super().__init__()
349 | self.in_channels = in_channels
350 | self.filter_channels = filter_channels
351 | self.kernel_size = kernel_size
352 | self.n_layers = n_layers
353 | self.num_bins = num_bins
354 | self.tail_bound = tail_bound
355 | self.half_channels = in_channels // 2
356 |
357 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
358 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
359 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
360 | self.proj.weight.data.zero_()
361 | self.proj.bias.data.zero_()
362 |
363 | def forward(self, x, x_mask, g=None, reverse=False):
364 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
365 | h = self.pre(x0)
366 | h = self.convs(h, x_mask, g=g)
367 | h = self.proj(h) * x_mask
368 |
369 | b, c, t = x0.shape
370 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
371 |
372 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
373 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
374 | unnormalized_derivatives = h[..., 2 * self.num_bins:]
375 |
376 | x1, logabsdet = piecewise_rational_quadratic_transform(x1,
377 | unnormalized_widths,
378 | unnormalized_heights,
379 | unnormalized_derivatives,
380 | inverse=reverse,
381 | tails='linear',
382 | tail_bound=self.tail_bound
383 | )
384 |
385 | x = torch.cat([x0, x1], 1) * x_mask
386 | logdet = torch.sum(logabsdet * x_mask, [1,2])
387 | if not reverse:
388 | return x, logdet
389 | else:
390 | return x
391 |
--------------------------------------------------------------------------------
/Recommender/fm_audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import datetime
4 | import torch
5 | import pandas as pd
6 | import numpy as np
7 | import scipy.sparse as sp
8 | import torch.utils.data as data
9 | from tqdm import tqdm
10 | from FMRecommender import PointFM
11 | from utils import precision_at_k, recall_at_k, map_at_k, hr_at_k, ndcg_at_k, mrr_at_k
12 | from utils import get_ur, get_feature, build_candidates_set, get_user_info
13 |
14 |
15 | def age_map(age):
16 | if age < 20:
17 | return 0
18 | elif age >= 20 & age <= 30:
19 | return 1
20 | elif age > 30:
21 | return 2
22 |
23 | def vote_user_info(df, u, mode):
24 | user_df = df[df['user'] == u]
25 | if mode == 0:
26 | gender_series = user_df['gender'].value_counts()
27 | gender = int(gender_series.idxmax())
28 | return gender
29 | elif mode == 1:
30 | age_series = user_df['age'].value_counts()
31 | age = int(age_series.idxmax())
32 | return age
33 | else:
34 | gender_series = user_df['gender'].value_counts()
35 | age_series = user_df['age'].value_counts()
36 |
37 | gender = int(gender_series.idxmax())
38 | age = int(age_series.idxmax())
39 |
40 | return gender, age
41 |
42 | def load_data(path):
43 | df = pd.read_csv(path)
44 | df['user'] = pd.Categorical(df['user']).codes
45 | df['item'] = pd.Categorical(df['item']).codes
46 | user_num = df['user'].nunique()
47 | item_num = df['item'].nunique()
48 | print(user_num, item_num)
49 | return df, len(df), user_num, item_num
50 |
51 | def split_testset(data, test_size=0.2):
52 | split_idx = int(np.ceil(len(data) * (1 - test_size)))
53 | train, test = data.iloc[:split_idx, :].copy(), data.iloc[split_idx:, :].copy()
54 |
55 | return train, test
56 |
57 | class Sample(object):
58 | def __init__(self, user_num, item_num, feature_num, num_ng=4) -> None:
59 | self.user_num = user_num
60 | self.item_num = item_num
61 | self.num_ng = num_ng
62 | self.feature_num = feature_num
63 |
64 | def transform(self, data, is_training=True):
65 | if not is_training:
66 | neg_set = []
67 | for _, row in data.iterrows():
68 | u = int(row['user'])
69 | i = int(row['item'])
70 | r = row['rating']
71 | js = []
72 | if self.feature_num == 0:
73 | g = int(row['gender'])
74 | neg_set.append([u, i, g, r, js])
75 | elif self.feature_num == 1:
76 | a = int(row['age'])
77 | neg_set.append([u, i, a, r, js])
78 | elif self.feature_num == 2:
79 | g = int(row['gender'])
80 | a = int(row['age'])
81 | neg_set.append([u, i, g, a, r, js])
82 | else:
83 | neg_set.append([u, i, r, js])
84 | return neg_set
85 |
86 | user_num = self.user_num
87 | item_num = self.item_num
88 | pair_pos = sp.dok_matrix((user_num, item_num), dtype=np.float32)
89 | for _, row in data.iterrows():
90 | pair_pos[int(row['user']), int(row['item'])] = 1.0
91 | neg_set = []
92 | for _, row in data.iterrows():
93 | u = int(row['user'])
94 | i = int(row['item'])
95 | r = row['rating']
96 | js = []
97 | for _ in range(self.num_ng):
98 | j = np.random.randint(item_num)
99 | while (u, j) in pair_pos:
100 | j = np.random.randint(item_num)
101 | js.append(j)
102 | if self.feature_num == 0:
103 | g = int(row['gender'])
104 | neg_set.append([u, i, g, r, js])
105 | elif self.feature_num == 1:
106 | a = int(row['age'])
107 | neg_set.append([u, i, a, r, js])
108 | elif self.feature_num == 2:
109 | g = int(row['gender'])
110 | a = int(row['age'])
111 | neg_set.append([u, i, g, a, r, js])
112 | else:
113 | neg_set.append([u, i, r, js])
114 |
115 | return neg_set
116 |
117 |
118 | class FMData(data.Dataset):
119 | def __init__(self, neg_set, feature_num, is_training=True, neg_label_val=0.) -> None:
120 | super(FMData, self).__init__()
121 | self.features_fill = []
122 | self.labels_fill = []
123 | self.feature_num = feature_num
124 | self.neg_label = neg_label_val
125 |
126 | if self.feature_num == 0:
127 | self.init_gender(neg_set, is_training)
128 | elif self.feature_num == 1:
129 | self.init_age(neg_set, is_training)
130 | elif self.feature_num == 2:
131 | self.init_gender_age(neg_set, is_training)
132 | else:
133 | self.init_normal(neg_set, is_training)
134 |
135 | def __len__(self):
136 | return len(self.labels_fill)
137 |
138 | def __getitem__(self, index):
139 | features = self.features_fill
140 | labels = self.labels_fill
141 | user = features[index][0]
142 | item = features[index][1]
143 | label = labels[index]
144 | if self.feature_num == 0:
145 | gender = features[index][2]
146 | return user, item, gender, gender, label
147 | elif self.feature_num == 1:
148 | age = features[index][2]
149 | return user, item, age, age, label
150 | elif self.feature_num == 2:
151 | gender = features[index][2]
152 | age = features[index][3]
153 | return user, item, gender, age, label
154 | else:
155 | return user, item, item, item, label
156 |
157 | def init_gender(self, neg_set, is_training=True):
158 | for u, i, g, r, js in neg_set:
159 | self.features_fill.append([int(u), int(i), int(g)])
160 | self.labels_fill.append(r)
161 |
162 | if is_training:
163 | for j in js:
164 | self.features_fill.append([int(u), int(j), int(g)])
165 | self.labels_fill.append(self.neg_label)
166 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32)
167 |
168 | def init_age(self, neg_set, is_training=True):
169 | for u, i, a, r, js in neg_set:
170 | self.features_fill.append([int(u), int(i), int(a)])
171 | self.labels_fill.append(r)
172 |
173 | if is_training:
174 | for j in js:
175 | self.features_fill.append([int(u), int(j), int(a)])
176 | self.labels_fill.append(self.neg_label)
177 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32)
178 |
179 | def init_gender_age(self, neg_set, is_training=True):
180 | for u, i, g, a, r, js in neg_set:
181 | self.features_fill.append([int(u), int(i), int(g), int(a)])
182 | self.labels_fill.append(r)
183 |
184 | if is_training:
185 | for j in js:
186 | self.features_fill.append([int(u), int(j), int(g), int(a)])
187 | self.labels_fill.append(self.neg_label)
188 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32)
189 |
190 | def init_normal(self, neg_set, is_training=True):
191 | for u, i, r, js in neg_set:
192 | self.features_fill.append([int(u), int(i)])
193 | self.labels_fill.append(r)
194 |
195 | if is_training:
196 | for j in js:
197 | self.features_fill.append([int(u), int(j)])
198 | self.labels_fill.append(self.neg_label)
199 | self.labels_fill = np.array(self.labels_fill, dtype=np.float32)
200 |
201 |
202 | if __name__ == '__main__':
203 | parser = argparse.ArgumentParser(description='fm recommender')
204 | parser.add_argument('--feature',
205 | type=str,
206 | default='user,item,rating')
207 | parser.add_argument('--val',
208 | type=int,
209 | default=0)
210 | parser.add_argument('--num_ng',
211 | type=int,
212 | default=5)
213 | parser.add_argument('--factors',
214 | type=int,
215 | default=32)
216 | parser.add_argument('--epochs',
217 | type=int,
218 | default=50)
219 | parser.add_argument('--lr',
220 | type=float,
221 | default=0.01)
222 | parser.add_argument('--batch_size',
223 | type=int,
224 | default=64)
225 | parser.add_argument('--dataset',
226 | type=str,
227 | default='ml-1m')
228 | args = parser.parse_args()
229 | time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
230 | s_time = time[5:].replace(' ', '-').replace(':', '-')
231 | path = f'./res/{args.dataset}_predict.csv'
232 | df, record_num, user_num, item_num = load_data(path)
233 | df.insert(df.shape[1], 'rating', 1)
234 | feature_list = args.feature.split(',')
235 | feature_num = get_feature(feature_list)
236 | df_data = df[feature_list]
237 | user_info = get_user_info(df_data, feature_num)
238 | train_set, test_set = split_testset(df_data, 0.2)
239 | if args.val:
240 | train_set, test_set = split_testset(train_set, 0.1)
241 | test_ur = get_ur(test_set)
242 | total_train_ur = get_ur(train_set)
243 | item_pool = set(range(item_num))
244 | sampler = Sample(user_num, item_num, feature_num, args.num_ng)
245 | neg_set = sampler.transform(train_set, is_training=True)
246 | print("data sample complete")
247 | train_dataset = FMData(neg_set, feature_num, is_training=True)
248 | model = PointFM(user_num, item_num, args.factors, args.epochs, args.lr, feature=feature_num)
249 | train_loader = data.DataLoader(
250 | train_dataset,
251 | batch_size=args.batch_size,
252 | shuffle=True,
253 | num_workers=4
254 | )
255 | model.fit(train_loader)
256 | print('Start Calculating Metrics......')
257 | test_ucands = build_candidates_set(test_ur, total_train_ur, item_pool, 1000)
258 | print('')
259 | print('Generate recommend list...')
260 | print('')
261 | preds = {}
262 | for u in tqdm(test_ucands.keys()):
263 | if feature_num == 0:
264 | gender = vote_user_info(df_data, u, 0)
265 | tmp = pd.DataFrame({
266 | 'user': [u for _ in test_ucands[u]],
267 | 'item': test_ucands[u],
268 | 'gender': [gender for _ in test_ucands[u]],
269 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense
270 | })
271 | elif feature_num == 1:
272 | age = vote_user_info(df_data, u, 1)
273 | tmp = pd.DataFrame({
274 | 'user': [u for _ in test_ucands[u]],
275 | 'item': test_ucands[u],
276 | 'age': [age for _ in test_ucands[u]],
277 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense
278 | })
279 | elif feature_num == 2:
280 | gender, age = vote_user_info(df_data, u, 2)
281 | tmp = pd.DataFrame({
282 | 'user': [u for _ in test_ucands[u]],
283 | 'item': test_ucands[u],
284 | 'gender': [gender for _ in test_ucands[u]],
285 | 'age': [age for _ in test_ucands[u]],
286 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense
287 | })
288 | else:
289 | tmp = pd.DataFrame({
290 | 'user': [u for _ in test_ucands[u]],
291 | 'item': test_ucands[u],
292 | 'rating': [0. for _ in test_ucands[u]], # fake label, make nonsense
293 | })
294 | tmp_neg_set = sampler.transform(tmp, is_training=False)
295 | tmp_dataset = FMData(tmp_neg_set, feature_num, is_training=False)
296 | tmp_loader = data.DataLoader(
297 | tmp_dataset,
298 | batch_size=1000,
299 | shuffle=False,
300 | num_workers=4
301 | )
302 | for user, item, gender, age, label in tmp_loader:
303 | user = user.cuda()
304 | item = item.cuda()
305 | label = label.cuda()
306 | if feature_num == 0:
307 | gender = gender.cuda()
308 | elif feature_num == 1:
309 | age = age.cuda()
310 | elif feature_num == 2:
311 | gender = gender.cuda()
312 | age = age.cuda()
313 |
314 | if feature_num == 0:
315 | prediction = model.predict(user, item, g=gender)
316 | elif feature_num == 1:
317 | prediction = model.predict(user, item, a=age)
318 | elif feature_num == 2:
319 | prediction = model.predict(user, item, g=gender, a=age)
320 | else:
321 | prediction = model.predict(user, item)
322 | _, indices = torch.topk(prediction, 50)
323 | top_n = torch.take(torch.tensor(test_ucands[u]), indices).cpu().numpy()
324 | preds[u] = top_n
325 |
326 | for u in preds.keys():
327 | preds[u] = [1 if i in test_ur[u] else 0 for i in preds[u]]
328 |
329 | print('Save metric@k result to res folder...')
330 | if feature_num == 0:
331 | result_save_path = f'./res/fm_{args.dataset}_audio/gender/'
332 | elif feature_num == 1:
333 | result_save_path = f'./res/fm_{args.dataset}_audio/age/'
334 | elif feature_num == 2:
335 | result_save_path = f'./res/fm_{args.dataset}_audio/gender_age/'
336 | else:
337 | result_save_path = f'./res/fm_{args.dataset}_audio/no_ag/'
338 | if not os.path.exists(result_save_path):
339 | os.makedirs(result_save_path)
340 | res = pd.DataFrame({'metric@K': ['pre', 'rec', 'hr', 'map', 'mrr', 'ndcg']})
341 | for k in [1, 5, 10, 20, 30, 50]:
342 | tmp_preds = preds.copy()
343 | tmp_preds = {key: rank_list[:k] for key, rank_list in tmp_preds.items()}
344 |
345 | pre_k = np.mean([precision_at_k(r, k) for r in tmp_preds.values()])
346 | rec_k = recall_at_k(tmp_preds, test_ur, k)
347 | hr_k = hr_at_k(tmp_preds, test_ur)
348 | map_k = map_at_k(tmp_preds.values())
349 | mrr_k = mrr_at_k(tmp_preds, k)
350 | ndcg_k = np.mean([ndcg_at_k(r, k) for r in tmp_preds.values()])
351 |
352 | if k == 10:
353 | print(f'Precision@{k}: {pre_k:.4f}')
354 | print(f'Recall@{k}: {rec_k:.4f}')
355 | print(f'HR@{k}: {hr_k:.4f}')
356 | print(f'MAP@{k}: {map_k:.4f}')
357 | print(f'MRR@{k}: {mrr_k:.4f}')
358 | print(f'NDCG@{k}: {ndcg_k:.4f}')
359 |
360 | res[k] = np.array([pre_k, rec_k, hr_k, map_k, mrr_k, ndcg_k])
361 |
362 | common_prefix = f'{args.num_ng}_{args.factors}_{args.lr}_{args.batch_size}_{args.val}'
363 |
364 | res.to_csv(
365 | f'{result_save_path}{common_prefix}_results.csv',
366 | index=False
367 | )
368 |
369 |
370 |
371 |
372 |
373 |
--------------------------------------------------------------------------------
/Dialogue/coat_attr.py:
--------------------------------------------------------------------------------
1 | import random
2 | from coat_utils import *
3 |
4 | gender_tag = {'0': ['00', '01', '02'],
5 | '1': ['10', '11']}
6 |
7 | jacket_tag = {'0': ['00', '01', '02'],
8 | '1': ['10', '11'],
9 | '2': ['20'],
10 | '3': ['30']}
11 |
12 | color_tag = {'0': ['00', '01', '02'],
13 | '1': ['10', '11'],
14 | '2': ['20'],
15 | '3': ['30']}
16 |
17 | def index_attr_pattern(utterance_pattern, tags, pattern_mode=None):
18 | '''
19 | return:
20 | [
21 | {
22 | "tag":['00'],
23 | "nl": "Do you want coats for men or women?"
24 | },
25 | {
26 | "tag":['00'],
27 | "nl": "Do you want men's coats or women's coats?"
28 | },
29 | ]
30 | '''
31 | utterances = []
32 | if pattern_mode != None:
33 | tag = tags[str(pattern_mode)]
34 | for utterance in utterance_pattern:
35 | if set(utterance["tag"]) & set(tag):
36 | utterances.append(utterance)
37 | else:
38 | for utterance in utterance_pattern:
39 | if set(utterance["tag"]) & set(tags):
40 | utterances.append(utterance)
41 |
42 | return utterances
43 |
44 | def check_repeat(tmp):
45 | tmp_set = set(tmp)
46 | if len(tmp) == len(tmp_set):
47 | return True
48 | else:
49 | return False
50 |
51 | def generate_gender_dialogue(agent_pattern, user_pattern, gender_val, gender_all, gender_weight):
52 | gender_dialogue = {}
53 | pattern_mode = random.choices([0, 1], weights=[12, 88], k=1)[0]
54 | agent_pattern = index_attr_pattern(agent_pattern["gender"], gender_tag, pattern_mode)
55 | utterance = random.choice(agent_pattern)
56 | utterance_content = utterance['nl']
57 | utterance_tag = utterance['tag']
58 | if pattern_mode == 0:
59 | gender_dialogue['Q'] = utterance_content
60 | utterance = random.choice(index_attr_pattern(user_pattern["gender"], utterance_tag))
61 | utterance_content = utterance['nl']
62 | utterance_content = utterance_content.replace('$gender$', gender_val, 1)
63 | gender_dialogue['A'] = utterance_content[0].upper() + utterance_content[1:]
64 | elif pattern_mode == 1:
65 | gender_tmp = random.choices(sorted(gender_all), weights=gender_weight, k=1)
66 | gender_tmp = get_item_gender(gender_tmp[0])
67 | gender_dialogue['Q'] = utterance_content.replace('$gender$', gender_tmp, 1)
68 | if gender_tmp == gender_val:
69 | utterance = random.choice(index_attr_pattern(user_pattern["gender_pos"], utterance_tag))
70 | else:
71 | utterance = random.choice(index_attr_pattern(user_pattern["gender_neg"], utterance_tag))
72 | gender_dialogue['A'] = utterance['nl'].replace('$gender$', gender_tmp, 1)
73 |
74 | return gender_dialogue
75 |
76 |
77 | def gen_pattern_mode_one(attr, u_content, u_tag, g_tag, other, info):
78 | '''
79 | u_content: Do you like $jacket$ coats?
80 | g_tag: 0 or 1, whether slot ground truth
81 | '''
82 | val = info[0]
83 | all = info[1]
84 | weight = info[2]
85 | user_pattern = info[3]
86 | dialogue = {}
87 | tmp_other = other
88 | slot = '$' + attr + '$'
89 | if attr == 'jacket':
90 | pos = "jacket_pos"
91 | neg = "jacket_neg"
92 | add_val = get_item_type_index(val)
93 | get_item_func = get_item_type
94 | elif attr == 'color':
95 | pos = "color_pos"
96 | neg = "color_neg"
97 | add_val = get_item_color_index(val)
98 | get_item_func = get_item_color
99 | if g_tag:
100 | slot_val = val
101 | dialogue['Q'] = u_content.replace(slot, val, 1)
102 | utterance = random.choice(index_attr_pattern(user_pattern[pos], u_tag))
103 | answer = utterance['nl']
104 | if slot in answer:
105 | answer = answer.replace(slot, slot_val, 1)
106 | dialogue['A'] = answer
107 | tmp = []
108 | else:
109 | tmp_other.append(add_val)
110 |
111 | while True:
112 | tmp = random.choices(sorted(all), weights=weight, k=1)
113 | if not set(tmp_other) & set(tmp):
114 | break
115 | dialogue['Q'] = u_content.replace(slot, get_item_func(tmp[0]), 1)
116 | utterance = random.choice(index_attr_pattern(user_pattern[neg], u_tag))
117 | answer = utterance['nl']
118 | if slot in answer:
119 | answer = answer.replace(slot, get_item_func(tmp[0]), 1)
120 | dialogue['A'] = answer[0].upper() + answer[1:]
121 |
122 | return dialogue['Q'], dialogue['A'], tmp
123 |
124 |
125 | def gen_pattern_mode_two(attr, u_content, u_tag, g_tag, other, info):
126 | '''
127 | u_content: Do you prefer $jacket$ or $jacket$ coats?
128 | g_tag: 0 or 1, whether slot ground truth
129 | '''
130 | val = info[0]
131 | all = info[1]
132 | weight = info[2]
133 | user_pattern = info[3]
134 | dialogue = {}
135 | tmp_other = other
136 | slot = '$' + attr + '$'
137 | if attr == 'jacket':
138 | pos = "jacket"
139 | neg = "jacket_neg"
140 | add_val = get_item_type_index(val)
141 | get_item_func = get_item_type
142 | elif attr == 'color':
143 | pos = "color"
144 | neg = "color_neg"
145 | add_val = get_item_color_index(val)
146 | get_item_func = get_item_color
147 | if g_tag:
148 | tmp_other.append(add_val)
149 | while True:
150 | tmp = random.choices(sorted(all), weights=weight, k=1)
151 | if not set(tmp_other) & set(tmp):
152 | break
153 | order = random.randint(0, 1)
154 | if order:
155 | u_content = u_content.replace(slot, get_item_func(tmp[0]), 1)
156 | u_content = u_content.replace(slot, val, 1)
157 | else:
158 | u_content = u_content.replace(slot, val, 1)
159 | u_content = u_content.replace(slot, get_item_func(tmp[0]), 1)
160 | dialogue['Q'] = u_content
161 | utterance = random.choice(index_attr_pattern(user_pattern[pos], u_tag))
162 | utterance_content = utterance['nl']
163 | utterance_content = utterance_content.replace(slot, val, 1)
164 | dialogue['A'] = utterance_content[0].upper() + utterance_content[1:]
165 | else:
166 | tmp_other.append(add_val)
167 | while True:
168 | tmp = random.choices(sorted(all), weights=weight, k=2)
169 | if (not set(tmp_other) & set(tmp)) and (check_repeat(tmp)):
170 | break
171 | random.shuffle(tmp)
172 | for v in tmp:
173 | u_content = u_content.replace(slot, get_item_func(v), 1)
174 | dialogue['Q'] = u_content
175 | utterance = random.choice(index_attr_pattern(user_pattern[neg], u_tag))
176 | dialogue['A'] = utterance['nl'][0].upper() + utterance['nl'][1:]
177 |
178 | return dialogue['Q'], dialogue['A'], tmp
179 |
180 |
181 | def select_gen_pattern(mode):
182 | if mode == 1:
183 | func = gen_pattern_mode_one
184 | else:
185 | func = gen_pattern_mode_two
186 |
187 | return func
188 |
189 |
190 | def get_agent_content_tag(attr, pattern, mode, tag):
191 | pattern = index_attr_pattern(pattern[attr], tag, mode)
192 | utterance = random.choice(pattern)
193 | utterance_content = utterance['nl']
194 | utterance_tag = utterance['tag']
195 |
196 | return utterance_content, utterance_tag
197 |
198 |
199 | def generate_jacket_dialogue(agent_pattern, user_pattern, jacket_val, jacket_all, jacket_weight):
200 | jacket_dialogue = {}
201 | jacket_info = (jacket_val, jacket_all, jacket_weight, user_pattern)
202 | jacket_other = [6]
203 | if jacket_val == 'other':
204 | pattern_mode = 3
205 | else:
206 | pattern_mode = random.choices([0, 1, 2], weights=[10, 45, 45], k=1)[0]
207 | utterance_content, utterance_tag = get_agent_content_tag("jacket", agent_pattern, pattern_mode, jacket_tag)
208 | if pattern_mode == 0:
209 | jacket_dialogue['Q'] = utterance_content
210 | utterance = random.choice(index_attr_pattern(user_pattern["jacket"], utterance_tag))
211 | utterance_content = utterance['nl']
212 | utterance_content = utterance_content.replace('$jacket$', jacket_val, 1)
213 | jacket_dialogue['A'] = utterance_content[0].upper() + utterance_content[1:]
214 | elif pattern_mode == 1 or pattern_mode == 2:
215 | rounds = random.randint(1,3)
216 | gen_func = select_gen_pattern(pattern_mode)
217 | if rounds == 1:
218 | jacket_dialogue['Q'], jacket_dialogue['A'], _ = gen_func('jacket', utterance_content, utterance_tag, 1, jacket_other, jacket_info)
219 | elif rounds >= 2:
220 | jacket_dialogue['Q'], jacket_dialogue['A'], added_attr = gen_func('jacket', utterance_content, utterance_tag, 0, jacket_other, jacket_info)
221 | last_pm = pattern_mode
222 | pattern_mode = random.randint(1,2)
223 | gen_func = select_gen_pattern(pattern_mode)
224 | jacket_other = jacket_other + added_attr
225 | if rounds == 2:
226 | g_tag = 1
227 | else:
228 | g_tag = 0
229 | utterance_content, utterance_tag = get_agent_content_tag("jacket", agent_pattern, pattern_mode, jacket_tag)
230 | if last_pm == pattern_mode:
231 | jacket_dialogue['Q1'], jacket_dialogue['A1'], added_attr = gen_func('jacket', utterance_content, utterance_tag, g_tag, jacket_other, jacket_info)
232 | else:
233 | jacket_dialogue['Q1'], jacket_dialogue['A1'], added_attr = gen_func('jacket', utterance_content, utterance_tag, g_tag, jacket_other, jacket_info)
234 | if rounds == 3:
235 | g_tag = 1
236 | last_pm = pattern_mode
237 | pattern_mode = random.randint(1,2)
238 | gen_func = select_gen_pattern(pattern_mode)
239 | jacket_other = jacket_other + added_attr
240 | utterance_content, utterance_tag = get_agent_content_tag("jacket", agent_pattern, pattern_mode, jacket_tag)
241 | if last_pm == pattern_mode:
242 | jacket_dialogue['Q2'], jacket_dialogue['A2'], _ = gen_func('jacket', utterance_content, utterance_tag, g_tag, jacket_other, jacket_info)
243 | else:
244 | jacket_dialogue['Q2'], jacket_dialogue['A2'], _ = gen_func('jacket', utterance_content, utterance_tag, g_tag, jacket_other, jacket_info)
245 | elif pattern_mode == 3:
246 | rounds = random.randint(1,2)
247 | while True:
248 | added_attr = random.choices(sorted(jacket_all), weights=jacket_weight, k=3)
249 | if (not set(jacket_other) & set(added_attr)) and (check_repeat(added_attr)):
250 | break
251 | random.shuffle(added_attr)
252 | for v in added_attr:
253 | utterance_content = utterance_content.replace('$jacket$', get_item_type(v), 1)
254 | jacket_dialogue['Q'] = utterance_content
255 | utterance = random.choice(index_attr_pattern(user_pattern['jacket_neg'], utterance_tag))
256 | jacket_dialogue['A'] = utterance['nl'][0].upper() + utterance['nl'][1:]
257 | if rounds == 2:
258 | pattern_mode = random.randint(1,3)
259 | utterance_content, utterance_tag = get_agent_content_tag("jacket", agent_pattern, pattern_mode, jacket_tag)
260 | jacket_other = jacket_other + added_attr
261 | if pattern_mode != 3:
262 | gen_func = select_gen_pattern(pattern_mode)
263 | jacket_dialogue['Q1'], jacket_dialogue['A1'], _ = gen_func('jacket', utterance_content, utterance_tag, 0, jacket_other, jacket_info)
264 | else:
265 | while True:
266 | added_attr = random.choices(sorted(jacket_all), weights=jacket_weight, k=3)
267 | if (not set(jacket_other) & set(added_attr)) and (check_repeat(added_attr)):
268 | break
269 | random.shuffle(added_attr)
270 | for v in added_attr:
271 | utterance_content = utterance_content.replace('$jacket$', get_item_type(v), 1)
272 | jacket_dialogue['Q1'] = utterance_content
273 | utterance = random.choice(index_attr_pattern(user_pattern['jacket_neg'], utterance_tag))
274 | jacket_dialogue['A1'] = utterance['nl'][0].upper() + utterance['nl'][1:]
275 |
276 | return jacket_dialogue
277 |
278 |
279 | def generate_color_dialogue(agent_pattern, user_pattern, color_val, color_all, color_weight):
280 | color_dialogue = {}
281 | color_info = (color_val, color_all, color_weight, user_pattern)
282 | color_other = [9]
283 | if color_val == 'other':
284 | pattern_mode = 3
285 | else:
286 | pattern_mode = random.choices([0, 1, 2], weights=[20, 40, 40], k=1)[0]
287 | utterance_content, utterance_tag = get_agent_content_tag("color", agent_pattern, pattern_mode, color_tag)
288 | if pattern_mode == 0:
289 | color_dialogue['Q'] = utterance_content
290 | utterance = random.choice(index_attr_pattern(user_pattern["color"], utterance_tag))
291 | utterance_content = utterance['nl']
292 | utterance_content = utterance_content.replace('$color$', color_val, 1)
293 | color_dialogue['A'] = utterance_content[0].upper() + utterance_content[1:]
294 | elif pattern_mode == 1 or pattern_mode == 2:
295 | rounds = random.randint(1,3)
296 | gen_func = select_gen_pattern(pattern_mode)
297 | if rounds == 1:
298 | color_dialogue['Q'], color_dialogue['A'], _ = gen_func('color', utterance_content, utterance_tag, 1, color_other, color_info)
299 | elif rounds >= 2:
300 | color_dialogue['Q'], color_dialogue['A'], added_attr = gen_func('color', utterance_content, utterance_tag, 0, color_other, color_info)
301 | last_pm = pattern_mode
302 | pattern_mode = random.randint(1,2)
303 | gen_func = select_gen_pattern(pattern_mode)
304 | color_other = color_other + added_attr
305 | if rounds == 2:
306 | g_tag = 1
307 | else:
308 | g_tag = 0
309 | utterance_content, utterance_tag = get_agent_content_tag("color", agent_pattern, pattern_mode, color_tag)
310 | if last_pm == pattern_mode:
311 | color_dialogue['Q1'], color_dialogue['A1'], added_attr = gen_func('color', utterance_content, utterance_tag, g_tag, color_other, color_info)
312 | else:
313 | color_dialogue['Q1'], color_dialogue['A1'], added_attr = gen_func('color', utterance_content, utterance_tag, g_tag, color_other, color_info)
314 | if rounds == 3:
315 | g_tag = 1
316 | last_pm = pattern_mode
317 | pattern_mode = random.randint(1,2)
318 | gen_func = select_gen_pattern(pattern_mode)
319 | color_other = color_other + added_attr
320 | utterance_content, utterance_tag = get_agent_content_tag("color", agent_pattern, pattern_mode, color_tag)
321 | if last_pm == pattern_mode:
322 | color_dialogue['Q2'], color_dialogue['A2'], _ = gen_func('color', utterance_content, utterance_tag, g_tag, color_other, color_info)
323 | else:
324 | color_dialogue['Q2'], color_dialogue['A2'], _ = gen_func('color', utterance_content, utterance_tag, g_tag, color_other, color_info)
325 | elif pattern_mode == 3:
326 | rounds = random.randint(1,2)
327 | while True:
328 | added_attr = random.choices(sorted(color_all), weights=color_weight, k=3)
329 | if (not set(color_other) & set(added_attr)) and (check_repeat(added_attr)):
330 | break
331 | random.shuffle(added_attr)
332 | for v in added_attr:
333 | utterance_content = utterance_content.replace('$color$', get_item_color(v), 1)
334 | color_dialogue['Q'] = utterance_content
335 | utterance = random.choice(index_attr_pattern(user_pattern['color_neg'], utterance_tag))
336 | color_dialogue['A'] = utterance['nl'][0].upper() + utterance['nl'][1:]
337 | if rounds == 2:
338 | pattern_mode = random.randint(1,3)
339 | utterance_content, utterance_tag = get_agent_content_tag("color", agent_pattern, pattern_mode, color_tag)
340 | color_other = color_other + added_attr
341 | if pattern_mode != 3:
342 | gen_func = select_gen_pattern(pattern_mode)
343 | color_dialogue['Q1'], color_dialogue['A1'], _ = gen_func('color', utterance_content, utterance_tag, 0, color_other, color_info)
344 | else:
345 | while True:
346 | added_attr = random.choices(sorted(color_all), weights=color_weight, k=3)
347 | if (not set(color_other) & set(added_attr)) and (check_repeat(added_attr)):
348 | break
349 | random.shuffle(added_attr)
350 | for v in added_attr:
351 | utterance_content = utterance_content.replace('$color$', get_item_color(v), 1)
352 | color_dialogue['Q1'] = utterance_content
353 | utterance = random.choice(index_attr_pattern(user_pattern['color_neg'], utterance_tag))
354 | color_dialogue['A1'] = utterance['nl'][0].upper() + utterance['nl'][1:]
355 |
356 | return color_dialogue
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
--------------------------------------------------------------------------------