├── .gitignore ├── lib ├── colortext.py ├── decorator.py ├── glb_var.py └── json_util.py ├── model ├── __init__.py └── framework │ ├── deep_wiid.py │ ├── csiid.py │ ├── base.py │ └── autoencoder.py ├── config ├── bird │ ├── bird_encoder_svm_v1.json │ ├── bird_encoder_svm_v5_env0.json │ ├── bird_encoder_svm_v5_env1.json │ ├── bird_encoder_svm_v6_env0.json │ ├── bird_encoder_svm_v7_env0.json │ ├── bird_encoder_svm_v8_env0.json │ ├── bird_v1_wo_aug.json │ ├── bird_v5_env0_wo_aug.json │ ├── bird_v5_env1_wo_aug.json │ ├── bird_v6_env0_wo_aug.json │ ├── bird_v7_env0_wo_aug.json │ ├── bird_v8_env0_wo_aug.json │ ├── bird_encoder_v2_wo_aug+al.json │ ├── bird_encoder_v4_env0_wo_aug.json │ ├── bird_encoder_v4_env1_wo_aug.json │ ├── bird_encoder_v5_env0_wo_aug.json │ ├── bird_encoder_v5_env1_wo_aug.json │ ├── bird_encoder_v6_env0_wo_aug.json │ ├── bird_encoder_v7_env0_wo_aug.json │ ├── bird_encoder_v8_env0_wo_aug.json │ ├── bird_encoder_v2_wo_aug.json │ ├── bird_v1_aug.json │ ├── bird_encoder_v2_wo_al.json │ ├── bird_encoder_v4_env0.json │ └── bird_encoder_v4_env1.json ├── csiid │ ├── csiid_v1.json │ ├── csiid_v4_env0.json │ └── csiid_v4_env1.json ├── wiau │ ├── wiau_id_v1.json │ ├── wiau_id_v4_env0.json │ ├── wiau_id_v4_env1.json │ ├── wiau_v1.json │ ├── wiau_v5_env0.json │ ├── wiau_v5_env1.json │ ├── wiau_v6_env0.json │ ├── wiau_v7_env0.json │ └── wiau_v8_env0.json ├── gateid │ ├── gateid_v1.json │ ├── gateid_v4_env0.json │ └── gateid_v4_env1.json ├── caution │ ├── caution_v3.json │ ├── caution_v6_env0.json │ ├── caution_v5_env1.json │ ├── caution_v5_env0.json │ ├── caution_v7_env0.json │ ├── caution_v8_env0.json │ ├── caution_encoder_v3.json │ ├── caution_encoder_v5_env0.json │ ├── caution_encoder_v5_env1.json │ ├── caution_encoder_v6_env0.json │ ├── caution_encoder_v7_env0.json │ └── caution_encoder_v8_env0.json ├── autoencoder │ ├── cae_v1.json │ ├── cae_v5_env0.json │ ├── cae_v5_env1.json │ ├── cae_v6_env0.json │ ├── cae_v7_env0.json │ ├── cae_v8_env0.json │ ├── ae_v1.json │ ├── ae_v5_env0.json │ ├── ae_v5_env1.json │ ├── ae_v6_env0.json │ ├── ae_v7_env0.json │ └── ae_v8_env0.json ├── deep_wiid │ ├── deep_wiid_v1.json │ ├── deep_wiid_v4_env0.json │ ├── deep_wiid_v4_env1.json │ └── deep_wiid_al_v2.json ├── gait_enhance │ ├── gait_enhance_v1.json │ ├── gait_enhance_v4_env0.json │ ├── gait_enhance_v4_env1.json │ └── gait_enhance_al_v2.json ├── dcs_gait │ ├── dcs_gait_encoder_v1.json │ └── dcs_gait_v2.json └── wiai_id │ ├── wiai_id_v4_env0.json │ ├── wiai_id_v4_env1.json │ └── wiai_id_v2.json ├── requirements.txt ├── README.md ├── data └── augmentation.py └── executor.py /.gitignore: -------------------------------------------------------------------------------- 1 | # IDEs 2 | .vscode 3 | 4 | # Python 5 | *.ipynb 6 | *.pyc 7 | 8 | #file 9 | *.log 10 | *.mat 11 | *.txt 12 | !requirements.txt 13 | *.mlx 14 | *.rar 15 | *.zip 16 | 17 | #Data 18 | __pycache__/ 19 | cache/ 20 | runs/ 21 | model/pretrained 22 | 23 | 24 | -------------------------------------------------------------------------------- /lib/colortext.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023.05.16 2 | # @Author : Darrius Lei 3 | # @Email : darrius.lei@outlook.com 4 | 5 | BLACK = '\033[0;30m'; 6 | RED = '\033[0;31m'; 7 | GREEN = '\033[0;32m'; 8 | YELLOW = '\033[0;33m'; 9 | BLUE = '\033[0;34m'; 10 | PURPLE = '\033[0;35m'; 11 | CYAN = '\033[0;36m'; 12 | WHITE = '\033[0;37m'; 13 | GRAY = '\033[0;38m'; 14 | 15 | RESET = '\033[0m'; 16 | 17 | RED_TRIANGLE = f"{RED}▲{RESET}"; 18 | -------------------------------------------------------------------------------- /lib/decorator.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023.07.08 2 | # @Author : Darrius Lei 3 | # @Email : darrius.lei@outlook.com 4 | 5 | from typing import Any 6 | from lib import glb_var, util 7 | import time 8 | 9 | logger = glb_var.get_value('logger'); 10 | 11 | class Decorator(object): 12 | '''Abstract Decorator class 13 | ''' 14 | def __init__(self, func) -> None: 15 | self.func = func; 16 | 17 | def __call__(self, *args: Any, **kwds: Any) -> Any: 18 | '''Method needs to be called after being implemented''' 19 | raise NotImplementedError; 20 | 21 | class Timer(Decorator): 22 | '''Timer for func''' 23 | def __init__(self, func) -> None: 24 | super().__init__(func); 25 | 26 | def __call__(self, *args: Any, **kwds: Any) -> Any: 27 | t = time.time(); 28 | result = self.func(*args, **kwds); 29 | logger.info(f'The time consumption of {self.func.__name__}: {util.s2hms(time.time() - t)}'); 30 | return result -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from lib import glb_var, util, json_util 4 | 5 | device = glb_var.get_value('device'); 6 | logger = glb_var.get_value('logger'); 7 | 8 | __all__ = ['generate_model', 'load_model']; 9 | 10 | framework_dir = os.path.join(os.path.dirname(__file__), 'framework') 11 | util.load_modules_from_directory(framework_dir, 'model.framework') 12 | 13 | def generate_model(model_cfg): 14 | model_class = glb_var.query_model(model_cfg['name']) 15 | if model_class is None: 16 | raise RuntimeError(f"Unable to find model class named {model_cfg['name']}, please check configuration") 17 | 18 | try: 19 | model = model_class(model_cfg); 20 | except Exception as e: 21 | raise RuntimeError(f"Error instantiating model {model_cfg['name']} : {e}") 22 | num_params = sum(p.numel() for p in model.parameters()) 23 | logger.info(f'The number of parameters of the [{model.name}]: {num_params / 10 ** 6:.6f} M\n'); 24 | return model.to(device); 25 | 26 | def load_model(load_dir): 27 | ''' 28 | if A/B/model.pth then load_dir = A/B/ 29 | ''' 30 | if os.path.isfile(load_dir): 31 | raise RuntimeError('Please use the directory of the file instead of its path'); 32 | cfg = json_util.jsonload(load_dir + '/config.json'); 33 | model = generate_model(cfg['model']); 34 | model.load(load_dir); 35 | return model.to(device); -------------------------------------------------------------------------------- /lib/glb_var.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/12 2 | # @Author : Junwei Lei 3 | # @Email : darrius.lei@outlook.com 4 | 5 | from lib.callback import CustomException 6 | import pydash as ps 7 | from lib import util 8 | ''' 9 | Store global variables for the project 10 | 11 | Methods: 12 | -------- 13 | 14 | __init__():None 15 | initialization 16 | 17 | set_value(key, value):None 18 | set global variables 19 | 20 | get_value(key):value 21 | retrieve global variables 22 | 23 | Examples: 24 | --------- 25 | main.py 26 | >>> import lib.glb_var 27 | >>> glb_var.set_value('a', 1); 28 | >>> submain; 29 | 30 | sub.py 31 | >>> import lib.glb_var 32 | >>> print(glb_var.get_value('a')); 33 | 34 | ''' 35 | def __init__(): 36 | global glb_dict; 37 | glb_dict = dict(); 38 | glb_dict['model'] = {}; 39 | 40 | def set_values(dict, keys = None, except_type = None): 41 | util.set_attr(glb_dict, dict, keys = keys, except_type = except_type); 42 | 43 | def set_value(key, value): 44 | #set global var 45 | glb_dict[key] = value; 46 | 47 | def register_model(key, value): 48 | glb_dict['model'][key] = value; 49 | 50 | def query_model(key): 51 | return glb_dict['model'][key] if key in glb_dict['model'].keys() else None; 52 | 53 | def get_value(key): 54 | try: 55 | return glb_dict[key]; 56 | except KeyError: 57 | raise CustomException(f'The retrieved key [{key}] does not exist'); -------------------------------------------------------------------------------- /config/bird/bird_encoder_svm_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":false, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/alltrain.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":48, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "test":{ 20 | "dir":"data/v1/test.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | } 30 | }, 31 | "train":{ 32 | "is_DA":false, 33 | "max_epoch":100, 34 | "valid_start_epoch":20, 35 | "valid_step":5, 36 | "stop_train_step_valid_not_improve":6, 37 | "valid_metrics":"loss", 38 | "valid_metrics_less":true, 39 | "optimizer":{ 40 | "lr":1e-5, 41 | "weight_decay":1e-4 42 | } 43 | }, 44 | "model":{ 45 | "name":"BIRDEncoderSVM", 46 | "pretrained_enc_dir":"model/pretrained/bird_encoder_opt2/end", 47 | "do":256, 48 | "abnormality_rate":0.01 49 | } 50 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_svm_v5_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":false, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":48, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "test":{ 20 | "dir":"data/v5/test_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | } 30 | }, 31 | "train":{ 32 | "is_DA":false, 33 | "max_epoch":100, 34 | "valid_start_epoch":5, 35 | "valid_step":1, 36 | "stop_train_step_valid_not_improve":30, 37 | "valid_metrics":"loss", 38 | "valid_metrics_less":true, 39 | "optimizer":{ 40 | "lr":1e-5, 41 | "weight_decay":1e-4 42 | } 43 | }, 44 | "model":{ 45 | "name":"BIRDEncoderSVM", 46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v5_env0_opt1/end", 47 | "do":256, 48 | "abnormality_rate":0.3 49 | } 50 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_svm_v5_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":false, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":48, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "test":{ 20 | "dir":"data/v5/test_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | } 30 | }, 31 | "train":{ 32 | "is_DA":false, 33 | "max_epoch":100, 34 | "valid_start_epoch":5, 35 | "valid_step":1, 36 | "stop_train_step_valid_not_improve":30, 37 | "valid_metrics":"loss", 38 | "valid_metrics_less":true, 39 | "optimizer":{ 40 | "lr":1e-5, 41 | "weight_decay":1e-4 42 | } 43 | }, 44 | "model":{ 45 | "name":"BIRDEncoderSVM", 46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v5_env0_opt1/end", 47 | "do":256, 48 | "abnormality_rate":0.3 49 | } 50 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_svm_v6_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":false, 4 | "known_p_num":8, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v6", 8 | "train":{ 9 | "dir":"data/v6/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":48, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "test":{ 20 | "dir":"data/v6/test_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | } 30 | }, 31 | "train":{ 32 | "is_DA":false, 33 | "max_epoch":100, 34 | "valid_start_epoch":5, 35 | "valid_step":1, 36 | "stop_train_step_valid_not_improve":30, 37 | "valid_metrics":"loss", 38 | "valid_metrics_less":true, 39 | "optimizer":{ 40 | "lr":1e-5, 41 | "weight_decay":1e-4 42 | } 43 | }, 44 | "model":{ 45 | "name":"BIRDEncoderSVM", 46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v6_env0_opt2/best", 47 | "do":256, 48 | "abnormality_rate":0.3 49 | } 50 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_svm_v7_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":false, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v7", 8 | "train":{ 9 | "dir":"data/v7/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":48, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "test":{ 20 | "dir":"data/v7/test_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | } 30 | }, 31 | "train":{ 32 | "is_DA":false, 33 | "max_epoch":100, 34 | "valid_start_epoch":5, 35 | "valid_step":1, 36 | "stop_train_step_valid_not_improve":30, 37 | "valid_metrics":"loss", 38 | "valid_metrics_less":true, 39 | "optimizer":{ 40 | "lr":1e-5, 41 | "weight_decay":1e-4 42 | } 43 | }, 44 | "model":{ 45 | "name":"BIRDEncoderSVM", 46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v7_env0_opt2/end", 47 | "do":256, 48 | "abnormality_rate":0.3 49 | } 50 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_svm_v8_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":false, 4 | "known_p_num":6, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v8", 8 | "train":{ 9 | "dir":"data/v8/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":48, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "test":{ 20 | "dir":"data/v8/test_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | } 30 | }, 31 | "train":{ 32 | "is_DA":false, 33 | "max_epoch":100, 34 | "valid_start_epoch":5, 35 | "valid_step":1, 36 | "stop_train_step_valid_not_improve":30, 37 | "valid_metrics":"loss", 38 | "valid_metrics_less":true, 39 | "optimizer":{ 40 | "lr":1e-5, 41 | "weight_decay":1e-4 42 | } 43 | }, 44 | "model":{ 45 | "name":"BIRDEncoderSVM", 46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v8_env0_opt1/best", 47 | "do":256, 48 | "abnormality_rate":0.3 49 | } 50 | } -------------------------------------------------------------------------------- /config/csiid/csiid_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test_legal.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":20, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CSIID" 57 | } 58 | } -------------------------------------------------------------------------------- /config/wiau/wiau_id_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test_legal.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":20, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAUId" 57 | } 58 | } -------------------------------------------------------------------------------- /config/csiid/csiid_v4_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CSIID" 57 | } 58 | } -------------------------------------------------------------------------------- /config/csiid/csiid_v4_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CSIID" 57 | } 58 | } -------------------------------------------------------------------------------- /config/wiau/wiau_id_v4_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAUId" 57 | } 58 | } -------------------------------------------------------------------------------- /config/wiau/wiau_id_v4_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAUId" 57 | } 58 | } -------------------------------------------------------------------------------- /config/gateid/gateid_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":16, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test_legal.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":150, 45 | "valid_start_epoch":20, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"GateID", 57 | "window_size":40 58 | } 59 | } -------------------------------------------------------------------------------- /config/gateid/gateid_v4_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":16, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":150, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"GateID", 57 | "window_size":40 58 | } 59 | } -------------------------------------------------------------------------------- /config/gateid/gateid_v4_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":16, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":150, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"GateID", 57 | "window_size":40 58 | } 59 | } -------------------------------------------------------------------------------- /config/wiau/wiau_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":20, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAU", 57 | "p":0.3, 58 | "d":10, 59 | "lambda_reg":0.01 60 | } 61 | } -------------------------------------------------------------------------------- /config/wiau/wiau_v5_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":32, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAU", 57 | "p":0.3, 58 | "d":10, 59 | "lambda_reg":0.01 60 | } 61 | } -------------------------------------------------------------------------------- /config/wiau/wiau_v5_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":32, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAU", 57 | "p":0.3, 58 | "d":10, 59 | "lambda_reg":0.01 60 | } 61 | } -------------------------------------------------------------------------------- /config/wiau/wiau_v6_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":8, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v6", 8 | "train":{ 9 | "dir":"data/v6/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":32, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v6/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v6/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAU", 57 | "p":0.3, 58 | "d":10, 59 | "lambda_reg":0.01 60 | } 61 | } -------------------------------------------------------------------------------- /config/wiau/wiau_v7_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v7", 8 | "train":{ 9 | "dir":"data/v7/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":32, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v7/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v7/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAU", 57 | "p":0.3, 58 | "d":10, 59 | "lambda_reg":0.01 60 | } 61 | } -------------------------------------------------------------------------------- /config/wiau/wiau_v8_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":6, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v8", 8 | "train":{ 9 | "dir":"data/v8/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":32, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v8/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v8/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WIAU", 57 | "p":0.3, 58 | "d":10, 59 | "lambda_reg":0.01 60 | } 61 | } -------------------------------------------------------------------------------- /config/caution/caution_v3.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":8, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v3", 8 | "train":{ 9 | "dir":"data/v3/support.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v3/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":false, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v3/test.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":50, 45 | "valid_metrics":"loss", 46 | "valid_metrics_less":true, 47 | "optimizer":{ 48 | "lr":1e-4, 49 | "weight_decay":1e-4 50 | } 51 | }, 52 | "model":{ 53 | "name":"Caution", 54 | "update_start_epoch":5, 55 | "update_step":10, 56 | "pretrained_enc_dir":"model/pretrained/caution_encoder_opt1/end", 57 | "initial_threshold":0.5, 58 | "num_iterations":50, 59 | "num_thresholds":20, 60 | "threshold_step":0.05 61 | } 62 | } -------------------------------------------------------------------------------- /config/caution/caution_v6_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":6, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v6", 8 | "train":{ 9 | "dir":"data/v6/support_caution_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v6/valid_caution_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":false, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v6/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":50, 45 | "valid_metrics":"loss", 46 | "valid_metrics_less":true, 47 | "optimizer":{ 48 | "lr":1e-4, 49 | "weight_decay":1e-4 50 | } 51 | }, 52 | "model":{ 53 | "name":"Caution", 54 | "update_start_epoch":5, 55 | "update_step":10, 56 | "pretrained_enc_dir":"", 57 | "initial_threshold":0.75, 58 | "num_iterations":100, 59 | "num_thresholds":200, 60 | "threshold_step":0.001 61 | } 62 | } -------------------------------------------------------------------------------- /config/caution/caution_v5_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/support_caution_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_caution_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":false, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":50, 45 | "valid_metrics":"loss", 46 | "valid_metrics_less":true, 47 | "optimizer":{ 48 | "lr":1e-4, 49 | "weight_decay":1e-4 50 | } 51 | }, 52 | "model":{ 53 | "name":"Caution", 54 | "update_start_epoch":5, 55 | "update_step":10, 56 | "pretrained_enc_dir":"model/pretrained/CautionEncoder_v5_env1_opt1/end", 57 | "initial_threshold":0.5, 58 | "num_iterations":50, 59 | "num_thresholds":20, 60 | "threshold_step":0.05 61 | } 62 | } -------------------------------------------------------------------------------- /config/caution/caution_v5_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/support_caution_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_caution_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":false, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":50, 45 | "valid_metrics":"loss", 46 | "valid_metrics_less":true, 47 | "optimizer":{ 48 | "lr":1e-4, 49 | "weight_decay":1e-4 50 | } 51 | }, 52 | "model":{ 53 | "name":"Caution", 54 | "update_start_epoch":5, 55 | "update_step":10, 56 | "pretrained_enc_dir":"model/pretrained/CautionEncoder_v5_env0_opt1/end", 57 | "initial_threshold":0.75, 58 | "num_iterations":100, 59 | "num_thresholds":200, 60 | "threshold_step":0.001 61 | } 62 | } -------------------------------------------------------------------------------- /config/caution/caution_v7_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":5, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v7", 8 | "train":{ 9 | "dir":"data/v7/support_caution_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v7/valid_caution_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":false, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v7/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":50, 45 | "valid_metrics":"loss", 46 | "valid_metrics_less":true, 47 | "optimizer":{ 48 | "lr":1e-4, 49 | "weight_decay":1e-4 50 | } 51 | }, 52 | "model":{ 53 | "name":"Caution", 54 | "update_start_epoch":5, 55 | "update_step":10, 56 | "pretrained_enc_dir":"model/pretrained/CautionEncoder_v7_env0_opt1/best", 57 | "initial_threshold":0.75, 58 | "num_iterations":100, 59 | "num_thresholds":200, 60 | "threshold_step":0.001 61 | } 62 | } -------------------------------------------------------------------------------- /config/caution/caution_v8_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":4, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v8", 8 | "train":{ 9 | "dir":"data/v8/support_caution_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v8/valid_caution_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":false, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v8/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":50, 45 | "valid_metrics":"loss", 46 | "valid_metrics_less":true, 47 | "optimizer":{ 48 | "lr":1e-4, 49 | "weight_decay":1e-4 50 | } 51 | }, 52 | "model":{ 53 | "name":"Caution", 54 | "update_start_epoch":5, 55 | "update_step":10, 56 | "pretrained_enc_dir":"model/pretrained/CautionEncoder_v8_env0_opt1/best", 57 | "initial_threshold":0.75, 58 | "num_iterations":100, 59 | "num_thresholds":200, 60 | "threshold_step":0.001 61 | } 62 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | asttokens==2.4.1 3 | comm==0.2.2 4 | contourpy==1.3.0 5 | cycler==0.12.1 6 | debugpy==1.8.5 7 | decorator==5.1.1 8 | executing==2.0.1 9 | feedparser==6.0.11 10 | filelock==3.13.1 11 | fonttools==4.53.1 12 | fsspec==2024.2.0 13 | grpcio==1.66.2 14 | h5py==3.11.0 15 | ipykernel==6.29.5 16 | ipython==8.26.0 17 | ipywidgets==8.1.5 18 | jedi==0.19.1 19 | Jinja2==3.1.3 20 | joblib==1.4.2 21 | jupyter_client==8.6.2 22 | jupyter_core==5.7.2 23 | jupyterlab_widgets==3.0.13 24 | kiwisolver==1.4.5 25 | Markdown==3.7 26 | MarkupSafe==2.1.5 27 | matplotlib==3.9.2 28 | matplotlib-inline==0.1.7 29 | mpmath==1.3.0 30 | nest-asyncio==1.6.0 31 | networkx==3.2.1 32 | numpy==1.26.4 33 | nvidia-cublas-cu12==12.4.2.65 34 | nvidia-cuda-cupti-cu12==12.4.99 35 | nvidia-cuda-nvrtc-cu12==12.4.99 36 | nvidia-cuda-runtime-cu12==12.4.99 37 | nvidia-cudnn-cu12==9.1.0.70 38 | nvidia-cufft-cu12==11.2.0.44 39 | nvidia-curand-cu12==10.3.5.119 40 | nvidia-cusolver-cu12==11.6.0.99 41 | nvidia-cusparse-cu12==12.3.0.142 42 | nvidia-nccl-cu12==2.20.5 43 | nvidia-nvjitlink-cu12==12.4.99 44 | nvidia-nvtx-cu12==12.4.99 45 | packaging==24.1 46 | parso==0.8.4 47 | pexpect==4.9.0 48 | pillow==10.4.0 49 | platformdirs==4.2.2 50 | prettytable==3.11.0 51 | prompt_toolkit==3.0.47 52 | protobuf==5.28.2 53 | psutil==6.0.0 54 | ptyprocess==0.7.0 55 | pure_eval==0.2.3 56 | pydash==8.0.3 57 | Pygments==2.18.0 58 | pyparsing==3.1.4 59 | python-dateutil==2.9.0.post0 60 | pyzmq==26.2.0 61 | scikit-learn==1.5.2 62 | scipy==1.14.1 63 | sgmllib3k==1.0.0 64 | six==1.16.0 65 | stack-data==0.6.3 66 | sympy==1.12 67 | tensorboard==2.18.0 68 | tensorboard-data-server==0.7.2 69 | threadpoolctl==3.5.0 70 | torch==2.4.0+cu124 71 | torchaudio==2.4.0+cu124 72 | torchvision==0.19.0+cu124 73 | tornado==6.4.1 74 | tqdm==4.66.5 75 | traitlets==5.14.3 76 | triton==3.0.0 77 | typing_extensions==4.12.2 78 | wcwidth==0.2.13 79 | Werkzeug==3.0.4 80 | widgetsnbextension==4.0.13 81 | -------------------------------------------------------------------------------- /config/autoencoder/cae_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":16, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":16, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":20, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CAE", 57 | "loss":"nmse", 58 | "encoder_hid_layers":[ 59 | 128, 60 | 1024, 61 | 1024 62 | ], 63 | "decoder_hid_layers":[ 64 | 1024, 65 | 1024, 66 | 128 67 | ] 68 | } 69 | } -------------------------------------------------------------------------------- /config/autoencoder/cae_v5_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":16, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CAE", 57 | "loss":"nmse", 58 | "encoder_hid_layers":[ 59 | 128, 60 | 1024, 61 | 1024 62 | ], 63 | "decoder_hid_layers":[ 64 | 1024, 65 | 1024, 66 | 128 67 | ] 68 | } 69 | } -------------------------------------------------------------------------------- /config/autoencoder/cae_v5_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":16, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CAE", 57 | "loss":"nmse", 58 | "encoder_hid_layers":[ 59 | 128, 60 | 1024, 61 | 1024 62 | ], 63 | "decoder_hid_layers":[ 64 | 1024, 65 | 1024, 66 | 128 67 | ] 68 | } 69 | } -------------------------------------------------------------------------------- /config/autoencoder/cae_v6_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":8, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v6", 8 | "train":{ 9 | "dir":"data/v6/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v6/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":16, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v6/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CAE", 57 | "loss":"nmse", 58 | "encoder_hid_layers":[ 59 | 128, 60 | 1024, 61 | 1024 62 | ], 63 | "decoder_hid_layers":[ 64 | 1024, 65 | 1024, 66 | 128 67 | ] 68 | } 69 | } -------------------------------------------------------------------------------- /config/autoencoder/cae_v7_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v7", 8 | "train":{ 9 | "dir":"data/v7/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v7/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":16, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v7/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CAE", 57 | "loss":"nmse", 58 | "encoder_hid_layers":[ 59 | 128, 60 | 1024, 61 | 1024 62 | ], 63 | "decoder_hid_layers":[ 64 | 1024, 65 | 1024, 66 | 128 67 | ] 68 | } 69 | } -------------------------------------------------------------------------------- /config/autoencoder/cae_v8_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":6, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v8", 8 | "train":{ 9 | "dir":"data/v8/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v8/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":16, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v8/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CAE", 57 | "loss":"nmse", 58 | "encoder_hid_layers":[ 59 | 128, 60 | 1024, 61 | 1024 62 | ], 63 | "decoder_hid_layers":[ 64 | 1024, 65 | 1024, 66 | 128 67 | ] 68 | } 69 | } -------------------------------------------------------------------------------- /config/deep_wiid/deep_wiid_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test_legal.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":20, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"DeepWiID", 57 | "gru_cfg":{ 58 | "hidden_size":256, 59 | "num_layers":4, 60 | "batch_first":true 61 | }, 62 | "tail_net":{ 63 | "hid_layers":[ 64 | 128, 65 | 32 66 | ], 67 | "activation_fn":"tanh", 68 | "drop_out":0.5, 69 | "end_with_softmax":true 70 | } 71 | } 72 | } -------------------------------------------------------------------------------- /config/deep_wiid/deep_wiid_v4_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"DeepWiID", 57 | "gru_cfg":{ 58 | "hidden_size":256, 59 | "num_layers":4, 60 | "batch_first":true 61 | }, 62 | "tail_net":{ 63 | "hid_layers":[ 64 | 128, 65 | 32 66 | ], 67 | "activation_fn":"tanh", 68 | "drop_out":0.5, 69 | "end_with_softmax":true 70 | } 71 | } 72 | } -------------------------------------------------------------------------------- /config/deep_wiid/deep_wiid_v4_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":64, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"DeepWiID", 57 | "gru_cfg":{ 58 | "hidden_size":256, 59 | "num_layers":4, 60 | "batch_first":true 61 | }, 62 | "tail_net":{ 63 | "hid_layers":[ 64 | 128, 65 | 32 66 | ], 67 | "activation_fn":"tanh", 68 | "drop_out":0.5, 69 | "end_with_softmax":true 70 | } 71 | } 72 | } -------------------------------------------------------------------------------- /config/gait_enhance/gait_enhance_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":96, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":150, 45 | "valid_start_epoch":20, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"GaitEnhance", 57 | "window_size":40, 58 | "blks_cfg":{ 59 | "hid_layers":[ 60 | 32, 61 | 64, 62 | 128 63 | ] 64 | }, 65 | "dropout_rate":0.2, 66 | "output_layer":{ 67 | "hid_layers":[ 68 | 128, 69 | 64 70 | ], 71 | "activation_fn":"tanh", 72 | "drop_out":0.5, 73 | "end_with_softmax":true 74 | } 75 | } 76 | } -------------------------------------------------------------------------------- /config/gait_enhance/gait_enhance_v4_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":96, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":150, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"GaitEnhance", 57 | "window_size":40, 58 | "blks_cfg":{ 59 | "hid_layers":[ 60 | 32, 61 | 64, 62 | 128 63 | ] 64 | }, 65 | "dropout_rate":0.2, 66 | "output_layer":{ 67 | "hid_layers":[ 68 | 128, 69 | 64 70 | ], 71 | "activation_fn":"tanh", 72 | "drop_out":0.5, 73 | "end_with_softmax":true 74 | } 75 | } 76 | } -------------------------------------------------------------------------------- /config/gait_enhance/gait_enhance_v4_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":96, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":24, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":150, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"GaitEnhance", 57 | "window_size":40, 58 | "blks_cfg":{ 59 | "hid_layers":[ 60 | 32, 61 | 64, 62 | 128 63 | ] 64 | }, 65 | "dropout_rate":0.2, 66 | "output_layer":{ 67 | "hid_layers":[ 68 | 128, 69 | 64 70 | ], 71 | "activation_fn":"tanh", 72 | "drop_out":0.5, 73 | "end_with_softmax":true 74 | } 75 | } 76 | } -------------------------------------------------------------------------------- /config/caution/caution_encoder_v3.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":8, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v3", 8 | "train":{ 9 | "dir":"data/v3/support.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v3/query.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v3/test_legal.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":50, 45 | "valid_start_epoch":1e+5, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-4, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CautionEncoder", 57 | "do":256, 58 | "update_start_epoch":5, 59 | "update_step":10, 60 | "cnn":{ 61 | "hid_layer":[ 62 | 32, 63 | 128, 64 | 128 65 | ], 66 | "activation_fn":"leaky_relu" 67 | }, 68 | "FeatureLayer":{ 69 | "hid_layers":[ 70 | 512 71 | ], 72 | "activation_fn":"tanh", 73 | "drop_out":0.2, 74 | "end_with_softmax":false 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /config/caution/caution_encoder_v5_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/support_caution_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/query_caution_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env0_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1e+5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":5, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CautionEncoder", 57 | "do":256, 58 | "update_start_epoch":5, 59 | "update_step":10, 60 | "cnn":{ 61 | "hid_layer":[ 62 | 32, 63 | 128, 64 | 128 65 | ], 66 | "activation_fn":"leaky_relu" 67 | }, 68 | "FeatureLayer":{ 69 | "hid_layers":[ 70 | 512 71 | ], 72 | "activation_fn":"tanh", 73 | "drop_out":0.2, 74 | "end_with_softmax":false 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /config/caution/caution_encoder_v5_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/support_caution_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/query_caution_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env1_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1e+5, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CautionEncoder", 57 | "do":256, 58 | "update_start_epoch":5, 59 | "update_step":10, 60 | "cnn":{ 61 | "hid_layer":[ 62 | 32, 63 | 128, 64 | 128 65 | ], 66 | "activation_fn":"leaky_relu" 67 | }, 68 | "FeatureLayer":{ 69 | "hid_layers":[ 70 | 512 71 | ], 72 | "activation_fn":"tanh", 73 | "drop_out":0.2, 74 | "end_with_softmax":false 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /config/caution/caution_encoder_v6_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":6, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v6", 8 | "train":{ 9 | "dir":"data/v6/support_caution_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v6/query_caution_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v6/test_env0_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1e+5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":5, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CautionEncoder", 57 | "do":256, 58 | "update_start_epoch":5, 59 | "update_step":10, 60 | "cnn":{ 61 | "hid_layer":[ 62 | 32, 63 | 128, 64 | 128 65 | ], 66 | "activation_fn":"leaky_relu" 67 | }, 68 | "FeatureLayer":{ 69 | "hid_layers":[ 70 | 512 71 | ], 72 | "activation_fn":"tanh", 73 | "drop_out":0.2, 74 | "end_with_softmax":false 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /config/caution/caution_encoder_v7_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":5, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v7", 8 | "train":{ 9 | "dir":"data/v7/support_caution_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v7/query_caution_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v7/test_env0_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1e+5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":5, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CautionEncoder", 57 | "do":256, 58 | "update_start_epoch":5, 59 | "update_step":10, 60 | "cnn":{ 61 | "hid_layer":[ 62 | 32, 63 | 128, 64 | 128 65 | ], 66 | "activation_fn":"leaky_relu" 67 | }, 68 | "FeatureLayer":{ 69 | "hid_layers":[ 70 | 512 71 | ], 72 | "activation_fn":"tanh", 73 | "drop_out":0.2, 74 | "end_with_softmax":false 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /config/caution/caution_encoder_v8_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":4, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v8", 8 | "train":{ 9 | "dir":"data/v8/support_caution_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":1, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v8/query_caution_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v8/test_env0_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1e+5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":5, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"CautionEncoder", 57 | "do":256, 58 | "update_start_epoch":5, 59 | "update_step":10, 60 | "cnn":{ 61 | "hid_layer":[ 62 | 32, 63 | 128, 64 | 128 65 | ], 66 | "activation_fn":"leaky_relu" 67 | }, 68 | "FeatureLayer":{ 69 | "hid_layers":[ 70 | 512 71 | ], 72 | "activation_fn":"tanh", 73 | "drop_out":0.2, 74 | "end_with_softmax":false 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /config/autoencoder/ae_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":48, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":20, 46 | "valid_step":5, 47 | "stop_train_step_valid_not_improve":6, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"AE", 57 | "loss":"nmse", 58 | "do":256, 59 | "encoder_cfg":{ 60 | "hid_layers":[ 61 | 1024, 62 | 1024, 63 | 512 64 | ], 65 | "activation_fn":"relu", 66 | "drop_out":0.5, 67 | "end_with_softmax":false 68 | }, 69 | "decoder_cfg":{ 70 | "hid_layers":[ 71 | 512, 72 | 1024, 73 | 1024 74 | ], 75 | "activation_fn":"relu", 76 | "drop_out":0.5, 77 | "end_with_softmax":false 78 | } 79 | } 80 | } -------------------------------------------------------------------------------- /config/autoencoder/ae_v5_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"AE", 57 | "loss":"nmse", 58 | "do":256, 59 | "encoder_cfg":{ 60 | "hid_layers":[ 61 | 1024, 62 | 1024, 63 | 512 64 | ], 65 | "activation_fn":"relu", 66 | "drop_out":0.5, 67 | "end_with_softmax":false 68 | }, 69 | "decoder_cfg":{ 70 | "hid_layers":[ 71 | 512, 72 | 1024, 73 | 1024 74 | ], 75 | "activation_fn":"relu", 76 | "drop_out":0.5, 77 | "end_with_softmax":false 78 | } 79 | } 80 | } -------------------------------------------------------------------------------- /config/autoencoder/ae_v5_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"AE", 57 | "loss":"nmse", 58 | "do":256, 59 | "encoder_cfg":{ 60 | "hid_layers":[ 61 | 1024, 62 | 1024, 63 | 512 64 | ], 65 | "activation_fn":"relu", 66 | "drop_out":0.5, 67 | "end_with_softmax":false 68 | }, 69 | "decoder_cfg":{ 70 | "hid_layers":[ 71 | 512, 72 | 1024, 73 | 1024 74 | ], 75 | "activation_fn":"relu", 76 | "drop_out":0.5, 77 | "end_with_softmax":false 78 | } 79 | } 80 | } -------------------------------------------------------------------------------- /config/autoencoder/ae_v6_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":8, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v6", 8 | "train":{ 9 | "dir":"data/v6/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v6/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v6/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"AE", 57 | "loss":"nmse", 58 | "do":256, 59 | "encoder_cfg":{ 60 | "hid_layers":[ 61 | 1024, 62 | 1024, 63 | 512 64 | ], 65 | "activation_fn":"relu", 66 | "drop_out":0.5, 67 | "end_with_softmax":false 68 | }, 69 | "decoder_cfg":{ 70 | "hid_layers":[ 71 | 512, 72 | 1024, 73 | 1024 74 | ], 75 | "activation_fn":"relu", 76 | "drop_out":0.5, 77 | "end_with_softmax":false 78 | } 79 | } 80 | } -------------------------------------------------------------------------------- /config/autoencoder/ae_v7_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v7", 8 | "train":{ 9 | "dir":"data/v7/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v7/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v7/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"AE", 57 | "loss":"nmse", 58 | "do":256, 59 | "encoder_cfg":{ 60 | "hid_layers":[ 61 | 1024, 62 | 1024, 63 | 512 64 | ], 65 | "activation_fn":"relu", 66 | "drop_out":0.5, 67 | "end_with_softmax":false 68 | }, 69 | "decoder_cfg":{ 70 | "hid_layers":[ 71 | 512, 72 | 1024, 73 | 1024 74 | ], 75 | "activation_fn":"relu", 76 | "drop_out":0.5, 77 | "end_with_softmax":false 78 | } 79 | } 80 | } -------------------------------------------------------------------------------- /config/autoencoder/ae_v8_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":6, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v8", 8 | "train":{ 9 | "dir":"data/v8/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v8/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v8/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"AE", 57 | "loss":"nmse", 58 | "do":256, 59 | "encoder_cfg":{ 60 | "hid_layers":[ 61 | 1024, 62 | 1024, 63 | 512 64 | ], 65 | "activation_fn":"relu", 66 | "drop_out":0.5, 67 | "end_with_softmax":false 68 | }, 69 | "decoder_cfg":{ 70 | "hid_layers":[ 71 | 512, 72 | 1024, 73 | 1024 74 | ], 75 | "activation_fn":"relu", 76 | "drop_out":0.5, 77 | "end_with_softmax":false 78 | } 79 | } 80 | } -------------------------------------------------------------------------------- /lib/json_util.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023.03.03 2 | # @Author : Darrius Lei 3 | # @Email : darrius.lei@outlook.com 4 | import json, os 5 | 6 | def jsonload(src): 7 | '''Read the json file and convert it to dict 8 | Suitable for files containing only one json string. 9 | 10 | Parameters: 11 | ----------- 12 | 13 | src:chr 14 | the source direction of the jsonfile 15 | 16 | Returns: 17 | -------- 18 | 19 | json_dict: dict 20 | ''' 21 | assert os.path.exists(src); 22 | f = open(src, 'r'); 23 | return json.load(f); 24 | 25 | def jsonparse(src): 26 | '''Read the json file and convert every json string to dict 27 | Applicable to files containing only multiple json strings 28 | 29 | Parameters: 30 | ----------- 31 | 32 | src: chr 33 | the source direction of the jsonfile 34 | 35 | Returns: 36 | -------- 37 | 38 | json_dict: dict 39 | 40 | Example: 41 | --------- 42 | 43 | >>> from json_util import jsonparse 44 | >>> src = './file.json'; 45 | >>> for item in jsonparse(src): 46 | >>> print(item); 47 | ''' 48 | assert os.path.exists(src); 49 | f = open(src, 'r'); 50 | for item in f: 51 | yield json.loads(item) 52 | 53 | def jsonlen(src): 54 | '''Count the number of json data in the json type file 55 | 56 | Parameters: 57 | ----------- 58 | 59 | src:chr 60 | the source direction of the jsonfile 61 | 62 | Returns: 63 | -------- 64 | 65 | cnt:int 66 | 67 | ''' 68 | assert os.path.exists(src); 69 | f = open(src, 'r'); 70 | cnt = 0; 71 | for item in f: 72 | if type(json.loads(item)) == dict: 73 | cnt += 1; 74 | return cnt; 75 | 76 | def dict2jsonstr(dict): 77 | '''turn dict to str 78 | ''' 79 | return json.dumps(dict, indent = 4); 80 | 81 | def jsonsave(dict, tgt): 82 | '''Save the python type dict as a json file 83 | 84 | Parameters: 85 | ----------- 86 | dict:dict 87 | 88 | tgt:str 89 | Destination path (including file name) 90 | ''' 91 | json_data = json.dumps(dict, indent = 4); 92 | json_file = open(tgt, 'w'); 93 | json_file.write(json_data); 94 | json_file.close(); 95 | -------------------------------------------------------------------------------- /config/dcs_gait/dcs_gait_encoder_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":32, 14 | "drop_last":true, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":32, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test_legal.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":32, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":10, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"DCSGaitEncoder", 57 | "do":256, 58 | "cnn_cfg":{ 59 | "hid_layers":[ 60 | 16, 61 | 32 62 | ] 63 | }, 64 | "tail_net_cfg":{ 65 | "hid_layers":[ 66 | 1024 67 | ], 68 | "activation_fn":"tanh", 69 | "drop_out":0.5, 70 | "end_with_softmax":false 71 | }, 72 | "p_classifier":{ 73 | "hid_layers":[ 74 | 64 75 | ], 76 | "activation_fn":"tanh", 77 | "drop_out":0.5, 78 | "end_with_softmax":true 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /config/wiai_id/wiai_id_v4_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":4, 14 | "drop_last":true, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":1, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WiAiId", 57 | "do":256, 58 | "alpha":0.1, 59 | "beta":0.5, 60 | "gamma":0.5, 61 | "MHA_cfg":{ 62 | "d_q":64, 63 | "d_k":64, 64 | "d_v":64, 65 | "n_heads":3 66 | }, 67 | "MultiScaleCNN_cfg":{ 68 | "h_layers":[ 69 | 84 70 | ], 71 | "hid_layers":[ 72 | 64, 73 | 32, 74 | 8, 75 | 1 76 | ] 77 | }, 78 | "p_classifier":{ 79 | "hid_layers":[ 80 | 128, 81 | 32 82 | ] 83 | }, 84 | "env_classifier":{ 85 | "hid_layers":[ 86 | 32 87 | ] 88 | } 89 | 90 | } 91 | } -------------------------------------------------------------------------------- /config/wiai_id/wiai_id_v4_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":4, 14 | "drop_last":true, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":1, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"WiAiId", 57 | "do":256, 58 | "alpha":0.1, 59 | "beta":0.5, 60 | "gamma":0.5, 61 | "MHA_cfg":{ 62 | "d_q":64, 63 | "d_k":64, 64 | "d_v":64, 65 | "n_heads":3 66 | }, 67 | "MultiScaleCNN_cfg":{ 68 | "h_layers":[ 69 | 84 70 | ], 71 | "hid_layers":[ 72 | 64, 73 | 32, 74 | 8, 75 | 1 76 | ] 77 | }, 78 | "p_classifier":{ 79 | "hid_layers":[ 80 | 128, 81 | 32 82 | ] 83 | }, 84 | "env_classifier":{ 85 | "hid_layers":[ 86 | 32 87 | ] 88 | } 89 | 90 | } 91 | } -------------------------------------------------------------------------------- /config/dcs_gait/dcs_gait_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v2/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":16, 14 | "drop_last":false, 15 | "shuffle":false, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "target":{ 20 | "dir":"data/v2/target.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":16, 25 | "drop_last":false, 26 | "shuffle":false, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "valid":{ 31 | "dir":"data/v2/valid.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | }, 41 | "test":{ 42 | "dir":"data/v2/test_legal.mat", 43 | "loader":{ 44 | "num_workers":4, 45 | "pin_memory":true, 46 | "batch_size":8, 47 | "drop_last":false, 48 | "shuffle":true, 49 | "prefetch_factor":4 50 | } 51 | } 52 | }, 53 | "train":{ 54 | "is_DA":true, 55 | "max_epoch":100, 56 | "valid_start_epoch":1, 57 | "valid_step":1, 58 | "stop_train_step_valid_not_improve":30, 59 | "valid_metrics":"acc", 60 | "valid_metrics_less":false, 61 | "optimizer":{ 62 | "lr":1e-6, 63 | "weight_decay":1e-4 64 | } 65 | }, 66 | "model":{ 67 | "name":"DCSGait", 68 | "pretrained_enc_dir":"model/pretrained/DCSGaitEncoder_opt1/best", 69 | "alpha":0.5, 70 | "num_iterations":50, 71 | "MHA_cfg":{ 72 | "d_q":64, 73 | "d_k":64, 74 | "d_v":64, 75 | "n_heads":3 76 | }, 77 | "p_classifier":{ 78 | "hid_layers":[ 79 | 64 80 | ], 81 | "activation_fn":"tanh", 82 | "drop_out":0.5, 83 | "end_with_softmax":true 84 | } 85 | } 86 | } -------------------------------------------------------------------------------- /config/deep_wiid/deep_wiid_al_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v2", 8 | "train":{ 9 | "dir":"data/v2/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v2/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "target":{ 31 | "dir":"data/v2/target.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | }, 41 | "test":{ 42 | "dir":"data/v2/test_legal.mat", 43 | "loader":{ 44 | "num_workers":4, 45 | "pin_memory":true, 46 | "batch_size":8, 47 | "drop_last":false, 48 | "shuffle":true, 49 | "prefetch_factor":4 50 | } 51 | } 52 | }, 53 | "train":{ 54 | "is_DA":true, 55 | "max_epoch":200, 56 | "valid_start_epoch":1, 57 | "valid_step":1, 58 | "stop_train_step_valid_not_improve":30, 59 | "valid_metrics":"acc", 60 | "valid_metrics_less":false, 61 | "optimizer":{ 62 | "lr":1e-5, 63 | "weight_decay":1e-4 64 | } 65 | }, 66 | "model":{ 67 | "name":"DeepWiIDAL", 68 | "gru_cfg":{ 69 | "hidden_size":256, 70 | "num_layers":4, 71 | "batch_first":true 72 | }, 73 | "tail_net":{ 74 | "hid_layers":[ 75 | 128, 76 | 32 77 | ], 78 | "activation_fn":"tanh", 79 | "drop_out":0.5, 80 | "end_with_softmax":true 81 | }, 82 | "env_classifier":{ 83 | "hid_layers":[ 84 | 128, 85 | 64 86 | ], 87 | "activation_fn":"mish", 88 | "drop_out":0.5, 89 | "end_with_softmax":true 90 | }, 91 | "lambda_":1 92 | } 93 | } -------------------------------------------------------------------------------- /config/gait_enhance/gait_enhance_al_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v2", 8 | "train":{ 9 | "dir":"data/v2/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v2/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "target":{ 31 | "dir":"data/v2/target.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | }, 41 | "test":{ 42 | "dir":"data/v2/test_legal.mat", 43 | "loader":{ 44 | "num_workers":4, 45 | "pin_memory":true, 46 | "batch_size":8, 47 | "drop_last":false, 48 | "shuffle":true, 49 | "prefetch_factor":4 50 | } 51 | } 52 | }, 53 | "train":{ 54 | "is_DA":true, 55 | "max_epoch":150, 56 | "valid_start_epoch":1, 57 | "valid_step":1, 58 | "stop_train_step_valid_not_improve":30, 59 | "valid_metrics":"acc", 60 | "valid_metrics_less":false, 61 | "optimizer":{ 62 | "lr":1e-5, 63 | "weight_decay":1e-4 64 | } 65 | }, 66 | "model":{ 67 | "name":"GaitEnhanceAL", 68 | "window_size":40, 69 | "blks_cfg":{ 70 | "hid_layers":[ 71 | 32, 72 | 64, 73 | 128 74 | ] 75 | }, 76 | "dropout_rate":0.2, 77 | "output_layer":{ 78 | "hid_layers":[ 79 | 128, 80 | 64 81 | ], 82 | "activation_fn":"tanh", 83 | "drop_out":0.5, 84 | "end_with_softmax":true 85 | }, 86 | "env_classifier":{ 87 | "hid_layers":[ 88 | 128, 89 | 64 90 | ], 91 | "activation_fn":"mish", 92 | "drop_out":0.5, 93 | "end_with_softmax":true 94 | }, 95 | "lambda_":1 96 | } 97 | } -------------------------------------------------------------------------------- /config/wiai_id/wiai_id_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v2/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":4, 14 | "drop_last":true, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "target":{ 20 | "dir":"data/v2/target.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":4, 25 | "drop_last":true, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "valid":{ 31 | "dir":"data/v2/valid.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | }, 41 | "test":{ 42 | "dir":"data/v2/test_legal.mat", 43 | "loader":{ 44 | "num_workers":4, 45 | "pin_memory":true, 46 | "batch_size":8, 47 | "drop_last":false, 48 | "shuffle":true, 49 | "prefetch_factor":4 50 | } 51 | } 52 | }, 53 | "train":{ 54 | "is_DA":true, 55 | "max_epoch":30, 56 | "valid_start_epoch":1, 57 | "valid_step":1, 58 | "stop_train_step_valid_not_improve":10, 59 | "valid_metrics":"acc", 60 | "valid_metrics_less":false, 61 | "optimizer":{ 62 | "lr":1e-5, 63 | "weight_decay":1e-4 64 | } 65 | }, 66 | "model":{ 67 | "name":"WiAiId", 68 | "do":256, 69 | "alpha":0.1, 70 | "beta":0.5, 71 | "gamma":0.5, 72 | "MHA_cfg":{ 73 | "d_q":64, 74 | "d_k":64, 75 | "d_v":64, 76 | "n_heads":3 77 | }, 78 | "MultiScaleCNN_cfg":{ 79 | "h_layers":[ 80 | 84 81 | ], 82 | "hid_layers":[ 83 | 64, 84 | 32, 85 | 8, 86 | 1 87 | ] 88 | }, 89 | "p_classifier":{ 90 | "hid_layers":[ 91 | 128, 92 | 32 93 | ] 94 | }, 95 | "env_classifier":{ 96 | "hid_layers":[ 97 | 32 98 | ] 99 | } 100 | 101 | } 102 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IDLab 2 | 3 | Code for Identity recognition and intrusion detection in wireless sensing. 4 | 5 | **Dataset acquisition:** 6 | Wait for the public release 7 | 8 | ## Command 9 | 10 | ### usage 11 | 12 | ```shell 13 | usage: executor.py [-h] [--config CONFIG] [--saved_config SAVED_CONFIG] [--auto_shutdown AUTO_SHUTDOWN] 14 | ``` 15 | 16 | ### options 17 | 18 | ```shell 19 | -h, --help show this help message and exit 20 | --config CONFIG, -cfg CONFIG 21 | config for run 22 | --saved_config SAVED_CONFIG, -sc SAVED_CONFIG 23 | path for saved config to test 24 | --auto_shutdown AUTO_SHUTDOWN, -as AUTO_SHUTDOWN 25 | automatic shutdown after program completion 26 | ``` 27 | 28 | ### quick start 29 | ```shell 30 | python executor.py -cfg=path/to/config 31 | ``` 32 | 33 | ## Model 34 | 35 | - [x] Gait-Enhance[[1](#ref1)] 36 | - [x] Deep-WiID[[2](#ref2)] 37 | - [x] Caution[[3](#ref3)] 38 | - [x] CSIID[[4](#ref4)] 39 | - [x] Gate-ID[[5](#ref5)] 40 | - [x] WiAU[[6](#ref6)] 41 | - [x] WiAi-ID[[7](#ref7)] 42 | - [x] DCS-Gait[[8](#ref8)] 43 | - [x] Bird 44 | 45 | 46 | 47 | # Refrence 48 | 49 | 1. Yang J, Liu Y, Wu Y, et al. Gait-Enhance: Robust Gait Recognition of Complex Walking Patterns Based on WiFi CSI[C]//2023 IEEE Smart World Congress (SWC). IEEE, 2023: 1-9. 50 | 2. Zhou Z, Liu C, Yu X, et al. Deep-WiID: WiFi-based contactless human identification via deep learning[C]//2019 IEEE SmartWorld, Ubiquitous Intelligence & Computing, Advanced & Trusted Computing, Scalable Computing & Communications, Cloud & Big Data Computing, Internet of People and Smart City Innovation (SmartWorld/SCALCOM/UIC/ATC/CBDCom/IOP/SCI). IEEE, 2019: 877-884. 51 | 3. Wang D, Yang J, Cui W, et al. CAUTION: A Robust WiFi-based human authentication system via few-shot open-set recognition[J]. IEEE Internet of Things Journal, 2022, 9(18): 17323-17333. 52 | 4. Wang D, Zhou Z, Yu X, et al. CSIID: WiFi-based human identification via deep learning[C]//2019 14th International Conference on Computer Science & Education (ICCSE). IEEE, 2019: 326-330. 53 | 5. Zhang J, Wei B, Wu F, et al. Gate-ID: WiFi-based human identification irrespective of walking directions in smart home[J]. IEEE Internet of Things Journal, 2020, 8(9): 7610-7624. 54 | 6. Lin C, Hu J, Sun Y, et al. WiAU: An accurate device-free authentication system with ResNet[C]//2018 15th Annual IEEE International Conference on Sensing, Communication, and Networking (SECON). IEEE, 2018: 1-9. 55 | 7. Liang Y, Wu W, Li H, et al. WiAi-ID: Wi-Fi-Based Domain Adaptation for Appearance-Independent Passive Person Identification[J]. IEEE Internet of Things Journal, 2023, 11(1): 1012-1027. 56 | 8. Liang Y, Wu W, Li H, et al. DCS-Gait: A Class-Level Domain Adaptation Approach for Cross-Scene and Cross-State Gait Recognition Using Wi-Fi CSI[J]. IEEE Transactions on Information Forensics and Security, 2024. -------------------------------------------------------------------------------- /model/framework/deep_wiid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from lib import util, glb_var 4 | from model.framework.base import Net 5 | from model import net_util 6 | 7 | class DeepWiID(Net): 8 | def __init__(self, model_cfg) -> None: 9 | super().__init__(model_cfg); 10 | #GRU 11 | gru_cfg = model_cfg['gru_cfg']; 12 | gru_cfg['input_size'] = np.prod(self.dim_in[2:]); 13 | self.gru = torch.nn.GRU(**gru_cfg); 14 | #Tail 15 | #tail_net 16 | tail_net_cfg = model_cfg['tail_net']; 17 | tail_net_cfg['dim_in'] = gru_cfg['hidden_size']; 18 | tail_net_cfg['dim_out'] = self.known_p_num; 19 | #There is only one layer of network, and these two parameters are not needed 20 | tail_net_cfg['activation_fn'] = net_util.get_activation_fn(tail_net_cfg['activation_fn']) if 'activation_fn' in tail_net_cfg.keys() else None; 21 | tail_net_cfg['drop_out'] = tail_net_cfg['drop_out'] if 'drop_out' in tail_net_cfg.keys() else None; 22 | self.tail_net = util.get_func_rets(net_util.get_mlp_net, tail_net_cfg); 23 | 24 | self.is_Intrusion_Detection = False; 25 | 26 | @torch.no_grad() 27 | def encoder(self, amps): 28 | #amps:[B, T, R * F] 29 | amps = amps.reshape(amps.shape[0], amps.shape[1], -1); 30 | #[B, T, d] 31 | features, _ = self.gru(amps); 32 | #[B, d] 33 | features = features.mean(dim = 1); 34 | return features; 35 | 36 | def p_classify(self, amps): 37 | #amps:[B, T, R * F] 38 | #[B, d] 39 | features = self.encoder(amps); 40 | id_probs = self.tail_net(features); 41 | return id_probs; 42 | 43 | def cal_loss(self, amps, ids, envs): 44 | id_probs = self.p_classify(amps); 45 | return torch.nn.CrossEntropyLoss()(id_probs, ids); 46 | 47 | class DeepWiIDAL(DeepWiID): 48 | def __init__(self, model_cfg): 49 | super().__init__(model_cfg); 50 | 51 | #env layer 52 | env_cfg = model_cfg['env_classifier']; 53 | env_cfg['activation_fn'] = net_util.get_activation_fn(env_cfg['activation_fn']); 54 | env_cfg['dim_in'] = model_cfg['tail_net']['dim_in']; 55 | env_cfg['dim_out'] = self.known_env_num; 56 | self.env_layer = util.get_func_rets(net_util.get_mlp_net, env_cfg); 57 | 58 | def cal_loss(self, amps, ids, envs, amps_t = None, ids_t = None, envs_t = None): 59 | features = self.encoder(amps); 60 | id_probs = self.tail_net(features); 61 | loss_id = torch.nn.CrossEntropyLoss()(id_probs, ids); 62 | 63 | if amps_t is None: 64 | return loss_id; 65 | 66 | env_probs = self.env_layer(net_util.GradientReversalF.apply(features, self.lambda_)); 67 | loss_env = torch.nn.CrossEntropyLoss()(env_probs, envs); 68 | 69 | loss_t = loss_id + loss_env; 70 | 71 | feature_t = self.encoder(amps_t); 72 | id_probs_t = self.tail_net(feature_t); 73 | id_loss_t = torch.nn.CrossEntropyLoss()(id_probs_t, ids_t); 74 | 75 | return loss_t + id_loss_t; 76 | 77 | glb_var.register_model('DeepWiID', DeepWiID); 78 | glb_var.register_model('DeepWiIDAL', DeepWiIDAL); -------------------------------------------------------------------------------- /config/bird/bird_v1_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v1/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v1/test.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":50, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRD", 57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_opt4/end", 58 | "loss":"mse", 59 | "do":256, 60 | "d":16, 61 | "t":64, 62 | "TLinear1":{ 63 | "hid_layers":[ 64 | 8, 65 | 32 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":4, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "FLinear":{ 84 | "hid_layers":[ 85 | 16, 86 | 32 87 | ], 88 | "activation_fn":"mish", 89 | "drop_out":0.5, 90 | "end_with_softmax":false 91 | }, 92 | "RLinear":{ 93 | "hid_layers":[ 94 | 16, 95 | 8 96 | ], 97 | "activation_fn":"mish", 98 | "drop_out":0.5, 99 | "end_with_softmax":false 100 | }, 101 | "TLinear2":{ 102 | "hid_layers":[ 103 | 256, 104 | 1024, 105 | 4096 106 | ], 107 | "activation_fn":"mish", 108 | "drop_out":0.5, 109 | "end_with_softmax":false 110 | } 111 | } 112 | } -------------------------------------------------------------------------------- /config/bird/bird_v5_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":10, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRD", 57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v5_env0_opt1/end", 58 | "loss":"mse", 59 | "do":256, 60 | "d":16, 61 | "t":64, 62 | "TLinear1":{ 63 | "hid_layers":[ 64 | 8, 65 | 32 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":4, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "FLinear":{ 84 | "hid_layers":[ 85 | 16, 86 | 32 87 | ], 88 | "activation_fn":"mish", 89 | "drop_out":0.5, 90 | "end_with_softmax":false 91 | }, 92 | "RLinear":{ 93 | "hid_layers":[ 94 | 16, 95 | 8 96 | ], 97 | "activation_fn":"mish", 98 | "drop_out":0.5, 99 | "end_with_softmax":false 100 | }, 101 | "TLinear2":{ 102 | "hid_layers":[ 103 | 256, 104 | 1024, 105 | 4096 106 | ], 107 | "activation_fn":"mish", 108 | "drop_out":0.5, 109 | "end_with_softmax":false 110 | } 111 | } 112 | } -------------------------------------------------------------------------------- /config/bird/bird_v5_env1_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":90, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRD", 57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v5_env1_opt1/end", 58 | "loss":"mse", 59 | "do":256, 60 | "d":16, 61 | "t":64, 62 | "TLinear1":{ 63 | "hid_layers":[ 64 | 8, 65 | 32 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":4, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "FLinear":{ 84 | "hid_layers":[ 85 | 16, 86 | 32 87 | ], 88 | "activation_fn":"mish", 89 | "drop_out":0.5, 90 | "end_with_softmax":false 91 | }, 92 | "RLinear":{ 93 | "hid_layers":[ 94 | 16, 95 | 8 96 | ], 97 | "activation_fn":"mish", 98 | "drop_out":0.5, 99 | "end_with_softmax":false 100 | }, 101 | "TLinear2":{ 102 | "hid_layers":[ 103 | 256, 104 | 1024, 105 | 4096 106 | ], 107 | "activation_fn":"mish", 108 | "drop_out":0.5, 109 | "end_with_softmax":false 110 | } 111 | } 112 | } -------------------------------------------------------------------------------- /config/bird/bird_v6_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":8, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v6", 8 | "train":{ 9 | "dir":"data/v6/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v6/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v6/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":10, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRD", 57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v6_env0_opt2/best", 58 | "loss":"mse", 59 | "do":256, 60 | "d":16, 61 | "t":64, 62 | "TLinear1":{ 63 | "hid_layers":[ 64 | 8, 65 | 32 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":4, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "FLinear":{ 84 | "hid_layers":[ 85 | 16, 86 | 32 87 | ], 88 | "activation_fn":"mish", 89 | "drop_out":0.5, 90 | "end_with_softmax":false 91 | }, 92 | "RLinear":{ 93 | "hid_layers":[ 94 | 16, 95 | 8 96 | ], 97 | "activation_fn":"mish", 98 | "drop_out":0.5, 99 | "end_with_softmax":false 100 | }, 101 | "TLinear2":{ 102 | "hid_layers":[ 103 | 256, 104 | 1024, 105 | 4096 106 | ], 107 | "activation_fn":"mish", 108 | "drop_out":0.5, 109 | "end_with_softmax":false 110 | } 111 | } 112 | } -------------------------------------------------------------------------------- /config/bird/bird_v7_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v7", 8 | "train":{ 9 | "dir":"data/v7/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":16, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v7/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v7/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":10, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRD", 57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v7_env0_opt2/end", 58 | "loss":"mse", 59 | "do":256, 60 | "d":16, 61 | "t":64, 62 | "TLinear1":{ 63 | "hid_layers":[ 64 | 8, 65 | 32 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":4, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "FLinear":{ 84 | "hid_layers":[ 85 | 16, 86 | 32 87 | ], 88 | "activation_fn":"mish", 89 | "drop_out":0.5, 90 | "end_with_softmax":false 91 | }, 92 | "RLinear":{ 93 | "hid_layers":[ 94 | 16, 95 | 8 96 | ], 97 | "activation_fn":"mish", 98 | "drop_out":0.5, 99 | "end_with_softmax":false 100 | }, 101 | "TLinear2":{ 102 | "hid_layers":[ 103 | 256, 104 | 1024, 105 | 4096 106 | ], 107 | "activation_fn":"mish", 108 | "drop_out":0.5, 109 | "end_with_softmax":false 110 | } 111 | } 112 | } -------------------------------------------------------------------------------- /config/bird/bird_v8_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":6, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v8", 8 | "train":{ 9 | "dir":"data/v8/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":16, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v8/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v8/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":24, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":10, 48 | "valid_metrics":"loss", 49 | "valid_metrics_less":true, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRD", 57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v8_env0_opt1/best", 58 | "loss":"mse", 59 | "do":256, 60 | "d":16, 61 | "t":64, 62 | "TLinear1":{ 63 | "hid_layers":[ 64 | 8, 65 | 32 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":4, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "FLinear":{ 84 | "hid_layers":[ 85 | 16, 86 | 32 87 | ], 88 | "activation_fn":"mish", 89 | "drop_out":0.5, 90 | "end_with_softmax":false 91 | }, 92 | "RLinear":{ 93 | "hid_layers":[ 94 | 16, 95 | 8 96 | ], 97 | "activation_fn":"mish", 98 | "drop_out":0.5, 99 | "end_with_softmax":false 100 | }, 101 | "TLinear2":{ 102 | "hid_layers":[ 103 | 256, 104 | 1024, 105 | 4096 106 | ], 107 | "activation_fn":"mish", 108 | "drop_out":0.5, 109 | "end_with_softmax":false 110 | } 111 | } 112 | } -------------------------------------------------------------------------------- /data/augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def data_augmentation(batch, aug_cfg, device): 5 | def _augmentation(amps, ids, envs, s_aug_cfg, device): 6 | if s_aug_cfg['type'].lower() == 'gaussian_noise': 7 | amps_aug, ids_aug, envs_aug = add_gaussian_noise(amps, ids, envs, s_aug_cfg, device); 8 | elif s_aug_cfg['type'].lower() == 'multipath_fading': 9 | amps_aug, ids_aug, envs_aug = mutipath_fading(amps, ids, envs, s_aug_cfg, device); 10 | elif s_aug_cfg['type'].lower() == 'freq_selective_fading': 11 | amps_aug, ids_aug, envs_aug = freq_selective_fading(amps, ids, envs, s_aug_cfg, device); 12 | elif s_aug_cfg['type'].lower() == 'antenna_sequence': 13 | amps_aug, ids_aug, envs_aug = antenna_sequence(amps, ids, envs, s_aug_cfg, device); 14 | elif s_aug_cfg['type'].lower() == 'subcarrier_mask': 15 | amps_aug, ids_aug, envs_aug = subcarrier_mask(amps, ids, envs, s_aug_cfg, device); 16 | else: 17 | raise RuntimeError; 18 | 19 | return amps_aug, ids_aug, envs_aug 20 | amps, ids, envs = batch; 21 | new_amps, new_ids, new_envs = amps.clone(), ids.clone(), envs.clone(); 22 | for s_aug_cfg in aug_cfg.values(): 23 | amps_aug, ids_aug, envs_aug = _augmentation(amps.clone(), ids, envs, s_aug_cfg, device); 24 | new_amps = torch.cat((new_amps, amps_aug), dim = 0); 25 | new_ids = torch.cat((new_ids, ids_aug), dim = 0); 26 | new_envs = torch.cat((new_envs, envs_aug), dim = 0); 27 | return new_amps, new_ids, new_envs 28 | 29 | 30 | def add_gaussian_noise(amps, ids, envs, s_aug_cfg, device): 31 | snr_db = s_aug_cfg['snr_db']; 32 | amps_n = amps + torch.normal(0, torch.sqrt(torch.mean(amps**2) / (10**(snr_db/10))), amps.shape, dtype = torch.float32, device = device); 33 | return amps_n, ids, envs; 34 | 35 | def mutipath_fading(amps, ids, envs, s_aug_cfg, device): 36 | delay = s_aug_cfg['delay']; 37 | coef = s_aug_cfg['coef']; 38 | n = ids.shape[0]; 39 | mp = np.random.choice(len(delay), n); 40 | amps_mp = amps.clone(); 41 | for i in range(n): 42 | d = delay[mp[i]]; 43 | c = coef[mp[i]]; 44 | for j in range(len(d)): 45 | amps_mp[i, d[j]:, :] = amps_mp[i, d[j]:, :] + c[j] * amps[i, :-d[j], :]; 46 | return amps_mp, ids, envs; 47 | 48 | def freq_selective_fading(amps, ids, envs, s_aug_cfg, device): 49 | #amps:[N, T, R, F] 50 | num_f = amps.shape[-1]; 51 | scopes = s_aug_cfg['scope']; 52 | scope = scopes[np.random.choice(len(scopes), 1, replace = False)[0]]; 53 | coefs = torch.linspace(start = scope[0], end = scope[1], steps = num_f, device = device); 54 | return amps * coefs, ids, envs; 55 | 56 | def antenna_sequence(amps, ids, envs, s_aug_cfg, device): 57 | #amps:[N, T, R, F] 58 | R = amps.shape[2]; 59 | ant_idxs = np.random.choice(R, R, replace = False); 60 | amps_ant = amps[:, :, ant_idxs, :]; 61 | return amps_ant, ids, envs; 62 | 63 | def subcarrier_mask(amps, ids, envs, s_aug_cfg, device): 64 | #amps:[N, T, R, F] 65 | n_sc = int(s_aug_cfg['ratio'] * amps.shape[-1]); 66 | sc_idxs = np.random.choice(amps.shape[-1], n_sc, replace = False); 67 | amps[:, :, :, sc_idxs] = torch.zeros_like(amps[:, :, :, :n_sc]); 68 | return amps, ids, envs; 69 | -------------------------------------------------------------------------------- /config/bird/bird_encoder_v2_wo_aug+al.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v2", 8 | "train":{ 9 | "dir":"data/v2/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v2/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v2/test_legal.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRDEncoder", 57 | "is_norm_first":true, 58 | "d":64, 59 | "t":64, 60 | "do":256, 61 | "TLinear":{ 62 | "hid_layers":[ 63 | 4096, 64 | 1024, 65 | 256 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":2, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "trans2":{ 84 | "max_len":600, 85 | "is_norm_first":true, 86 | "d_fc":2048, 87 | "n_heads": 4, 88 | "n_layers": 5 89 | }, 90 | "tail_net":{ 91 | "hid_layers":[ 92 | 1024, 93 | 512, 94 | 256 95 | ], 96 | "activation_fn":"mish", 97 | "drop_out":0.5, 98 | "end_with_softmax":false 99 | }, 100 | "p_classifier":{ 101 | "hid_layers":[ 102 | 128, 103 | 64 104 | ], 105 | "activation_fn":"mish", 106 | "drop_out":0.5, 107 | "end_with_softmax":true 108 | }, 109 | "env_classifier":{ 110 | "hid_layers":[ 111 | 128, 112 | 64 113 | ], 114 | "activation_fn":"mish", 115 | "drop_out":0.5, 116 | "end_with_softmax":true 117 | }, 118 | "lambda_":1 119 | } 120 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v4_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env0.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":50, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRDEncoder", 57 | "is_norm_first":true, 58 | "d":64, 59 | "t":64, 60 | "do":256, 61 | "TLinear":{ 62 | "hid_layers":[ 63 | 4096, 64 | 1024, 65 | 256 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":2, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "trans2":{ 84 | "max_len":600, 85 | "is_norm_first":true, 86 | "d_fc":2048, 87 | "n_heads": 4, 88 | "n_layers": 5 89 | }, 90 | "tail_net":{ 91 | "hid_layers":[ 92 | 1024, 93 | 512, 94 | 256 95 | ], 96 | "activation_fn":"mish", 97 | "drop_out":0.5, 98 | "end_with_softmax":false 99 | }, 100 | "p_classifier":{ 101 | "hid_layers":[ 102 | 128, 103 | 64 104 | ], 105 | "activation_fn":"mish", 106 | "drop_out":0.5, 107 | "end_with_softmax":true 108 | }, 109 | "env_classifier":{ 110 | "hid_layers":[ 111 | 128, 112 | 64 113 | ], 114 | "activation_fn":"mish", 115 | "drop_out":0.5, 116 | "end_with_softmax":true 117 | }, 118 | "lambda_":-1 119 | } 120 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v4_env1_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v4/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v4/test_env1.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":5, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRDEncoder", 57 | "is_norm_first":true, 58 | "d":64, 59 | "t":64, 60 | "do":256, 61 | "TLinear":{ 62 | "hid_layers":[ 63 | 4096, 64 | 1024, 65 | 256 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":2, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "trans2":{ 84 | "max_len":600, 85 | "is_norm_first":true, 86 | "d_fc":2048, 87 | "n_heads": 4, 88 | "n_layers": 5 89 | }, 90 | "tail_net":{ 91 | "hid_layers":[ 92 | 1024, 93 | 512, 94 | 256 95 | ], 96 | "activation_fn":"mish", 97 | "drop_out":0.5, 98 | "end_with_softmax":false 99 | }, 100 | "p_classifier":{ 101 | "hid_layers":[ 102 | 128, 103 | 64 104 | ], 105 | "activation_fn":"mish", 106 | "drop_out":0.5, 107 | "end_with_softmax":true 108 | }, 109 | "env_classifier":{ 110 | "hid_layers":[ 111 | 128, 112 | 64 113 | ], 114 | "activation_fn":"mish", 115 | "drop_out":0.5, 116 | "end_with_softmax":true 117 | }, 118 | "lambda_":-1 119 | } 120 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v5_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env0_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":100, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRDEncoder", 57 | "is_norm_first":true, 58 | "d":64, 59 | "t":64, 60 | "do":256, 61 | "TLinear":{ 62 | "hid_layers":[ 63 | 4096, 64 | 1024, 65 | 256 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":2, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "trans2":{ 84 | "max_len":600, 85 | "is_norm_first":true, 86 | "d_fc":2048, 87 | "n_heads": 4, 88 | "n_layers": 5 89 | }, 90 | "tail_net":{ 91 | "hid_layers":[ 92 | 1024, 93 | 512, 94 | 256 95 | ], 96 | "activation_fn":"mish", 97 | "drop_out":0.5, 98 | "end_with_softmax":false 99 | }, 100 | "p_classifier":{ 101 | "hid_layers":[ 102 | 128, 103 | 64 104 | ], 105 | "activation_fn":"mish", 106 | "drop_out":0.5, 107 | "end_with_softmax":true 108 | }, 109 | "env_classifier":{ 110 | "hid_layers":[ 111 | 128, 112 | 64 113 | ], 114 | "activation_fn":"mish", 115 | "drop_out":0.5, 116 | "end_with_softmax":true 117 | }, 118 | "lambda_":-1 119 | } 120 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v5_env1_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":9, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v5", 8 | "train":{ 9 | "dir":"data/v5/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v5/valid_env1.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v5/test_env1_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":50, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRDEncoder", 57 | "is_norm_first":true, 58 | "d":64, 59 | "t":64, 60 | "do":256, 61 | "TLinear":{ 62 | "hid_layers":[ 63 | 4096, 64 | 1024, 65 | 256 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":2, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "trans2":{ 84 | "max_len":600, 85 | "is_norm_first":true, 86 | "d_fc":2048, 87 | "n_heads": 4, 88 | "n_layers": 5 89 | }, 90 | "tail_net":{ 91 | "hid_layers":[ 92 | 1024, 93 | 512, 94 | 256 95 | ], 96 | "activation_fn":"mish", 97 | "drop_out":0.5, 98 | "end_with_softmax":false 99 | }, 100 | "p_classifier":{ 101 | "hid_layers":[ 102 | 128, 103 | 64 104 | ], 105 | "activation_fn":"mish", 106 | "drop_out":0.5, 107 | "end_with_softmax":true 108 | }, 109 | "env_classifier":{ 110 | "hid_layers":[ 111 | 128, 112 | 64 113 | ], 114 | "activation_fn":"mish", 115 | "drop_out":0.5, 116 | "end_with_softmax":true 117 | }, 118 | "lambda_":-1 119 | } 120 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v6_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":8, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v6", 8 | "train":{ 9 | "dir":"data/v6/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v6/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v6/test_env0_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRDEncoder", 57 | "is_norm_first":true, 58 | "d":64, 59 | "t":64, 60 | "do":256, 61 | "TLinear":{ 62 | "hid_layers":[ 63 | 4096, 64 | 1024, 65 | 256 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":2, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "trans2":{ 84 | "max_len":600, 85 | "is_norm_first":true, 86 | "d_fc":2048, 87 | "n_heads": 4, 88 | "n_layers": 5 89 | }, 90 | "tail_net":{ 91 | "hid_layers":[ 92 | 1024, 93 | 512, 94 | 256 95 | ], 96 | "activation_fn":"mish", 97 | "drop_out":0.5, 98 | "end_with_softmax":false 99 | }, 100 | "p_classifier":{ 101 | "hid_layers":[ 102 | 128, 103 | 64 104 | ], 105 | "activation_fn":"mish", 106 | "drop_out":0.5, 107 | "end_with_softmax":true 108 | }, 109 | "env_classifier":{ 110 | "hid_layers":[ 111 | 128, 112 | 64 113 | ], 114 | "activation_fn":"mish", 115 | "drop_out":0.5, 116 | "end_with_softmax":true 117 | }, 118 | "lambda_":-1 119 | } 120 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v7_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":7, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v7", 8 | "train":{ 9 | "dir":"data/v7/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v7/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v7/test_env0_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRDEncoder", 57 | "is_norm_first":true, 58 | "d":64, 59 | "t":64, 60 | "do":256, 61 | "TLinear":{ 62 | "hid_layers":[ 63 | 4096, 64 | 1024, 65 | 256 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":2, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "trans2":{ 84 | "max_len":600, 85 | "is_norm_first":true, 86 | "d_fc":2048, 87 | "n_heads": 4, 88 | "n_layers": 5 89 | }, 90 | "tail_net":{ 91 | "hid_layers":[ 92 | 1024, 93 | 512, 94 | 256 95 | ], 96 | "activation_fn":"mish", 97 | "drop_out":0.5, 98 | "end_with_softmax":false 99 | }, 100 | "p_classifier":{ 101 | "hid_layers":[ 102 | 128, 103 | 64 104 | ], 105 | "activation_fn":"mish", 106 | "drop_out":0.5, 107 | "end_with_softmax":true 108 | }, 109 | "env_classifier":{ 110 | "hid_layers":[ 111 | 128, 112 | 64 113 | ], 114 | "activation_fn":"mish", 115 | "drop_out":0.5, 116 | "end_with_softmax":true 117 | }, 118 | "lambda_":-1 119 | } 120 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v8_env0_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":6, 5 | "known_env_num":1, 6 | "data":{ 7 | "dataset":"v8", 8 | "train":{ 9 | "dir":"data/v8/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v8/valid_env0.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "test":{ 31 | "dir":"data/v8/test_env0_pre.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | } 41 | }, 42 | "train":{ 43 | "is_DA":false, 44 | "max_epoch":300, 45 | "valid_start_epoch":1, 46 | "valid_step":1, 47 | "stop_train_step_valid_not_improve":30, 48 | "valid_metrics":"acc", 49 | "valid_metrics_less":false, 50 | "optimizer":{ 51 | "lr":1e-5, 52 | "weight_decay":1e-4 53 | } 54 | }, 55 | "model":{ 56 | "name":"BIRDEncoder", 57 | "is_norm_first":true, 58 | "d":64, 59 | "t":64, 60 | "do":256, 61 | "TLinear":{ 62 | "hid_layers":[ 63 | 4096, 64 | 1024, 65 | 256 66 | ], 67 | "activation_fn":"mish", 68 | "drop_out":0.5, 69 | "end_with_softmax":false 70 | }, 71 | "csinet":{ 72 | "n_refine":2, 73 | "hid_layers":[ 74 | 8, 75 | 16, 76 | 32, 77 | 32, 78 | 16, 79 | 8 80 | ], 81 | "activation_fn":"relu" 82 | }, 83 | "trans2":{ 84 | "max_len":600, 85 | "is_norm_first":true, 86 | "d_fc":2048, 87 | "n_heads": 4, 88 | "n_layers": 5 89 | }, 90 | "tail_net":{ 91 | "hid_layers":[ 92 | 1024, 93 | 512, 94 | 256 95 | ], 96 | "activation_fn":"mish", 97 | "drop_out":0.5, 98 | "end_with_softmax":false 99 | }, 100 | "p_classifier":{ 101 | "hid_layers":[ 102 | 128, 103 | 64 104 | ], 105 | "activation_fn":"mish", 106 | "drop_out":0.5, 107 | "end_with_softmax":true 108 | }, 109 | "env_classifier":{ 110 | "hid_layers":[ 111 | 128, 112 | 64 113 | ], 114 | "activation_fn":"mish", 115 | "drop_out":0.5, 116 | "end_with_softmax":true 117 | }, 118 | "lambda_":-1 119 | } 120 | } -------------------------------------------------------------------------------- /executor.py: -------------------------------------------------------------------------------- 1 | import argparse, logging, os, torch, platform, subprocess, webbrowser, time 2 | from torch.utils.tensorboard import SummaryWriter 3 | from lib import glb_var, json_util, util, callback, colortext 4 | 5 | if __name__ == '__main__': 6 | glb_var.__init__(); 7 | parse = argparse.ArgumentParser(); 8 | parse.add_argument('--config', '-cfg', type = str, default = None, help = 'config for run'); 9 | parse.add_argument('--saved_config', '-sc', type = str, default = None, help = 'path for saved config to test') 10 | parse.add_argument('--auto_shutdown', '-as', type = bool, default = False, help = 'automatic shutdown after program completion') 11 | 12 | args = parse.parse_args(); 13 | is_train, is_test = False, False; 14 | if args.config is not None: 15 | config = json_util.jsonload(args.config); 16 | is_train = True; 17 | is_test = True; 18 | elif args.saved_config is not None: 19 | config = json_util.jsonload(args.saved_config); 20 | save_dir = os.path.dirname(args.saved_config) + '/'; 21 | config['save_dir'] = save_dir; 22 | is_test = True; 23 | 24 | from lib.callback import Logger 25 | DATASET = config['data']['dataset']; 26 | INFO_LEVEL = logging.INFO; 27 | MODEL_NAME = config['model']['name']; 28 | 29 | if not os.path.exists(f'./cache/logger/{DATASET}/'): 30 | os.makedirs(f'./cache/logger/{DATASET}/'); 31 | if 'save_dir' not in locals(): 32 | save_dir = f'./cache/save/{DATASET}/{MODEL_NAME}_{util.get_date()}_{util.get_time()}/'; 33 | config['save_dir'] = save_dir; 34 | if not os.path.exists(save_dir): 35 | os.makedirs(save_dir); 36 | logger = Logger( 37 | level = INFO_LEVEL, 38 | filename = f'./cache/logger/{DATASET}/{MODEL_NAME}_{util.get_date()}_{util.get_time()}.log', 39 | ).get_log(); 40 | logger.debug(f'save dir:{save_dir}'); 41 | if not os.path.exists(save_dir + 'tb'): 42 | os.makedirs(save_dir + 'tb'); 43 | glb_var.set_value('logger', logger); 44 | if config['gpu_is_available']: 45 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'); 46 | else: 47 | device = torch.device('cpu'); 48 | glb_var.set_value('device', device); 49 | util.kill_process_on_port(6006); 50 | tb_writer = SummaryWriter(log_dir = save_dir + 'tb'); 51 | 52 | tb_process = subprocess.Popen(["tensorboard", "--logdir", save_dir + 'tb', "--port", "6006", "--reload_interval", "5"]) 53 | time.sleep(5); 54 | try: 55 | subprocess.run(["explorer.exe", "http://localhost:6006"]) 56 | except webbrowser.Error: 57 | webbrowser.open("http://localhost:6006") 58 | 59 | glb_var.set_value('tb_writer', tb_writer); 60 | util.set_seed(config['seed']); 61 | callback.set_custom_tqdm_warning(); 62 | 63 | from t2 import train, test 64 | try: 65 | if is_train: 66 | config = train.train_model(config); 67 | if is_test: 68 | test.test_model(config); 69 | except KeyboardInterrupt: 70 | logger.info(colortext.YELLOW + "Exiting program..." + colortext.RESET) 71 | 72 | tb_writer.close(); 73 | if args.auto_shutdown: 74 | logger.info('Automatic shutdown.'); 75 | tb_process.terminate(); 76 | if platform.system().lower() == 'linux' and not util.is_wsl(): 77 | os.system("shutdown -h now"); 78 | elif util.is_wsl() : 79 | os.system("/mnt/c/Windows/System32/cmd.exe /c shutdown /s /t 1") 80 | else:#windows 81 | os.system("shutdown /s /t 1"); 82 | 83 | logger.info("\nTensorBoard is still running.") 84 | try: 85 | input(colortext.RED + "Press Enter to stop TensorBoard and exit..." + colortext.RESET); 86 | finally: 87 | tb_process.terminate(); -------------------------------------------------------------------------------- /model/framework/csiid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib import glb_var, util 3 | from model.framework.base import Net 4 | from model import net_util 5 | 6 | class CSIID(Net): 7 | def __init__(self, model_cfg): 8 | super().__init__(model_cfg); 9 | 10 | # First convolutional layer 11 | self.conv1 = torch.nn.Sequential( 12 | torch.nn.Conv2d(3, 30, kernel_size=(100, 3), stride=(1, 3)), 13 | torch.nn.BatchNorm2d(30), 14 | torch.nn.ReLU() 15 | ) 16 | 17 | # Second convolutional layer 18 | self.conv2 = torch.nn.Sequential( 19 | torch.nn.Conv2d(30, 30, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1)), 20 | torch.nn.BatchNorm2d(30), 21 | torch.nn.ReLU() 22 | ) 23 | 24 | # Third convolutional layer 25 | self.conv3 = torch.nn.Sequential( 26 | torch.nn.Conv2d(30, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 0)), 27 | torch.nn.BatchNorm2d(30), 28 | torch.nn.ReLU() 29 | ) 30 | 31 | # LSTM layer 32 | self.lstm = torch.nn.LSTM(input_size=480, hidden_size=128, num_layers=5, batch_first=True) 33 | 34 | # Fully connected layer 35 | self.fc = torch.nn.Linear(128, self.known_p_num) 36 | 37 | # Softmax 38 | self.softmax = torch.nn.Softmax(dim=1) 39 | 40 | @torch.no_grad() 41 | def encoder(self, amps): 42 | 43 | x = amps.permute(0, 2, 1, 3) 44 | 45 | # Convolutional layers 46 | x = self.conv1(x) 47 | x = self.conv2(x) 48 | x = self.conv3(x) 49 | 50 | # Crop to 203*22*30 51 | 52 | 53 | # Reshape for LSTM input (batch_size, seq_len, input_size) 54 | x = x.permute(0, 2, 1, 3) 55 | x = x.reshape(x.size(0), x.size(1), -1) 56 | 57 | # LSTM layer 58 | x, _ = self.lstm(x) 59 | x = x[:, -1, :] # Take the output from the last time step 60 | return x; 61 | 62 | def p_classifier(self, features): 63 | x = self.fc(features) 64 | 65 | # Softmax 66 | x = self.softmax(x) 67 | return x; 68 | 69 | def p_classify(self, x): 70 | x = self.encoder(x); 71 | x = self.p_classifier(x); 72 | return x 73 | 74 | def cal_loss(self, amps, ids, envs): 75 | id_probs = self.p_classify(amps); 76 | return torch.nn.CrossEntropyLoss()(id_probs, ids); 77 | 78 | class CSIIDAL(CSIID): 79 | def __init__(self, model_cfg): 80 | super().__init__(model_cfg); 81 | self.lambda_ = model_cfg['lambda_']; 82 | 83 | #env layer 84 | env_cfg = model_cfg['env_classifier']; 85 | env_cfg['activation_fn'] = net_util.get_activation_fn(env_cfg['activation_fn']); 86 | env_cfg['dim_in'] = 128; 87 | env_cfg['dim_out'] = self.known_env_num; 88 | self.env_layer = util.get_func_rets(net_util.get_mlp_net, env_cfg); 89 | 90 | def cal_loss(self, amps, ids, envs, amps_t = None, ids_t = None, envs_t = None): 91 | features = self.encoder(amps); 92 | id_probs = self.p_classifier(features); 93 | loss_id = torch.nn.CrossEntropyLoss()(id_probs, ids); 94 | 95 | if amps_t is None: 96 | return loss_id; 97 | 98 | env_probs = self.env_layer(net_util.GradientReversalF.apply(features, self.lambda_)); 99 | loss_env = torch.nn.CrossEntropyLoss()(env_probs, envs); 100 | 101 | loss_t = loss_id + loss_env; 102 | 103 | feature_t = self.encoder(amps_t); 104 | id_probs_t = self.p_classifier(feature_t); 105 | id_loss_t = torch.nn.CrossEntropyLoss()(id_probs_t, ids_t); 106 | 107 | return loss_t + id_loss_t; 108 | 109 | glb_var.register_model('CSIID', CSIID); 110 | glb_var.register_model('CSIIDAL', CSIIDAL); 111 | -------------------------------------------------------------------------------- /config/bird/bird_encoder_v2_wo_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v2", 8 | "train":{ 9 | "dir":"data/v2/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | } 18 | }, 19 | "valid":{ 20 | "dir":"data/v2/valid.mat", 21 | "loader":{ 22 | "num_workers":4, 23 | "pin_memory":true, 24 | "batch_size":8, 25 | "drop_last":false, 26 | "shuffle":true, 27 | "prefetch_factor":4 28 | } 29 | }, 30 | "target":{ 31 | "dir":"data/v2/target.mat", 32 | "loader":{ 33 | "num_workers":4, 34 | "pin_memory":true, 35 | "batch_size":8, 36 | "drop_last":false, 37 | "shuffle":true, 38 | "prefetch_factor":4 39 | } 40 | }, 41 | "test":{ 42 | "dir":"data/v2/test_legal.mat", 43 | "loader":{ 44 | "num_workers":4, 45 | "pin_memory":true, 46 | "batch_size":8, 47 | "drop_last":false, 48 | "shuffle":true, 49 | "prefetch_factor":4 50 | } 51 | } 52 | }, 53 | "train":{ 54 | "is_DA":true, 55 | "max_epoch":300, 56 | "valid_start_epoch":1, 57 | "valid_step":1, 58 | "stop_train_step_valid_not_improve":30, 59 | "valid_metrics":"acc", 60 | "valid_metrics_less":false, 61 | "optimizer":{ 62 | "lr":1e-5, 63 | "weight_decay":1e-4 64 | } 65 | }, 66 | "model":{ 67 | "name":"BIRDEncoder", 68 | "is_norm_first":true, 69 | "d":64, 70 | "t":64, 71 | "do":256, 72 | "TLinear":{ 73 | "hid_layers":[ 74 | 4096, 75 | 1024, 76 | 256 77 | ], 78 | "activation_fn":"mish", 79 | "drop_out":0.5, 80 | "end_with_softmax":false 81 | }, 82 | "csinet":{ 83 | "n_refine":2, 84 | "hid_layers":[ 85 | 8, 86 | 16, 87 | 32, 88 | 32, 89 | 16, 90 | 8 91 | ], 92 | "activation_fn":"relu" 93 | }, 94 | "trans2":{ 95 | "max_len":600, 96 | "is_norm_first":true, 97 | "d_fc":2048, 98 | "n_heads": 4, 99 | "n_layers": 5 100 | }, 101 | "tail_net":{ 102 | "hid_layers":[ 103 | 1024, 104 | 512, 105 | 256 106 | ], 107 | "activation_fn":"mish", 108 | "drop_out":0.5, 109 | "end_with_softmax":false 110 | }, 111 | "p_classifier":{ 112 | "hid_layers":[ 113 | 128, 114 | 64 115 | ], 116 | "activation_fn":"mish", 117 | "drop_out":0.5, 118 | "end_with_softmax":true 119 | }, 120 | "env_classifier":{ 121 | "hid_layers":[ 122 | 128, 123 | 64 124 | ], 125 | "activation_fn":"mish", 126 | "drop_out":0.5, 127 | "end_with_softmax":true 128 | }, 129 | "lambda_":1 130 | } 131 | } -------------------------------------------------------------------------------- /config/bird/bird_v1_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v1", 8 | "train":{ 9 | "dir":"data/v1/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":24, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | }, 18 | "augmentation":{ 19 | "gaussian_noise":{ 20 | "type":"gaussian_noise", 21 | "snr_db":30 22 | }, 23 | "multipath_fading":{ 24 | "type":"multipath_fading", 25 | "delay":[ 26 | [50], 27 | [50, 70], 28 | [50, 70, 80] 29 | ], 30 | "coef":[ 31 | [0.8], 32 | [0.8, 0.6], 33 | [0.8, 0.6, 0.5] 34 | ] 35 | }, 36 | "freq_selective_fading":{ 37 | "type":"freq_selective_fading", 38 | "scope":[ 39 | [0.8, 0.3], 40 | [0.7, 0.2], 41 | [0.6, 0.1], 42 | [0.8, 0.1] 43 | ] 44 | }, 45 | "antenna_sequence":{ 46 | "type":"antenna_sequence" 47 | }, 48 | "subcarrier_mask":{ 49 | "type":"subcarrier_mask", 50 | "ratio":0.1 51 | } 52 | } 53 | }, 54 | "valid":{ 55 | "dir":"data/v1/valid.mat", 56 | "loader":{ 57 | "num_workers":4, 58 | "pin_memory":true, 59 | "batch_size":8, 60 | "drop_last":false, 61 | "shuffle":true, 62 | "prefetch_factor":4 63 | } 64 | }, 65 | "test":{ 66 | "dir":"data/v1/test.mat", 67 | "loader":{ 68 | "num_workers":4, 69 | "pin_memory":true, 70 | "batch_size":8, 71 | "drop_last":false, 72 | "shuffle":true, 73 | "prefetch_factor":4 74 | } 75 | } 76 | }, 77 | "train":{ 78 | "is_DA":false, 79 | "max_epoch":100, 80 | "valid_start_epoch":20, 81 | "valid_step":5, 82 | "stop_train_step_valid_not_improve":6, 83 | "valid_metrics":"loss", 84 | "valid_metrics_less":true, 85 | "optimizer":{ 86 | "lr":1e-5, 87 | "weight_decay":1e-4 88 | } 89 | }, 90 | "model":{ 91 | "name":"BIRD", 92 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_opt4/end", 93 | "loss":"nmse", 94 | "do":256, 95 | "d":16, 96 | "t":64, 97 | "TLinear1":{ 98 | "hid_layers":[ 99 | 8, 100 | 32 101 | ], 102 | "activation_fn":"mish", 103 | "drop_out":0.5, 104 | "end_with_softmax":false 105 | }, 106 | "csinet":{ 107 | "n_refine":4, 108 | "hid_layers":[ 109 | 8, 110 | 16, 111 | 32, 112 | 32, 113 | 16, 114 | 8 115 | ], 116 | "activation_fn":"relu" 117 | }, 118 | "FLinear":{ 119 | "hid_layers":[ 120 | 16, 121 | 32 122 | ], 123 | "activation_fn":"mish", 124 | "drop_out":0.5, 125 | "end_with_softmax":false 126 | }, 127 | "RLinear":{ 128 | "hid_layers":[ 129 | 16, 130 | 8 131 | ], 132 | "activation_fn":"mish", 133 | "drop_out":0.5, 134 | "end_with_softmax":false 135 | }, 136 | "TLinear2":{ 137 | "hid_layers":[ 138 | 256, 139 | 1024, 140 | 4096 141 | ], 142 | "activation_fn":"mish", 143 | "drop_out":0.5, 144 | "end_with_softmax":false 145 | } 146 | } 147 | } -------------------------------------------------------------------------------- /model/framework/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib import glb_var, util 3 | 4 | logger = glb_var.get_value('logger'); 5 | 6 | class Net(torch.nn.Module): 7 | '''Abstract Net class to define the API methods 8 | ''' 9 | def __init__(self, model_cfg) -> None: 10 | super().__init__(); 11 | util.set_attr(self, model_cfg, except_type = dict); 12 | self.is_Intrusion_Detection = False; 13 | 14 | def _init_para(self, module): 15 | if isinstance(module, torch.nn.Embedding): 16 | module.weight.data.normal_(mean=0.0, std=1/module.embedding_dim) 17 | elif isinstance(module, torch.nn.LayerNorm): 18 | module.bias.data.fill_(1.0) 19 | module.weight.data.fill_(1.0) 20 | elif isinstance(module, torch.nn.Linear): 21 | module.weight.data.normal_() 22 | if module.bias is not None: 23 | module.bias.data.fill_(1.0) 24 | 25 | def forward(self): 26 | '''Defines the forward pass of the network. Must be implemented by subclasses.''' 27 | logger.error('Method needs to be called after being implemented'); 28 | raise NotImplementedError; 29 | 30 | def encoder(self): 31 | '''Defines the encoder part of the network. Must be implemented by subclasses.''' 32 | logger.error('Method needs to be called after being implemented'); 33 | raise NotImplementedError; 34 | 35 | @torch.no_grad() 36 | def cal_accuracy(self, amps, ids, envs) -> float: 37 | ''' 38 | Calculates the accuracy of the model's predictions. 39 | 40 | If `is_Intrusion_Detection` is `False`, it calculates classification accuracy by comparing 41 | the predicted class with the ground truth (`ids`). If `is_Intrusion_Detection` is `True`, 42 | it calculates accuracy for intrusion detection tasks. 43 | 44 | Returns: 45 | --------- 46 | acc : float 47 | The accuracy of the model's predictions. 48 | ''' 49 | #ids:[N] 50 | #envs:[N] 51 | if not self.is_Intrusion_Detection: 52 | #p 53 | id_pred = self.p_classify(amps).argmax(dim = -1); 54 | acc = (id_pred == ids).cpu().float().mean().item(); 55 | else: 56 | intrude_pred = self.intrusion_detection(amps); 57 | intrude_gt = ids >= self.known_p_num; 58 | acc = (intrude_gt == intrude_pred).cpu().float().mean().item(); 59 | return acc; 60 | 61 | def p_classify(self, amps): 62 | '''Defines the id classiify part of the network. Must be implemented by subclasses.''' 63 | logger.error('Method needs to be called after being implemented'); 64 | raise NotImplementedError; 65 | 66 | def train_epoch_hook(self, trainer, epoch): 67 | '''Hook for executing custom logic during each training epoch.''' 68 | pass 69 | 70 | def valid_epoch_hook(self, trainer, epoch): 71 | '''Hook for executing custom logic during each validation epoch.''' 72 | pass 73 | 74 | def pre_test_hook(self, tester): 75 | '''Hook for executing custom logic before testing the model.''' 76 | pass 77 | 78 | def cal_loss(self, amps, ids, envs, amps_t = None, ids_t = None, envs_t = None): 79 | '''Calculate the loss''' 80 | logger.error('Method needs to be called after being implemented'); 81 | raise NotImplementedError; 82 | 83 | def conventional_train(self, X, Y): 84 | '''Only applicable to traditional methods''' 85 | logger.error('Method needs to be called after being implemented'); 86 | raise NotImplementedError; 87 | 88 | def save(self, save_dir): 89 | ''' 90 | General model saving method for pytorch model 91 | 92 | Saves the model's state dictionary to the specified directory. 93 | 94 | The model's parameters are saved as `model_state_dict.pth`. 95 | ''' 96 | torch.save(self.state_dict(), save_dir + '/model_state_dict.pth'); 97 | logger.info(f"Model saved to {save_dir}\n"); 98 | 99 | def load(self, load_dir): 100 | '''General method for loading models 101 | 102 | Loads the model's state dictionary from the specified directory. 103 | 104 | The model's parameters are loaded from `model_state_dict.pth`. 105 | ''' 106 | self.load_state_dict(torch.load(load_dir + '/model_state_dict.pth', weights_only = True)); 107 | logger.info(f"Model loaded from {load_dir}\n") -------------------------------------------------------------------------------- /model/framework/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib import util, glb_var 3 | import numpy as np 4 | from model import net_util 5 | from model.framework.base import Net 6 | 7 | device = glb_var.get_value('device'); 8 | logger = glb_var.get_value('logger'); 9 | 10 | class AE(Net): 11 | def __init__(self, model_cfg) -> None: 12 | super().__init__(model_cfg); 13 | 14 | #encoder 15 | encoder_cfg = model_cfg['encoder_cfg']; 16 | encoder_cfg['activation_fn'] = net_util.get_activation_fn(encoder_cfg['activation_fn']); 17 | encoder_cfg['dim_in'] = np.prod(self.dim_in[1:]); 18 | encoder_cfg['dim_out'] = self.do; 19 | self.encnet = util.get_func_rets(net_util.get_mlp_net, encoder_cfg); 20 | 21 | #decoder 22 | decoder_cfg = model_cfg['decoder_cfg']; 23 | decoder_cfg['activation_fn'] = net_util.get_activation_fn(decoder_cfg['activation_fn']); 24 | decoder_cfg['dim_in'] = self.do; 25 | decoder_cfg['dim_out'] = np.prod(self.dim_in[1:]); 26 | self.decnet = util.get_func_rets(net_util.get_mlp_net, decoder_cfg); 27 | 28 | self.loss_func = net_util.get_loss_func(model_cfg['loss']); 29 | self.threshold = .0; 30 | self.thresholds = dict(); 31 | 32 | self.is_Intrusion_Detection = True; 33 | 34 | def encoder(self, amps): 35 | return self.encnet(amps.flatten(1, -1)); 36 | 37 | def reconstruct(self, amps): 38 | #amps[B, T, R, F] 39 | feature = self.encoder(amps); 40 | csi_rcst = self.decnet(feature).reshape(amps.shape); 41 | return csi_rcst; 42 | 43 | def cal_loss(self, amps, ids, envs, keep_batch = False): 44 | #amps:[N, T, R, F] 45 | #ids:[N] 46 | #envs:[N] 47 | csi_rcst = self.reconstruct(amps); 48 | loss = self.loss_func(csi_rcst, amps, keep_batch); 49 | return loss; 50 | 51 | @torch.no_grad() 52 | def update_thresholds(self, loader, threshold_percents): 53 | self.thresholds = dict(); 54 | loader.disable_aug(); 55 | self.eval(); 56 | rcst_errors = torch.zeros(0, device = device); 57 | for amps, ids, envs in iter(loader): 58 | rcst_errors = torch.cat(( 59 | rcst_errors, 60 | self.cal_loss(amps, ids, envs, keep_batch = True) 61 | )); 62 | for percent in threshold_percents: 63 | self.thresholds[percent] = np.percentile(rcst_errors.cpu(), percent) 64 | logger.debug(f'Updated threshold({percent:.2f}%): {self.threshold}'); 65 | 66 | def set_threshold(self, percent): 67 | self.threshold = self.thresholds[percent]; 68 | 69 | def intrusion_detection(self, amps): 70 | ''' 71 | Rrturns: 72 | -------- 73 | Is it an intruder. 74 | ''' 75 | rcst_errors = self.cal_loss(amps, None, None, keep_batch = True); 76 | return rcst_errors >= self.threshold 77 | 78 | class CAE(AE): 79 | def __init__(self, model_cfg) -> None: 80 | Net.__init__(self, model_cfg); 81 | 82 | #encoder 83 | layers = [self.dim_in[2]] + model_cfg['encoder_hid_layers']; 84 | self.encnet = torch.nn.Sequential(*[ 85 | torch.nn.Sequential( 86 | torch.nn.Conv2d(layers[i], layers[i + 1], kernel_size = 3, stride = 2, padding = 1), 87 | torch.nn.ReLU() 88 | ) 89 | for i in range(len(layers) - 1) 90 | ]); 91 | 92 | #decoder 93 | layers = model_cfg['decoder_hid_layers'] + [self.dim_in[2]]; 94 | self.decnet = torch.nn.Sequential(*[ 95 | torch.nn.Sequential( 96 | torch.nn.ConvTranspose2d(layers[i], layers[i + 1], kernel_size=3, stride=2, padding=1, output_padding=(1, 1)), 97 | torch.nn.Sigmoid() if i + 2 == len(layers) else torch.nn.ReLU() 98 | ) 99 | for i in range(len(layers) - 1) 100 | ]); 101 | 102 | self.loss_func = net_util.get_loss_func(model_cfg['loss']); 103 | self.threshold = .0; 104 | self.thresholds = dict(); 105 | 106 | self.is_Intrusion_Detection = True; 107 | 108 | def encoder(self, amps): 109 | return self.encnet(amps.permute(0, 2, 1, 3)); 110 | 111 | def reconstruct(self, amps): 112 | #amps[B, T, R, F] 113 | feature = self.encoder(amps); 114 | csi_rcst = self.decnet(feature).permute(0, 2, 1, 3); 115 | return csi_rcst; 116 | 117 | glb_var.register_model('AE', AE); 118 | glb_var.register_model('CAE', CAE); -------------------------------------------------------------------------------- /config/bird/bird_encoder_v2_wo_al.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2025, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v2", 8 | "train":{ 9 | "dir":"data/v2/train.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | }, 18 | "augmentation":{ 19 | "gaussian_noise":{ 20 | "type":"gaussian_noise", 21 | "snr_db":30 22 | }, 23 | "multipath_fading":{ 24 | "type":"multipath_fading", 25 | "delay":[ 26 | [50], 27 | [50, 70], 28 | [50, 70, 80] 29 | ], 30 | "coef":[ 31 | [0.8], 32 | [0.8, 0.6], 33 | [0.8, 0.6, 0.5] 34 | ] 35 | }, 36 | "freq_selective_fading":{ 37 | "type":"freq_selective_fading", 38 | "scope":[ 39 | [0.8, 0.3], 40 | [0.7, 0.2], 41 | [0.6, 0.1], 42 | [0.8, 0.1] 43 | ] 44 | }, 45 | "antenna_sequence":{ 46 | "type":"antenna_sequence" 47 | }, 48 | "subcarrier_mask":{ 49 | "type":"subcarrier_mask", 50 | "ratio":0.1 51 | } 52 | } 53 | }, 54 | "valid":{ 55 | "dir":"data/v2/valid.mat", 56 | "loader":{ 57 | "num_workers":4, 58 | "pin_memory":true, 59 | "batch_size":8, 60 | "drop_last":false, 61 | "shuffle":true, 62 | "prefetch_factor":4 63 | } 64 | }, 65 | "test":{ 66 | "dir":"data/v2/test_legal.mat", 67 | "loader":{ 68 | "num_workers":4, 69 | "pin_memory":true, 70 | "batch_size":8, 71 | "drop_last":false, 72 | "shuffle":true, 73 | "prefetch_factor":4 74 | } 75 | } 76 | }, 77 | "train":{ 78 | "is_DA":false, 79 | "max_epoch":100, 80 | "valid_start_epoch":1, 81 | "valid_step":1, 82 | "stop_train_step_valid_not_improve":10, 83 | "valid_metrics":"acc", 84 | "valid_metrics_less":false, 85 | "optimizer":{ 86 | "lr":1e-5, 87 | "weight_decay":1e-4 88 | } 89 | }, 90 | "model":{ 91 | "name":"BIRDEncoder", 92 | "is_norm_first":true, 93 | "d":64, 94 | "t":64, 95 | "do":256, 96 | "TLinear":{ 97 | "hid_layers":[ 98 | 4096, 99 | 1024, 100 | 256 101 | ], 102 | "activation_fn":"mish", 103 | "drop_out":0.5, 104 | "end_with_softmax":false 105 | }, 106 | "csinet":{ 107 | "n_refine":2, 108 | "hid_layers":[ 109 | 8, 110 | 16, 111 | 32, 112 | 32, 113 | 16, 114 | 8 115 | ], 116 | "activation_fn":"relu" 117 | }, 118 | "trans2":{ 119 | "max_len":600, 120 | "is_norm_first":true, 121 | "d_fc":2048, 122 | "n_heads": 4, 123 | "n_layers": 5 124 | }, 125 | "tail_net":{ 126 | "hid_layers":[ 127 | 1024, 128 | 512, 129 | 256 130 | ], 131 | "activation_fn":"mish", 132 | "drop_out":0.5, 133 | "end_with_softmax":false 134 | }, 135 | "p_classifier":{ 136 | "hid_layers":[ 137 | 128, 138 | 64 139 | ], 140 | "activation_fn":"mish", 141 | "drop_out":0.5, 142 | "end_with_softmax":true 143 | }, 144 | "env_classifier":{ 145 | "hid_layers":[ 146 | 128, 147 | 64 148 | ], 149 | "activation_fn":"mish", 150 | "drop_out":0.5, 151 | "end_with_softmax":true 152 | }, 153 | "lambda_":0 154 | } 155 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v4_env0.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env0.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | }, 18 | "augmentation":{ 19 | "gaussian_noise":{ 20 | "type":"gaussian_noise", 21 | "snr_db":30 22 | }, 23 | "multipath_fading":{ 24 | "type":"multipath_fading", 25 | "delay":[ 26 | [50], 27 | [50, 70], 28 | [50, 70, 80] 29 | ], 30 | "coef":[ 31 | [0.8], 32 | [0.8, 0.6], 33 | [0.8, 0.6, 0.5] 34 | ] 35 | }, 36 | "freq_selective_fading":{ 37 | "type":"freq_selective_fading", 38 | "scope":[ 39 | [0.8, 0.3], 40 | [0.7, 0.2], 41 | [0.6, 0.1], 42 | [0.8, 0.1] 43 | ] 44 | }, 45 | "antenna_sequence":{ 46 | "type":"antenna_sequence" 47 | }, 48 | "subcarrier_mask":{ 49 | "type":"subcarrier_mask", 50 | "ratio":0.1 51 | } 52 | } 53 | }, 54 | "valid":{ 55 | "dir":"data/v4/valid_env0.mat", 56 | "loader":{ 57 | "num_workers":4, 58 | "pin_memory":true, 59 | "batch_size":8, 60 | "drop_last":false, 61 | "shuffle":true, 62 | "prefetch_factor":4 63 | } 64 | }, 65 | "test":{ 66 | "dir":"data/v4/test_env0.mat", 67 | "loader":{ 68 | "num_workers":4, 69 | "pin_memory":true, 70 | "batch_size":8, 71 | "drop_last":false, 72 | "shuffle":true, 73 | "prefetch_factor":4 74 | } 75 | } 76 | }, 77 | "train":{ 78 | "is_DA":false, 79 | "max_epoch":300, 80 | "valid_start_epoch":1, 81 | "valid_step":1, 82 | "stop_train_step_valid_not_improve":50, 83 | "valid_metrics":"acc", 84 | "valid_metrics_less":false, 85 | "optimizer":{ 86 | "lr":1e-5, 87 | "weight_decay":1e-4 88 | } 89 | }, 90 | "model":{ 91 | "name":"BIRDEncoder", 92 | "is_norm_first":true, 93 | "d":64, 94 | "t":64, 95 | "do":256, 96 | "TLinear":{ 97 | "hid_layers":[ 98 | 4096, 99 | 1024, 100 | 256 101 | ], 102 | "activation_fn":"mish", 103 | "drop_out":0.5, 104 | "end_with_softmax":false 105 | }, 106 | "csinet":{ 107 | "n_refine":2, 108 | "hid_layers":[ 109 | 8, 110 | 16, 111 | 32, 112 | 32, 113 | 16, 114 | 8 115 | ], 116 | "activation_fn":"relu" 117 | }, 118 | "trans2":{ 119 | "max_len":600, 120 | "is_norm_first":true, 121 | "d_fc":2048, 122 | "n_heads": 4, 123 | "n_layers": 5 124 | }, 125 | "tail_net":{ 126 | "hid_layers":[ 127 | 1024, 128 | 512, 129 | 256 130 | ], 131 | "activation_fn":"mish", 132 | "drop_out":0.5, 133 | "end_with_softmax":false 134 | }, 135 | "p_classifier":{ 136 | "hid_layers":[ 137 | 128, 138 | 64 139 | ], 140 | "activation_fn":"mish", 141 | "drop_out":0.5, 142 | "end_with_softmax":true 143 | }, 144 | "env_classifier":{ 145 | "hid_layers":[ 146 | 128, 147 | 64 148 | ], 149 | "activation_fn":"mish", 150 | "drop_out":0.5, 151 | "end_with_softmax":true 152 | }, 153 | "lambda_":-1 154 | } 155 | } -------------------------------------------------------------------------------- /config/bird/bird_encoder_v4_env1.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed":2024, 3 | "gpu_is_available":true, 4 | "known_p_num":10, 5 | "known_env_num":2, 6 | "data":{ 7 | "dataset":"v4", 8 | "train":{ 9 | "dir":"data/v4/train_env1.mat", 10 | "loader":{ 11 | "num_workers":4, 12 | "pin_memory":true, 13 | "batch_size":8, 14 | "drop_last":false, 15 | "shuffle":true, 16 | "prefetch_factor":4 17 | }, 18 | "augmentation":{ 19 | "gaussian_noise":{ 20 | "type":"gaussian_noise", 21 | "snr_db":30 22 | }, 23 | "multipath_fading":{ 24 | "type":"multipath_fading", 25 | "delay":[ 26 | [50], 27 | [50, 70], 28 | [50, 70, 80] 29 | ], 30 | "coef":[ 31 | [0.8], 32 | [0.8, 0.6], 33 | [0.8, 0.6, 0.5] 34 | ] 35 | }, 36 | "freq_selective_fading":{ 37 | "type":"freq_selective_fading", 38 | "scope":[ 39 | [0.8, 0.3], 40 | [0.7, 0.2], 41 | [0.6, 0.1], 42 | [0.8, 0.1] 43 | ] 44 | }, 45 | "antenna_sequence":{ 46 | "type":"antenna_sequence" 47 | }, 48 | "subcarrier_mask":{ 49 | "type":"subcarrier_mask", 50 | "ratio":0.1 51 | } 52 | } 53 | }, 54 | "valid":{ 55 | "dir":"data/v4/valid_env1.mat", 56 | "loader":{ 57 | "num_workers":4, 58 | "pin_memory":true, 59 | "batch_size":8, 60 | "drop_last":false, 61 | "shuffle":true, 62 | "prefetch_factor":4 63 | } 64 | }, 65 | "test":{ 66 | "dir":"data/v4/test_env1.mat", 67 | "loader":{ 68 | "num_workers":4, 69 | "pin_memory":true, 70 | "batch_size":8, 71 | "drop_last":false, 72 | "shuffle":true, 73 | "prefetch_factor":4 74 | } 75 | } 76 | }, 77 | "train":{ 78 | "is_DA":false, 79 | "max_epoch":300, 80 | "valid_start_epoch":5, 81 | "valid_step":1, 82 | "stop_train_step_valid_not_improve":30, 83 | "valid_metrics":"acc", 84 | "valid_metrics_less":false, 85 | "optimizer":{ 86 | "lr":1e-5, 87 | "weight_decay":1e-4 88 | } 89 | }, 90 | "model":{ 91 | "name":"BIRDEncoder", 92 | "is_norm_first":true, 93 | "d":64, 94 | "t":64, 95 | "do":256, 96 | "TLinear":{ 97 | "hid_layers":[ 98 | 4096, 99 | 1024, 100 | 256 101 | ], 102 | "activation_fn":"mish", 103 | "drop_out":0.5, 104 | "end_with_softmax":false 105 | }, 106 | "csinet":{ 107 | "n_refine":2, 108 | "hid_layers":[ 109 | 8, 110 | 16, 111 | 32, 112 | 32, 113 | 16, 114 | 8 115 | ], 116 | "activation_fn":"relu" 117 | }, 118 | "trans2":{ 119 | "max_len":600, 120 | "is_norm_first":true, 121 | "d_fc":2048, 122 | "n_heads": 4, 123 | "n_layers": 5 124 | }, 125 | "tail_net":{ 126 | "hid_layers":[ 127 | 1024, 128 | 512, 129 | 256 130 | ], 131 | "activation_fn":"mish", 132 | "drop_out":0.5, 133 | "end_with_softmax":false 134 | }, 135 | "p_classifier":{ 136 | "hid_layers":[ 137 | 128, 138 | 64 139 | ], 140 | "activation_fn":"mish", 141 | "drop_out":0.5, 142 | "end_with_softmax":true 143 | }, 144 | "env_classifier":{ 145 | "hid_layers":[ 146 | 128, 147 | 64 148 | ], 149 | "activation_fn":"mish", 150 | "drop_out":0.5, 151 | "end_with_softmax":true 152 | }, 153 | "lambda_":-1 154 | } 155 | } --------------------------------------------------------------------------------