├── .gitignore
├── lib
├── colortext.py
├── decorator.py
├── glb_var.py
└── json_util.py
├── model
├── __init__.py
└── framework
│ ├── deep_wiid.py
│ ├── csiid.py
│ ├── base.py
│ └── autoencoder.py
├── config
├── bird
│ ├── bird_encoder_svm_v1.json
│ ├── bird_encoder_svm_v5_env0.json
│ ├── bird_encoder_svm_v5_env1.json
│ ├── bird_encoder_svm_v6_env0.json
│ ├── bird_encoder_svm_v7_env0.json
│ ├── bird_encoder_svm_v8_env0.json
│ ├── bird_v1_wo_aug.json
│ ├── bird_v5_env0_wo_aug.json
│ ├── bird_v5_env1_wo_aug.json
│ ├── bird_v6_env0_wo_aug.json
│ ├── bird_v7_env0_wo_aug.json
│ ├── bird_v8_env0_wo_aug.json
│ ├── bird_encoder_v2_wo_aug+al.json
│ ├── bird_encoder_v4_env0_wo_aug.json
│ ├── bird_encoder_v4_env1_wo_aug.json
│ ├── bird_encoder_v5_env0_wo_aug.json
│ ├── bird_encoder_v5_env1_wo_aug.json
│ ├── bird_encoder_v6_env0_wo_aug.json
│ ├── bird_encoder_v7_env0_wo_aug.json
│ ├── bird_encoder_v8_env0_wo_aug.json
│ ├── bird_encoder_v2_wo_aug.json
│ ├── bird_v1_aug.json
│ ├── bird_encoder_v2_wo_al.json
│ ├── bird_encoder_v4_env0.json
│ └── bird_encoder_v4_env1.json
├── csiid
│ ├── csiid_v1.json
│ ├── csiid_v4_env0.json
│ └── csiid_v4_env1.json
├── wiau
│ ├── wiau_id_v1.json
│ ├── wiau_id_v4_env0.json
│ ├── wiau_id_v4_env1.json
│ ├── wiau_v1.json
│ ├── wiau_v5_env0.json
│ ├── wiau_v5_env1.json
│ ├── wiau_v6_env0.json
│ ├── wiau_v7_env0.json
│ └── wiau_v8_env0.json
├── gateid
│ ├── gateid_v1.json
│ ├── gateid_v4_env0.json
│ └── gateid_v4_env1.json
├── caution
│ ├── caution_v3.json
│ ├── caution_v6_env0.json
│ ├── caution_v5_env1.json
│ ├── caution_v5_env0.json
│ ├── caution_v7_env0.json
│ ├── caution_v8_env0.json
│ ├── caution_encoder_v3.json
│ ├── caution_encoder_v5_env0.json
│ ├── caution_encoder_v5_env1.json
│ ├── caution_encoder_v6_env0.json
│ ├── caution_encoder_v7_env0.json
│ └── caution_encoder_v8_env0.json
├── autoencoder
│ ├── cae_v1.json
│ ├── cae_v5_env0.json
│ ├── cae_v5_env1.json
│ ├── cae_v6_env0.json
│ ├── cae_v7_env0.json
│ ├── cae_v8_env0.json
│ ├── ae_v1.json
│ ├── ae_v5_env0.json
│ ├── ae_v5_env1.json
│ ├── ae_v6_env0.json
│ ├── ae_v7_env0.json
│ └── ae_v8_env0.json
├── deep_wiid
│ ├── deep_wiid_v1.json
│ ├── deep_wiid_v4_env0.json
│ ├── deep_wiid_v4_env1.json
│ └── deep_wiid_al_v2.json
├── gait_enhance
│ ├── gait_enhance_v1.json
│ ├── gait_enhance_v4_env0.json
│ ├── gait_enhance_v4_env1.json
│ └── gait_enhance_al_v2.json
├── dcs_gait
│ ├── dcs_gait_encoder_v1.json
│ └── dcs_gait_v2.json
└── wiai_id
│ ├── wiai_id_v4_env0.json
│ ├── wiai_id_v4_env1.json
│ └── wiai_id_v2.json
├── requirements.txt
├── README.md
├── data
└── augmentation.py
└── executor.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # IDEs
2 | .vscode
3 |
4 | # Python
5 | *.ipynb
6 | *.pyc
7 |
8 | #file
9 | *.log
10 | *.mat
11 | *.txt
12 | !requirements.txt
13 | *.mlx
14 | *.rar
15 | *.zip
16 |
17 | #Data
18 | __pycache__/
19 | cache/
20 | runs/
21 | model/pretrained
22 |
23 |
24 |
--------------------------------------------------------------------------------
/lib/colortext.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023.05.16
2 | # @Author : Darrius Lei
3 | # @Email : darrius.lei@outlook.com
4 |
5 | BLACK = '\033[0;30m';
6 | RED = '\033[0;31m';
7 | GREEN = '\033[0;32m';
8 | YELLOW = '\033[0;33m';
9 | BLUE = '\033[0;34m';
10 | PURPLE = '\033[0;35m';
11 | CYAN = '\033[0;36m';
12 | WHITE = '\033[0;37m';
13 | GRAY = '\033[0;38m';
14 |
15 | RESET = '\033[0m';
16 |
17 | RED_TRIANGLE = f"{RED}▲{RESET}";
18 |
--------------------------------------------------------------------------------
/lib/decorator.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023.07.08
2 | # @Author : Darrius Lei
3 | # @Email : darrius.lei@outlook.com
4 |
5 | from typing import Any
6 | from lib import glb_var, util
7 | import time
8 |
9 | logger = glb_var.get_value('logger');
10 |
11 | class Decorator(object):
12 | '''Abstract Decorator class
13 | '''
14 | def __init__(self, func) -> None:
15 | self.func = func;
16 |
17 | def __call__(self, *args: Any, **kwds: Any) -> Any:
18 | '''Method needs to be called after being implemented'''
19 | raise NotImplementedError;
20 |
21 | class Timer(Decorator):
22 | '''Timer for func'''
23 | def __init__(self, func) -> None:
24 | super().__init__(func);
25 |
26 | def __call__(self, *args: Any, **kwds: Any) -> Any:
27 | t = time.time();
28 | result = self.func(*args, **kwds);
29 | logger.info(f'The time consumption of {self.func.__name__}: {util.s2hms(time.time() - t)}');
30 | return result
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from lib import glb_var, util, json_util
4 |
5 | device = glb_var.get_value('device');
6 | logger = glb_var.get_value('logger');
7 |
8 | __all__ = ['generate_model', 'load_model'];
9 |
10 | framework_dir = os.path.join(os.path.dirname(__file__), 'framework')
11 | util.load_modules_from_directory(framework_dir, 'model.framework')
12 |
13 | def generate_model(model_cfg):
14 | model_class = glb_var.query_model(model_cfg['name'])
15 | if model_class is None:
16 | raise RuntimeError(f"Unable to find model class named {model_cfg['name']}, please check configuration")
17 |
18 | try:
19 | model = model_class(model_cfg);
20 | except Exception as e:
21 | raise RuntimeError(f"Error instantiating model {model_cfg['name']} : {e}")
22 | num_params = sum(p.numel() for p in model.parameters())
23 | logger.info(f'The number of parameters of the [{model.name}]: {num_params / 10 ** 6:.6f} M\n');
24 | return model.to(device);
25 |
26 | def load_model(load_dir):
27 | '''
28 | if A/B/model.pth then load_dir = A/B/
29 | '''
30 | if os.path.isfile(load_dir):
31 | raise RuntimeError('Please use the directory of the file instead of its path');
32 | cfg = json_util.jsonload(load_dir + '/config.json');
33 | model = generate_model(cfg['model']);
34 | model.load(load_dir);
35 | return model.to(device);
--------------------------------------------------------------------------------
/lib/glb_var.py:
--------------------------------------------------------------------------------
1 | # @Time : 2022/4/12
2 | # @Author : Junwei Lei
3 | # @Email : darrius.lei@outlook.com
4 |
5 | from lib.callback import CustomException
6 | import pydash as ps
7 | from lib import util
8 | '''
9 | Store global variables for the project
10 |
11 | Methods:
12 | --------
13 |
14 | __init__():None
15 | initialization
16 |
17 | set_value(key, value):None
18 | set global variables
19 |
20 | get_value(key):value
21 | retrieve global variables
22 |
23 | Examples:
24 | ---------
25 | main.py
26 | >>> import lib.glb_var
27 | >>> glb_var.set_value('a', 1);
28 | >>> submain;
29 |
30 | sub.py
31 | >>> import lib.glb_var
32 | >>> print(glb_var.get_value('a'));
33 |
34 | '''
35 | def __init__():
36 | global glb_dict;
37 | glb_dict = dict();
38 | glb_dict['model'] = {};
39 |
40 | def set_values(dict, keys = None, except_type = None):
41 | util.set_attr(glb_dict, dict, keys = keys, except_type = except_type);
42 |
43 | def set_value(key, value):
44 | #set global var
45 | glb_dict[key] = value;
46 |
47 | def register_model(key, value):
48 | glb_dict['model'][key] = value;
49 |
50 | def query_model(key):
51 | return glb_dict['model'][key] if key in glb_dict['model'].keys() else None;
52 |
53 | def get_value(key):
54 | try:
55 | return glb_dict[key];
56 | except KeyError:
57 | raise CustomException(f'The retrieved key [{key}] does not exist');
--------------------------------------------------------------------------------
/config/bird/bird_encoder_svm_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":false,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/alltrain.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":48,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "test":{
20 | "dir":"data/v1/test.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | }
30 | },
31 | "train":{
32 | "is_DA":false,
33 | "max_epoch":100,
34 | "valid_start_epoch":20,
35 | "valid_step":5,
36 | "stop_train_step_valid_not_improve":6,
37 | "valid_metrics":"loss",
38 | "valid_metrics_less":true,
39 | "optimizer":{
40 | "lr":1e-5,
41 | "weight_decay":1e-4
42 | }
43 | },
44 | "model":{
45 | "name":"BIRDEncoderSVM",
46 | "pretrained_enc_dir":"model/pretrained/bird_encoder_opt2/end",
47 | "do":256,
48 | "abnormality_rate":0.01
49 | }
50 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_svm_v5_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":false,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":48,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "test":{
20 | "dir":"data/v5/test_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | }
30 | },
31 | "train":{
32 | "is_DA":false,
33 | "max_epoch":100,
34 | "valid_start_epoch":5,
35 | "valid_step":1,
36 | "stop_train_step_valid_not_improve":30,
37 | "valid_metrics":"loss",
38 | "valid_metrics_less":true,
39 | "optimizer":{
40 | "lr":1e-5,
41 | "weight_decay":1e-4
42 | }
43 | },
44 | "model":{
45 | "name":"BIRDEncoderSVM",
46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v5_env0_opt1/end",
47 | "do":256,
48 | "abnormality_rate":0.3
49 | }
50 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_svm_v5_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":false,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":48,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "test":{
20 | "dir":"data/v5/test_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | }
30 | },
31 | "train":{
32 | "is_DA":false,
33 | "max_epoch":100,
34 | "valid_start_epoch":5,
35 | "valid_step":1,
36 | "stop_train_step_valid_not_improve":30,
37 | "valid_metrics":"loss",
38 | "valid_metrics_less":true,
39 | "optimizer":{
40 | "lr":1e-5,
41 | "weight_decay":1e-4
42 | }
43 | },
44 | "model":{
45 | "name":"BIRDEncoderSVM",
46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v5_env0_opt1/end",
47 | "do":256,
48 | "abnormality_rate":0.3
49 | }
50 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_svm_v6_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":false,
4 | "known_p_num":8,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v6",
8 | "train":{
9 | "dir":"data/v6/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":48,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "test":{
20 | "dir":"data/v6/test_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | }
30 | },
31 | "train":{
32 | "is_DA":false,
33 | "max_epoch":100,
34 | "valid_start_epoch":5,
35 | "valid_step":1,
36 | "stop_train_step_valid_not_improve":30,
37 | "valid_metrics":"loss",
38 | "valid_metrics_less":true,
39 | "optimizer":{
40 | "lr":1e-5,
41 | "weight_decay":1e-4
42 | }
43 | },
44 | "model":{
45 | "name":"BIRDEncoderSVM",
46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v6_env0_opt2/best",
47 | "do":256,
48 | "abnormality_rate":0.3
49 | }
50 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_svm_v7_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":false,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v7",
8 | "train":{
9 | "dir":"data/v7/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":48,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "test":{
20 | "dir":"data/v7/test_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | }
30 | },
31 | "train":{
32 | "is_DA":false,
33 | "max_epoch":100,
34 | "valid_start_epoch":5,
35 | "valid_step":1,
36 | "stop_train_step_valid_not_improve":30,
37 | "valid_metrics":"loss",
38 | "valid_metrics_less":true,
39 | "optimizer":{
40 | "lr":1e-5,
41 | "weight_decay":1e-4
42 | }
43 | },
44 | "model":{
45 | "name":"BIRDEncoderSVM",
46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v7_env0_opt2/end",
47 | "do":256,
48 | "abnormality_rate":0.3
49 | }
50 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_svm_v8_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":false,
4 | "known_p_num":6,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v8",
8 | "train":{
9 | "dir":"data/v8/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":48,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "test":{
20 | "dir":"data/v8/test_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | }
30 | },
31 | "train":{
32 | "is_DA":false,
33 | "max_epoch":100,
34 | "valid_start_epoch":5,
35 | "valid_step":1,
36 | "stop_train_step_valid_not_improve":30,
37 | "valid_metrics":"loss",
38 | "valid_metrics_less":true,
39 | "optimizer":{
40 | "lr":1e-5,
41 | "weight_decay":1e-4
42 | }
43 | },
44 | "model":{
45 | "name":"BIRDEncoderSVM",
46 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v8_env0_opt1/best",
47 | "do":256,
48 | "abnormality_rate":0.3
49 | }
50 | }
--------------------------------------------------------------------------------
/config/csiid/csiid_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test_legal.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":20,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CSIID"
57 | }
58 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_id_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test_legal.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":20,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAUId"
57 | }
58 | }
--------------------------------------------------------------------------------
/config/csiid/csiid_v4_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CSIID"
57 | }
58 | }
--------------------------------------------------------------------------------
/config/csiid/csiid_v4_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CSIID"
57 | }
58 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_id_v4_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAUId"
57 | }
58 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_id_v4_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAUId"
57 | }
58 | }
--------------------------------------------------------------------------------
/config/gateid/gateid_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":16,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test_legal.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":150,
45 | "valid_start_epoch":20,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"GateID",
57 | "window_size":40
58 | }
59 | }
--------------------------------------------------------------------------------
/config/gateid/gateid_v4_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":16,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":150,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"GateID",
57 | "window_size":40
58 | }
59 | }
--------------------------------------------------------------------------------
/config/gateid/gateid_v4_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":16,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":150,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"GateID",
57 | "window_size":40
58 | }
59 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":20,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAU",
57 | "p":0.3,
58 | "d":10,
59 | "lambda_reg":0.01
60 | }
61 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_v5_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":32,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAU",
57 | "p":0.3,
58 | "d":10,
59 | "lambda_reg":0.01
60 | }
61 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_v5_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":32,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAU",
57 | "p":0.3,
58 | "d":10,
59 | "lambda_reg":0.01
60 | }
61 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_v6_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":8,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v6",
8 | "train":{
9 | "dir":"data/v6/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":32,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v6/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v6/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAU",
57 | "p":0.3,
58 | "d":10,
59 | "lambda_reg":0.01
60 | }
61 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_v7_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v7",
8 | "train":{
9 | "dir":"data/v7/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":32,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v7/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v7/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAU",
57 | "p":0.3,
58 | "d":10,
59 | "lambda_reg":0.01
60 | }
61 | }
--------------------------------------------------------------------------------
/config/wiau/wiau_v8_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":6,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v8",
8 | "train":{
9 | "dir":"data/v8/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":32,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v8/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v8/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WIAU",
57 | "p":0.3,
58 | "d":10,
59 | "lambda_reg":0.01
60 | }
61 | }
--------------------------------------------------------------------------------
/config/caution/caution_v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":8,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v3",
8 | "train":{
9 | "dir":"data/v3/support.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v3/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":false,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v3/test.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":50,
45 | "valid_metrics":"loss",
46 | "valid_metrics_less":true,
47 | "optimizer":{
48 | "lr":1e-4,
49 | "weight_decay":1e-4
50 | }
51 | },
52 | "model":{
53 | "name":"Caution",
54 | "update_start_epoch":5,
55 | "update_step":10,
56 | "pretrained_enc_dir":"model/pretrained/caution_encoder_opt1/end",
57 | "initial_threshold":0.5,
58 | "num_iterations":50,
59 | "num_thresholds":20,
60 | "threshold_step":0.05
61 | }
62 | }
--------------------------------------------------------------------------------
/config/caution/caution_v6_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":6,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v6",
8 | "train":{
9 | "dir":"data/v6/support_caution_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v6/valid_caution_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":false,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v6/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":50,
45 | "valid_metrics":"loss",
46 | "valid_metrics_less":true,
47 | "optimizer":{
48 | "lr":1e-4,
49 | "weight_decay":1e-4
50 | }
51 | },
52 | "model":{
53 | "name":"Caution",
54 | "update_start_epoch":5,
55 | "update_step":10,
56 | "pretrained_enc_dir":"",
57 | "initial_threshold":0.75,
58 | "num_iterations":100,
59 | "num_thresholds":200,
60 | "threshold_step":0.001
61 | }
62 | }
--------------------------------------------------------------------------------
/config/caution/caution_v5_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/support_caution_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_caution_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":false,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":50,
45 | "valid_metrics":"loss",
46 | "valid_metrics_less":true,
47 | "optimizer":{
48 | "lr":1e-4,
49 | "weight_decay":1e-4
50 | }
51 | },
52 | "model":{
53 | "name":"Caution",
54 | "update_start_epoch":5,
55 | "update_step":10,
56 | "pretrained_enc_dir":"model/pretrained/CautionEncoder_v5_env1_opt1/end",
57 | "initial_threshold":0.5,
58 | "num_iterations":50,
59 | "num_thresholds":20,
60 | "threshold_step":0.05
61 | }
62 | }
--------------------------------------------------------------------------------
/config/caution/caution_v5_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/support_caution_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_caution_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":false,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":50,
45 | "valid_metrics":"loss",
46 | "valid_metrics_less":true,
47 | "optimizer":{
48 | "lr":1e-4,
49 | "weight_decay":1e-4
50 | }
51 | },
52 | "model":{
53 | "name":"Caution",
54 | "update_start_epoch":5,
55 | "update_step":10,
56 | "pretrained_enc_dir":"model/pretrained/CautionEncoder_v5_env0_opt1/end",
57 | "initial_threshold":0.75,
58 | "num_iterations":100,
59 | "num_thresholds":200,
60 | "threshold_step":0.001
61 | }
62 | }
--------------------------------------------------------------------------------
/config/caution/caution_v7_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":5,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v7",
8 | "train":{
9 | "dir":"data/v7/support_caution_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v7/valid_caution_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":false,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v7/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":50,
45 | "valid_metrics":"loss",
46 | "valid_metrics_less":true,
47 | "optimizer":{
48 | "lr":1e-4,
49 | "weight_decay":1e-4
50 | }
51 | },
52 | "model":{
53 | "name":"Caution",
54 | "update_start_epoch":5,
55 | "update_step":10,
56 | "pretrained_enc_dir":"model/pretrained/CautionEncoder_v7_env0_opt1/best",
57 | "initial_threshold":0.75,
58 | "num_iterations":100,
59 | "num_thresholds":200,
60 | "threshold_step":0.001
61 | }
62 | }
--------------------------------------------------------------------------------
/config/caution/caution_v8_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":4,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v8",
8 | "train":{
9 | "dir":"data/v8/support_caution_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v8/valid_caution_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":false,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v8/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":50,
45 | "valid_metrics":"loss",
46 | "valid_metrics_less":true,
47 | "optimizer":{
48 | "lr":1e-4,
49 | "weight_decay":1e-4
50 | }
51 | },
52 | "model":{
53 | "name":"Caution",
54 | "update_start_epoch":5,
55 | "update_step":10,
56 | "pretrained_enc_dir":"model/pretrained/CautionEncoder_v8_env0_opt1/best",
57 | "initial_threshold":0.75,
58 | "num_iterations":100,
59 | "num_thresholds":200,
60 | "threshold_step":0.001
61 | }
62 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | asttokens==2.4.1
3 | comm==0.2.2
4 | contourpy==1.3.0
5 | cycler==0.12.1
6 | debugpy==1.8.5
7 | decorator==5.1.1
8 | executing==2.0.1
9 | feedparser==6.0.11
10 | filelock==3.13.1
11 | fonttools==4.53.1
12 | fsspec==2024.2.0
13 | grpcio==1.66.2
14 | h5py==3.11.0
15 | ipykernel==6.29.5
16 | ipython==8.26.0
17 | ipywidgets==8.1.5
18 | jedi==0.19.1
19 | Jinja2==3.1.3
20 | joblib==1.4.2
21 | jupyter_client==8.6.2
22 | jupyter_core==5.7.2
23 | jupyterlab_widgets==3.0.13
24 | kiwisolver==1.4.5
25 | Markdown==3.7
26 | MarkupSafe==2.1.5
27 | matplotlib==3.9.2
28 | matplotlib-inline==0.1.7
29 | mpmath==1.3.0
30 | nest-asyncio==1.6.0
31 | networkx==3.2.1
32 | numpy==1.26.4
33 | nvidia-cublas-cu12==12.4.2.65
34 | nvidia-cuda-cupti-cu12==12.4.99
35 | nvidia-cuda-nvrtc-cu12==12.4.99
36 | nvidia-cuda-runtime-cu12==12.4.99
37 | nvidia-cudnn-cu12==9.1.0.70
38 | nvidia-cufft-cu12==11.2.0.44
39 | nvidia-curand-cu12==10.3.5.119
40 | nvidia-cusolver-cu12==11.6.0.99
41 | nvidia-cusparse-cu12==12.3.0.142
42 | nvidia-nccl-cu12==2.20.5
43 | nvidia-nvjitlink-cu12==12.4.99
44 | nvidia-nvtx-cu12==12.4.99
45 | packaging==24.1
46 | parso==0.8.4
47 | pexpect==4.9.0
48 | pillow==10.4.0
49 | platformdirs==4.2.2
50 | prettytable==3.11.0
51 | prompt_toolkit==3.0.47
52 | protobuf==5.28.2
53 | psutil==6.0.0
54 | ptyprocess==0.7.0
55 | pure_eval==0.2.3
56 | pydash==8.0.3
57 | Pygments==2.18.0
58 | pyparsing==3.1.4
59 | python-dateutil==2.9.0.post0
60 | pyzmq==26.2.0
61 | scikit-learn==1.5.2
62 | scipy==1.14.1
63 | sgmllib3k==1.0.0
64 | six==1.16.0
65 | stack-data==0.6.3
66 | sympy==1.12
67 | tensorboard==2.18.0
68 | tensorboard-data-server==0.7.2
69 | threadpoolctl==3.5.0
70 | torch==2.4.0+cu124
71 | torchaudio==2.4.0+cu124
72 | torchvision==0.19.0+cu124
73 | tornado==6.4.1
74 | tqdm==4.66.5
75 | traitlets==5.14.3
76 | triton==3.0.0
77 | typing_extensions==4.12.2
78 | wcwidth==0.2.13
79 | Werkzeug==3.0.4
80 | widgetsnbextension==4.0.13
81 |
--------------------------------------------------------------------------------
/config/autoencoder/cae_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":16,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":16,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":20,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CAE",
57 | "loss":"nmse",
58 | "encoder_hid_layers":[
59 | 128,
60 | 1024,
61 | 1024
62 | ],
63 | "decoder_hid_layers":[
64 | 1024,
65 | 1024,
66 | 128
67 | ]
68 | }
69 | }
--------------------------------------------------------------------------------
/config/autoencoder/cae_v5_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":16,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CAE",
57 | "loss":"nmse",
58 | "encoder_hid_layers":[
59 | 128,
60 | 1024,
61 | 1024
62 | ],
63 | "decoder_hid_layers":[
64 | 1024,
65 | 1024,
66 | 128
67 | ]
68 | }
69 | }
--------------------------------------------------------------------------------
/config/autoencoder/cae_v5_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":16,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CAE",
57 | "loss":"nmse",
58 | "encoder_hid_layers":[
59 | 128,
60 | 1024,
61 | 1024
62 | ],
63 | "decoder_hid_layers":[
64 | 1024,
65 | 1024,
66 | 128
67 | ]
68 | }
69 | }
--------------------------------------------------------------------------------
/config/autoencoder/cae_v6_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":8,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v6",
8 | "train":{
9 | "dir":"data/v6/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v6/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":16,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v6/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CAE",
57 | "loss":"nmse",
58 | "encoder_hid_layers":[
59 | 128,
60 | 1024,
61 | 1024
62 | ],
63 | "decoder_hid_layers":[
64 | 1024,
65 | 1024,
66 | 128
67 | ]
68 | }
69 | }
--------------------------------------------------------------------------------
/config/autoencoder/cae_v7_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v7",
8 | "train":{
9 | "dir":"data/v7/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v7/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":16,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v7/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CAE",
57 | "loss":"nmse",
58 | "encoder_hid_layers":[
59 | 128,
60 | 1024,
61 | 1024
62 | ],
63 | "decoder_hid_layers":[
64 | 1024,
65 | 1024,
66 | 128
67 | ]
68 | }
69 | }
--------------------------------------------------------------------------------
/config/autoencoder/cae_v8_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":6,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v8",
8 | "train":{
9 | "dir":"data/v8/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v8/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":16,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v8/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CAE",
57 | "loss":"nmse",
58 | "encoder_hid_layers":[
59 | 128,
60 | 1024,
61 | 1024
62 | ],
63 | "decoder_hid_layers":[
64 | 1024,
65 | 1024,
66 | 128
67 | ]
68 | }
69 | }
--------------------------------------------------------------------------------
/config/deep_wiid/deep_wiid_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test_legal.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":20,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"DeepWiID",
57 | "gru_cfg":{
58 | "hidden_size":256,
59 | "num_layers":4,
60 | "batch_first":true
61 | },
62 | "tail_net":{
63 | "hid_layers":[
64 | 128,
65 | 32
66 | ],
67 | "activation_fn":"tanh",
68 | "drop_out":0.5,
69 | "end_with_softmax":true
70 | }
71 | }
72 | }
--------------------------------------------------------------------------------
/config/deep_wiid/deep_wiid_v4_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"DeepWiID",
57 | "gru_cfg":{
58 | "hidden_size":256,
59 | "num_layers":4,
60 | "batch_first":true
61 | },
62 | "tail_net":{
63 | "hid_layers":[
64 | 128,
65 | 32
66 | ],
67 | "activation_fn":"tanh",
68 | "drop_out":0.5,
69 | "end_with_softmax":true
70 | }
71 | }
72 | }
--------------------------------------------------------------------------------
/config/deep_wiid/deep_wiid_v4_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":64,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"DeepWiID",
57 | "gru_cfg":{
58 | "hidden_size":256,
59 | "num_layers":4,
60 | "batch_first":true
61 | },
62 | "tail_net":{
63 | "hid_layers":[
64 | 128,
65 | 32
66 | ],
67 | "activation_fn":"tanh",
68 | "drop_out":0.5,
69 | "end_with_softmax":true
70 | }
71 | }
72 | }
--------------------------------------------------------------------------------
/config/gait_enhance/gait_enhance_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":96,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":150,
45 | "valid_start_epoch":20,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"GaitEnhance",
57 | "window_size":40,
58 | "blks_cfg":{
59 | "hid_layers":[
60 | 32,
61 | 64,
62 | 128
63 | ]
64 | },
65 | "dropout_rate":0.2,
66 | "output_layer":{
67 | "hid_layers":[
68 | 128,
69 | 64
70 | ],
71 | "activation_fn":"tanh",
72 | "drop_out":0.5,
73 | "end_with_softmax":true
74 | }
75 | }
76 | }
--------------------------------------------------------------------------------
/config/gait_enhance/gait_enhance_v4_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":96,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":150,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"GaitEnhance",
57 | "window_size":40,
58 | "blks_cfg":{
59 | "hid_layers":[
60 | 32,
61 | 64,
62 | 128
63 | ]
64 | },
65 | "dropout_rate":0.2,
66 | "output_layer":{
67 | "hid_layers":[
68 | 128,
69 | 64
70 | ],
71 | "activation_fn":"tanh",
72 | "drop_out":0.5,
73 | "end_with_softmax":true
74 | }
75 | }
76 | }
--------------------------------------------------------------------------------
/config/gait_enhance/gait_enhance_v4_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":96,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":24,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":150,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"GaitEnhance",
57 | "window_size":40,
58 | "blks_cfg":{
59 | "hid_layers":[
60 | 32,
61 | 64,
62 | 128
63 | ]
64 | },
65 | "dropout_rate":0.2,
66 | "output_layer":{
67 | "hid_layers":[
68 | 128,
69 | 64
70 | ],
71 | "activation_fn":"tanh",
72 | "drop_out":0.5,
73 | "end_with_softmax":true
74 | }
75 | }
76 | }
--------------------------------------------------------------------------------
/config/caution/caution_encoder_v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":8,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v3",
8 | "train":{
9 | "dir":"data/v3/support.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v3/query.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v3/test_legal.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":50,
45 | "valid_start_epoch":1e+5,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-4,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CautionEncoder",
57 | "do":256,
58 | "update_start_epoch":5,
59 | "update_step":10,
60 | "cnn":{
61 | "hid_layer":[
62 | 32,
63 | 128,
64 | 128
65 | ],
66 | "activation_fn":"leaky_relu"
67 | },
68 | "FeatureLayer":{
69 | "hid_layers":[
70 | 512
71 | ],
72 | "activation_fn":"tanh",
73 | "drop_out":0.2,
74 | "end_with_softmax":false
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/config/caution/caution_encoder_v5_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/support_caution_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/query_caution_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env0_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1e+5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":5,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CautionEncoder",
57 | "do":256,
58 | "update_start_epoch":5,
59 | "update_step":10,
60 | "cnn":{
61 | "hid_layer":[
62 | 32,
63 | 128,
64 | 128
65 | ],
66 | "activation_fn":"leaky_relu"
67 | },
68 | "FeatureLayer":{
69 | "hid_layers":[
70 | 512
71 | ],
72 | "activation_fn":"tanh",
73 | "drop_out":0.2,
74 | "end_with_softmax":false
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/config/caution/caution_encoder_v5_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/support_caution_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/query_caution_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env1_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1e+5,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CautionEncoder",
57 | "do":256,
58 | "update_start_epoch":5,
59 | "update_step":10,
60 | "cnn":{
61 | "hid_layer":[
62 | 32,
63 | 128,
64 | 128
65 | ],
66 | "activation_fn":"leaky_relu"
67 | },
68 | "FeatureLayer":{
69 | "hid_layers":[
70 | 512
71 | ],
72 | "activation_fn":"tanh",
73 | "drop_out":0.2,
74 | "end_with_softmax":false
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/config/caution/caution_encoder_v6_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":6,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v6",
8 | "train":{
9 | "dir":"data/v6/support_caution_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v6/query_caution_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v6/test_env0_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1e+5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":5,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CautionEncoder",
57 | "do":256,
58 | "update_start_epoch":5,
59 | "update_step":10,
60 | "cnn":{
61 | "hid_layer":[
62 | 32,
63 | 128,
64 | 128
65 | ],
66 | "activation_fn":"leaky_relu"
67 | },
68 | "FeatureLayer":{
69 | "hid_layers":[
70 | 512
71 | ],
72 | "activation_fn":"tanh",
73 | "drop_out":0.2,
74 | "end_with_softmax":false
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/config/caution/caution_encoder_v7_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":5,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v7",
8 | "train":{
9 | "dir":"data/v7/support_caution_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v7/query_caution_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v7/test_env0_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1e+5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":5,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CautionEncoder",
57 | "do":256,
58 | "update_start_epoch":5,
59 | "update_step":10,
60 | "cnn":{
61 | "hid_layer":[
62 | 32,
63 | 128,
64 | 128
65 | ],
66 | "activation_fn":"leaky_relu"
67 | },
68 | "FeatureLayer":{
69 | "hid_layers":[
70 | 512
71 | ],
72 | "activation_fn":"tanh",
73 | "drop_out":0.2,
74 | "end_with_softmax":false
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/config/caution/caution_encoder_v8_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":4,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v8",
8 | "train":{
9 | "dir":"data/v8/support_caution_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":1,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v8/query_caution_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v8/test_env0_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1e+5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":5,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"CautionEncoder",
57 | "do":256,
58 | "update_start_epoch":5,
59 | "update_step":10,
60 | "cnn":{
61 | "hid_layer":[
62 | 32,
63 | 128,
64 | 128
65 | ],
66 | "activation_fn":"leaky_relu"
67 | },
68 | "FeatureLayer":{
69 | "hid_layers":[
70 | 512
71 | ],
72 | "activation_fn":"tanh",
73 | "drop_out":0.2,
74 | "end_with_softmax":false
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/config/autoencoder/ae_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":48,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":20,
46 | "valid_step":5,
47 | "stop_train_step_valid_not_improve":6,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"AE",
57 | "loss":"nmse",
58 | "do":256,
59 | "encoder_cfg":{
60 | "hid_layers":[
61 | 1024,
62 | 1024,
63 | 512
64 | ],
65 | "activation_fn":"relu",
66 | "drop_out":0.5,
67 | "end_with_softmax":false
68 | },
69 | "decoder_cfg":{
70 | "hid_layers":[
71 | 512,
72 | 1024,
73 | 1024
74 | ],
75 | "activation_fn":"relu",
76 | "drop_out":0.5,
77 | "end_with_softmax":false
78 | }
79 | }
80 | }
--------------------------------------------------------------------------------
/config/autoencoder/ae_v5_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"AE",
57 | "loss":"nmse",
58 | "do":256,
59 | "encoder_cfg":{
60 | "hid_layers":[
61 | 1024,
62 | 1024,
63 | 512
64 | ],
65 | "activation_fn":"relu",
66 | "drop_out":0.5,
67 | "end_with_softmax":false
68 | },
69 | "decoder_cfg":{
70 | "hid_layers":[
71 | 512,
72 | 1024,
73 | 1024
74 | ],
75 | "activation_fn":"relu",
76 | "drop_out":0.5,
77 | "end_with_softmax":false
78 | }
79 | }
80 | }
--------------------------------------------------------------------------------
/config/autoencoder/ae_v5_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"AE",
57 | "loss":"nmse",
58 | "do":256,
59 | "encoder_cfg":{
60 | "hid_layers":[
61 | 1024,
62 | 1024,
63 | 512
64 | ],
65 | "activation_fn":"relu",
66 | "drop_out":0.5,
67 | "end_with_softmax":false
68 | },
69 | "decoder_cfg":{
70 | "hid_layers":[
71 | 512,
72 | 1024,
73 | 1024
74 | ],
75 | "activation_fn":"relu",
76 | "drop_out":0.5,
77 | "end_with_softmax":false
78 | }
79 | }
80 | }
--------------------------------------------------------------------------------
/config/autoencoder/ae_v6_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":8,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v6",
8 | "train":{
9 | "dir":"data/v6/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v6/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v6/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"AE",
57 | "loss":"nmse",
58 | "do":256,
59 | "encoder_cfg":{
60 | "hid_layers":[
61 | 1024,
62 | 1024,
63 | 512
64 | ],
65 | "activation_fn":"relu",
66 | "drop_out":0.5,
67 | "end_with_softmax":false
68 | },
69 | "decoder_cfg":{
70 | "hid_layers":[
71 | 512,
72 | 1024,
73 | 1024
74 | ],
75 | "activation_fn":"relu",
76 | "drop_out":0.5,
77 | "end_with_softmax":false
78 | }
79 | }
80 | }
--------------------------------------------------------------------------------
/config/autoencoder/ae_v7_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v7",
8 | "train":{
9 | "dir":"data/v7/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v7/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v7/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"AE",
57 | "loss":"nmse",
58 | "do":256,
59 | "encoder_cfg":{
60 | "hid_layers":[
61 | 1024,
62 | 1024,
63 | 512
64 | ],
65 | "activation_fn":"relu",
66 | "drop_out":0.5,
67 | "end_with_softmax":false
68 | },
69 | "decoder_cfg":{
70 | "hid_layers":[
71 | 512,
72 | 1024,
73 | 1024
74 | ],
75 | "activation_fn":"relu",
76 | "drop_out":0.5,
77 | "end_with_softmax":false
78 | }
79 | }
80 | }
--------------------------------------------------------------------------------
/config/autoencoder/ae_v8_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":6,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v8",
8 | "train":{
9 | "dir":"data/v8/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v8/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v8/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"AE",
57 | "loss":"nmse",
58 | "do":256,
59 | "encoder_cfg":{
60 | "hid_layers":[
61 | 1024,
62 | 1024,
63 | 512
64 | ],
65 | "activation_fn":"relu",
66 | "drop_out":0.5,
67 | "end_with_softmax":false
68 | },
69 | "decoder_cfg":{
70 | "hid_layers":[
71 | 512,
72 | 1024,
73 | 1024
74 | ],
75 | "activation_fn":"relu",
76 | "drop_out":0.5,
77 | "end_with_softmax":false
78 | }
79 | }
80 | }
--------------------------------------------------------------------------------
/lib/json_util.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023.03.03
2 | # @Author : Darrius Lei
3 | # @Email : darrius.lei@outlook.com
4 | import json, os
5 |
6 | def jsonload(src):
7 | '''Read the json file and convert it to dict
8 | Suitable for files containing only one json string.
9 |
10 | Parameters:
11 | -----------
12 |
13 | src:chr
14 | the source direction of the jsonfile
15 |
16 | Returns:
17 | --------
18 |
19 | json_dict: dict
20 | '''
21 | assert os.path.exists(src);
22 | f = open(src, 'r');
23 | return json.load(f);
24 |
25 | def jsonparse(src):
26 | '''Read the json file and convert every json string to dict
27 | Applicable to files containing only multiple json strings
28 |
29 | Parameters:
30 | -----------
31 |
32 | src: chr
33 | the source direction of the jsonfile
34 |
35 | Returns:
36 | --------
37 |
38 | json_dict: dict
39 |
40 | Example:
41 | ---------
42 |
43 | >>> from json_util import jsonparse
44 | >>> src = './file.json';
45 | >>> for item in jsonparse(src):
46 | >>> print(item);
47 | '''
48 | assert os.path.exists(src);
49 | f = open(src, 'r');
50 | for item in f:
51 | yield json.loads(item)
52 |
53 | def jsonlen(src):
54 | '''Count the number of json data in the json type file
55 |
56 | Parameters:
57 | -----------
58 |
59 | src:chr
60 | the source direction of the jsonfile
61 |
62 | Returns:
63 | --------
64 |
65 | cnt:int
66 |
67 | '''
68 | assert os.path.exists(src);
69 | f = open(src, 'r');
70 | cnt = 0;
71 | for item in f:
72 | if type(json.loads(item)) == dict:
73 | cnt += 1;
74 | return cnt;
75 |
76 | def dict2jsonstr(dict):
77 | '''turn dict to str
78 | '''
79 | return json.dumps(dict, indent = 4);
80 |
81 | def jsonsave(dict, tgt):
82 | '''Save the python type dict as a json file
83 |
84 | Parameters:
85 | -----------
86 | dict:dict
87 |
88 | tgt:str
89 | Destination path (including file name)
90 | '''
91 | json_data = json.dumps(dict, indent = 4);
92 | json_file = open(tgt, 'w');
93 | json_file.write(json_data);
94 | json_file.close();
95 |
--------------------------------------------------------------------------------
/config/dcs_gait/dcs_gait_encoder_v1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":32,
14 | "drop_last":true,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":32,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test_legal.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":32,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":10,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"DCSGaitEncoder",
57 | "do":256,
58 | "cnn_cfg":{
59 | "hid_layers":[
60 | 16,
61 | 32
62 | ]
63 | },
64 | "tail_net_cfg":{
65 | "hid_layers":[
66 | 1024
67 | ],
68 | "activation_fn":"tanh",
69 | "drop_out":0.5,
70 | "end_with_softmax":false
71 | },
72 | "p_classifier":{
73 | "hid_layers":[
74 | 64
75 | ],
76 | "activation_fn":"tanh",
77 | "drop_out":0.5,
78 | "end_with_softmax":true
79 | }
80 | }
81 | }
--------------------------------------------------------------------------------
/config/wiai_id/wiai_id_v4_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":4,
14 | "drop_last":true,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":1,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WiAiId",
57 | "do":256,
58 | "alpha":0.1,
59 | "beta":0.5,
60 | "gamma":0.5,
61 | "MHA_cfg":{
62 | "d_q":64,
63 | "d_k":64,
64 | "d_v":64,
65 | "n_heads":3
66 | },
67 | "MultiScaleCNN_cfg":{
68 | "h_layers":[
69 | 84
70 | ],
71 | "hid_layers":[
72 | 64,
73 | 32,
74 | 8,
75 | 1
76 | ]
77 | },
78 | "p_classifier":{
79 | "hid_layers":[
80 | 128,
81 | 32
82 | ]
83 | },
84 | "env_classifier":{
85 | "hid_layers":[
86 | 32
87 | ]
88 | }
89 |
90 | }
91 | }
--------------------------------------------------------------------------------
/config/wiai_id/wiai_id_v4_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":4,
14 | "drop_last":true,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":1,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"WiAiId",
57 | "do":256,
58 | "alpha":0.1,
59 | "beta":0.5,
60 | "gamma":0.5,
61 | "MHA_cfg":{
62 | "d_q":64,
63 | "d_k":64,
64 | "d_v":64,
65 | "n_heads":3
66 | },
67 | "MultiScaleCNN_cfg":{
68 | "h_layers":[
69 | 84
70 | ],
71 | "hid_layers":[
72 | 64,
73 | 32,
74 | 8,
75 | 1
76 | ]
77 | },
78 | "p_classifier":{
79 | "hid_layers":[
80 | 128,
81 | 32
82 | ]
83 | },
84 | "env_classifier":{
85 | "hid_layers":[
86 | 32
87 | ]
88 | }
89 |
90 | }
91 | }
--------------------------------------------------------------------------------
/config/dcs_gait/dcs_gait_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v2/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":16,
14 | "drop_last":false,
15 | "shuffle":false,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "target":{
20 | "dir":"data/v2/target.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":16,
25 | "drop_last":false,
26 | "shuffle":false,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "valid":{
31 | "dir":"data/v2/valid.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | },
41 | "test":{
42 | "dir":"data/v2/test_legal.mat",
43 | "loader":{
44 | "num_workers":4,
45 | "pin_memory":true,
46 | "batch_size":8,
47 | "drop_last":false,
48 | "shuffle":true,
49 | "prefetch_factor":4
50 | }
51 | }
52 | },
53 | "train":{
54 | "is_DA":true,
55 | "max_epoch":100,
56 | "valid_start_epoch":1,
57 | "valid_step":1,
58 | "stop_train_step_valid_not_improve":30,
59 | "valid_metrics":"acc",
60 | "valid_metrics_less":false,
61 | "optimizer":{
62 | "lr":1e-6,
63 | "weight_decay":1e-4
64 | }
65 | },
66 | "model":{
67 | "name":"DCSGait",
68 | "pretrained_enc_dir":"model/pretrained/DCSGaitEncoder_opt1/best",
69 | "alpha":0.5,
70 | "num_iterations":50,
71 | "MHA_cfg":{
72 | "d_q":64,
73 | "d_k":64,
74 | "d_v":64,
75 | "n_heads":3
76 | },
77 | "p_classifier":{
78 | "hid_layers":[
79 | 64
80 | ],
81 | "activation_fn":"tanh",
82 | "drop_out":0.5,
83 | "end_with_softmax":true
84 | }
85 | }
86 | }
--------------------------------------------------------------------------------
/config/deep_wiid/deep_wiid_al_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v2",
8 | "train":{
9 | "dir":"data/v2/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v2/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "target":{
31 | "dir":"data/v2/target.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | },
41 | "test":{
42 | "dir":"data/v2/test_legal.mat",
43 | "loader":{
44 | "num_workers":4,
45 | "pin_memory":true,
46 | "batch_size":8,
47 | "drop_last":false,
48 | "shuffle":true,
49 | "prefetch_factor":4
50 | }
51 | }
52 | },
53 | "train":{
54 | "is_DA":true,
55 | "max_epoch":200,
56 | "valid_start_epoch":1,
57 | "valid_step":1,
58 | "stop_train_step_valid_not_improve":30,
59 | "valid_metrics":"acc",
60 | "valid_metrics_less":false,
61 | "optimizer":{
62 | "lr":1e-5,
63 | "weight_decay":1e-4
64 | }
65 | },
66 | "model":{
67 | "name":"DeepWiIDAL",
68 | "gru_cfg":{
69 | "hidden_size":256,
70 | "num_layers":4,
71 | "batch_first":true
72 | },
73 | "tail_net":{
74 | "hid_layers":[
75 | 128,
76 | 32
77 | ],
78 | "activation_fn":"tanh",
79 | "drop_out":0.5,
80 | "end_with_softmax":true
81 | },
82 | "env_classifier":{
83 | "hid_layers":[
84 | 128,
85 | 64
86 | ],
87 | "activation_fn":"mish",
88 | "drop_out":0.5,
89 | "end_with_softmax":true
90 | },
91 | "lambda_":1
92 | }
93 | }
--------------------------------------------------------------------------------
/config/gait_enhance/gait_enhance_al_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v2",
8 | "train":{
9 | "dir":"data/v2/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v2/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "target":{
31 | "dir":"data/v2/target.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | },
41 | "test":{
42 | "dir":"data/v2/test_legal.mat",
43 | "loader":{
44 | "num_workers":4,
45 | "pin_memory":true,
46 | "batch_size":8,
47 | "drop_last":false,
48 | "shuffle":true,
49 | "prefetch_factor":4
50 | }
51 | }
52 | },
53 | "train":{
54 | "is_DA":true,
55 | "max_epoch":150,
56 | "valid_start_epoch":1,
57 | "valid_step":1,
58 | "stop_train_step_valid_not_improve":30,
59 | "valid_metrics":"acc",
60 | "valid_metrics_less":false,
61 | "optimizer":{
62 | "lr":1e-5,
63 | "weight_decay":1e-4
64 | }
65 | },
66 | "model":{
67 | "name":"GaitEnhanceAL",
68 | "window_size":40,
69 | "blks_cfg":{
70 | "hid_layers":[
71 | 32,
72 | 64,
73 | 128
74 | ]
75 | },
76 | "dropout_rate":0.2,
77 | "output_layer":{
78 | "hid_layers":[
79 | 128,
80 | 64
81 | ],
82 | "activation_fn":"tanh",
83 | "drop_out":0.5,
84 | "end_with_softmax":true
85 | },
86 | "env_classifier":{
87 | "hid_layers":[
88 | 128,
89 | 64
90 | ],
91 | "activation_fn":"mish",
92 | "drop_out":0.5,
93 | "end_with_softmax":true
94 | },
95 | "lambda_":1
96 | }
97 | }
--------------------------------------------------------------------------------
/config/wiai_id/wiai_id_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v2/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":4,
14 | "drop_last":true,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "target":{
20 | "dir":"data/v2/target.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":4,
25 | "drop_last":true,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "valid":{
31 | "dir":"data/v2/valid.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | },
41 | "test":{
42 | "dir":"data/v2/test_legal.mat",
43 | "loader":{
44 | "num_workers":4,
45 | "pin_memory":true,
46 | "batch_size":8,
47 | "drop_last":false,
48 | "shuffle":true,
49 | "prefetch_factor":4
50 | }
51 | }
52 | },
53 | "train":{
54 | "is_DA":true,
55 | "max_epoch":30,
56 | "valid_start_epoch":1,
57 | "valid_step":1,
58 | "stop_train_step_valid_not_improve":10,
59 | "valid_metrics":"acc",
60 | "valid_metrics_less":false,
61 | "optimizer":{
62 | "lr":1e-5,
63 | "weight_decay":1e-4
64 | }
65 | },
66 | "model":{
67 | "name":"WiAiId",
68 | "do":256,
69 | "alpha":0.1,
70 | "beta":0.5,
71 | "gamma":0.5,
72 | "MHA_cfg":{
73 | "d_q":64,
74 | "d_k":64,
75 | "d_v":64,
76 | "n_heads":3
77 | },
78 | "MultiScaleCNN_cfg":{
79 | "h_layers":[
80 | 84
81 | ],
82 | "hid_layers":[
83 | 64,
84 | 32,
85 | 8,
86 | 1
87 | ]
88 | },
89 | "p_classifier":{
90 | "hid_layers":[
91 | 128,
92 | 32
93 | ]
94 | },
95 | "env_classifier":{
96 | "hid_layers":[
97 | 32
98 | ]
99 | }
100 |
101 | }
102 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IDLab
2 |
3 | Code for Identity recognition and intrusion detection in wireless sensing.
4 |
5 | **Dataset acquisition:**
6 | Wait for the public release
7 |
8 | ## Command
9 |
10 | ### usage
11 |
12 | ```shell
13 | usage: executor.py [-h] [--config CONFIG] [--saved_config SAVED_CONFIG] [--auto_shutdown AUTO_SHUTDOWN]
14 | ```
15 |
16 | ### options
17 |
18 | ```shell
19 | -h, --help show this help message and exit
20 | --config CONFIG, -cfg CONFIG
21 | config for run
22 | --saved_config SAVED_CONFIG, -sc SAVED_CONFIG
23 | path for saved config to test
24 | --auto_shutdown AUTO_SHUTDOWN, -as AUTO_SHUTDOWN
25 | automatic shutdown after program completion
26 | ```
27 |
28 | ### quick start
29 | ```shell
30 | python executor.py -cfg=path/to/config
31 | ```
32 |
33 | ## Model
34 |
35 | - [x] Gait-Enhance[[1](#ref1)]
36 | - [x] Deep-WiID[[2](#ref2)]
37 | - [x] Caution[[3](#ref3)]
38 | - [x] CSIID[[4](#ref4)]
39 | - [x] Gate-ID[[5](#ref5)]
40 | - [x] WiAU[[6](#ref6)]
41 | - [x] WiAi-ID[[7](#ref7)]
42 | - [x] DCS-Gait[[8](#ref8)]
43 | - [x] Bird
44 |
45 |
46 |
47 | # Refrence
48 |
49 | 1. Yang J, Liu Y, Wu Y, et al. Gait-Enhance: Robust Gait Recognition of Complex Walking Patterns Based on WiFi CSI[C]//2023 IEEE Smart World Congress (SWC). IEEE, 2023: 1-9.
50 | 2. Zhou Z, Liu C, Yu X, et al. Deep-WiID: WiFi-based contactless human identification via deep learning[C]//2019 IEEE SmartWorld, Ubiquitous Intelligence & Computing, Advanced & Trusted Computing, Scalable Computing & Communications, Cloud & Big Data Computing, Internet of People and Smart City Innovation (SmartWorld/SCALCOM/UIC/ATC/CBDCom/IOP/SCI). IEEE, 2019: 877-884.
51 | 3. Wang D, Yang J, Cui W, et al. CAUTION: A Robust WiFi-based human authentication system via few-shot open-set recognition[J]. IEEE Internet of Things Journal, 2022, 9(18): 17323-17333.
52 | 4. Wang D, Zhou Z, Yu X, et al. CSIID: WiFi-based human identification via deep learning[C]//2019 14th International Conference on Computer Science & Education (ICCSE). IEEE, 2019: 326-330.
53 | 5. Zhang J, Wei B, Wu F, et al. Gate-ID: WiFi-based human identification irrespective of walking directions in smart home[J]. IEEE Internet of Things Journal, 2020, 8(9): 7610-7624.
54 | 6. Lin C, Hu J, Sun Y, et al. WiAU: An accurate device-free authentication system with ResNet[C]//2018 15th Annual IEEE International Conference on Sensing, Communication, and Networking (SECON). IEEE, 2018: 1-9.
55 | 7. Liang Y, Wu W, Li H, et al. WiAi-ID: Wi-Fi-Based Domain Adaptation for Appearance-Independent Passive Person Identification[J]. IEEE Internet of Things Journal, 2023, 11(1): 1012-1027.
56 | 8. Liang Y, Wu W, Li H, et al. DCS-Gait: A Class-Level Domain Adaptation Approach for Cross-Scene and Cross-State Gait Recognition Using Wi-Fi CSI[J]. IEEE Transactions on Information Forensics and Security, 2024.
--------------------------------------------------------------------------------
/model/framework/deep_wiid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from lib import util, glb_var
4 | from model.framework.base import Net
5 | from model import net_util
6 |
7 | class DeepWiID(Net):
8 | def __init__(self, model_cfg) -> None:
9 | super().__init__(model_cfg);
10 | #GRU
11 | gru_cfg = model_cfg['gru_cfg'];
12 | gru_cfg['input_size'] = np.prod(self.dim_in[2:]);
13 | self.gru = torch.nn.GRU(**gru_cfg);
14 | #Tail
15 | #tail_net
16 | tail_net_cfg = model_cfg['tail_net'];
17 | tail_net_cfg['dim_in'] = gru_cfg['hidden_size'];
18 | tail_net_cfg['dim_out'] = self.known_p_num;
19 | #There is only one layer of network, and these two parameters are not needed
20 | tail_net_cfg['activation_fn'] = net_util.get_activation_fn(tail_net_cfg['activation_fn']) if 'activation_fn' in tail_net_cfg.keys() else None;
21 | tail_net_cfg['drop_out'] = tail_net_cfg['drop_out'] if 'drop_out' in tail_net_cfg.keys() else None;
22 | self.tail_net = util.get_func_rets(net_util.get_mlp_net, tail_net_cfg);
23 |
24 | self.is_Intrusion_Detection = False;
25 |
26 | @torch.no_grad()
27 | def encoder(self, amps):
28 | #amps:[B, T, R * F]
29 | amps = amps.reshape(amps.shape[0], amps.shape[1], -1);
30 | #[B, T, d]
31 | features, _ = self.gru(amps);
32 | #[B, d]
33 | features = features.mean(dim = 1);
34 | return features;
35 |
36 | def p_classify(self, amps):
37 | #amps:[B, T, R * F]
38 | #[B, d]
39 | features = self.encoder(amps);
40 | id_probs = self.tail_net(features);
41 | return id_probs;
42 |
43 | def cal_loss(self, amps, ids, envs):
44 | id_probs = self.p_classify(amps);
45 | return torch.nn.CrossEntropyLoss()(id_probs, ids);
46 |
47 | class DeepWiIDAL(DeepWiID):
48 | def __init__(self, model_cfg):
49 | super().__init__(model_cfg);
50 |
51 | #env layer
52 | env_cfg = model_cfg['env_classifier'];
53 | env_cfg['activation_fn'] = net_util.get_activation_fn(env_cfg['activation_fn']);
54 | env_cfg['dim_in'] = model_cfg['tail_net']['dim_in'];
55 | env_cfg['dim_out'] = self.known_env_num;
56 | self.env_layer = util.get_func_rets(net_util.get_mlp_net, env_cfg);
57 |
58 | def cal_loss(self, amps, ids, envs, amps_t = None, ids_t = None, envs_t = None):
59 | features = self.encoder(amps);
60 | id_probs = self.tail_net(features);
61 | loss_id = torch.nn.CrossEntropyLoss()(id_probs, ids);
62 |
63 | if amps_t is None:
64 | return loss_id;
65 |
66 | env_probs = self.env_layer(net_util.GradientReversalF.apply(features, self.lambda_));
67 | loss_env = torch.nn.CrossEntropyLoss()(env_probs, envs);
68 |
69 | loss_t = loss_id + loss_env;
70 |
71 | feature_t = self.encoder(amps_t);
72 | id_probs_t = self.tail_net(feature_t);
73 | id_loss_t = torch.nn.CrossEntropyLoss()(id_probs_t, ids_t);
74 |
75 | return loss_t + id_loss_t;
76 |
77 | glb_var.register_model('DeepWiID', DeepWiID);
78 | glb_var.register_model('DeepWiIDAL', DeepWiIDAL);
--------------------------------------------------------------------------------
/config/bird/bird_v1_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v1/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v1/test.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":50,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRD",
57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_opt4/end",
58 | "loss":"mse",
59 | "do":256,
60 | "d":16,
61 | "t":64,
62 | "TLinear1":{
63 | "hid_layers":[
64 | 8,
65 | 32
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":4,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "FLinear":{
84 | "hid_layers":[
85 | 16,
86 | 32
87 | ],
88 | "activation_fn":"mish",
89 | "drop_out":0.5,
90 | "end_with_softmax":false
91 | },
92 | "RLinear":{
93 | "hid_layers":[
94 | 16,
95 | 8
96 | ],
97 | "activation_fn":"mish",
98 | "drop_out":0.5,
99 | "end_with_softmax":false
100 | },
101 | "TLinear2":{
102 | "hid_layers":[
103 | 256,
104 | 1024,
105 | 4096
106 | ],
107 | "activation_fn":"mish",
108 | "drop_out":0.5,
109 | "end_with_softmax":false
110 | }
111 | }
112 | }
--------------------------------------------------------------------------------
/config/bird/bird_v5_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":10,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRD",
57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v5_env0_opt1/end",
58 | "loss":"mse",
59 | "do":256,
60 | "d":16,
61 | "t":64,
62 | "TLinear1":{
63 | "hid_layers":[
64 | 8,
65 | 32
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":4,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "FLinear":{
84 | "hid_layers":[
85 | 16,
86 | 32
87 | ],
88 | "activation_fn":"mish",
89 | "drop_out":0.5,
90 | "end_with_softmax":false
91 | },
92 | "RLinear":{
93 | "hid_layers":[
94 | 16,
95 | 8
96 | ],
97 | "activation_fn":"mish",
98 | "drop_out":0.5,
99 | "end_with_softmax":false
100 | },
101 | "TLinear2":{
102 | "hid_layers":[
103 | 256,
104 | 1024,
105 | 4096
106 | ],
107 | "activation_fn":"mish",
108 | "drop_out":0.5,
109 | "end_with_softmax":false
110 | }
111 | }
112 | }
--------------------------------------------------------------------------------
/config/bird/bird_v5_env1_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":90,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRD",
57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v5_env1_opt1/end",
58 | "loss":"mse",
59 | "do":256,
60 | "d":16,
61 | "t":64,
62 | "TLinear1":{
63 | "hid_layers":[
64 | 8,
65 | 32
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":4,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "FLinear":{
84 | "hid_layers":[
85 | 16,
86 | 32
87 | ],
88 | "activation_fn":"mish",
89 | "drop_out":0.5,
90 | "end_with_softmax":false
91 | },
92 | "RLinear":{
93 | "hid_layers":[
94 | 16,
95 | 8
96 | ],
97 | "activation_fn":"mish",
98 | "drop_out":0.5,
99 | "end_with_softmax":false
100 | },
101 | "TLinear2":{
102 | "hid_layers":[
103 | 256,
104 | 1024,
105 | 4096
106 | ],
107 | "activation_fn":"mish",
108 | "drop_out":0.5,
109 | "end_with_softmax":false
110 | }
111 | }
112 | }
--------------------------------------------------------------------------------
/config/bird/bird_v6_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":8,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v6",
8 | "train":{
9 | "dir":"data/v6/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v6/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v6/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":10,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRD",
57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v6_env0_opt2/best",
58 | "loss":"mse",
59 | "do":256,
60 | "d":16,
61 | "t":64,
62 | "TLinear1":{
63 | "hid_layers":[
64 | 8,
65 | 32
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":4,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "FLinear":{
84 | "hid_layers":[
85 | 16,
86 | 32
87 | ],
88 | "activation_fn":"mish",
89 | "drop_out":0.5,
90 | "end_with_softmax":false
91 | },
92 | "RLinear":{
93 | "hid_layers":[
94 | 16,
95 | 8
96 | ],
97 | "activation_fn":"mish",
98 | "drop_out":0.5,
99 | "end_with_softmax":false
100 | },
101 | "TLinear2":{
102 | "hid_layers":[
103 | 256,
104 | 1024,
105 | 4096
106 | ],
107 | "activation_fn":"mish",
108 | "drop_out":0.5,
109 | "end_with_softmax":false
110 | }
111 | }
112 | }
--------------------------------------------------------------------------------
/config/bird/bird_v7_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v7",
8 | "train":{
9 | "dir":"data/v7/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":16,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v7/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v7/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":10,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRD",
57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v7_env0_opt2/end",
58 | "loss":"mse",
59 | "do":256,
60 | "d":16,
61 | "t":64,
62 | "TLinear1":{
63 | "hid_layers":[
64 | 8,
65 | 32
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":4,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "FLinear":{
84 | "hid_layers":[
85 | 16,
86 | 32
87 | ],
88 | "activation_fn":"mish",
89 | "drop_out":0.5,
90 | "end_with_softmax":false
91 | },
92 | "RLinear":{
93 | "hid_layers":[
94 | 16,
95 | 8
96 | ],
97 | "activation_fn":"mish",
98 | "drop_out":0.5,
99 | "end_with_softmax":false
100 | },
101 | "TLinear2":{
102 | "hid_layers":[
103 | 256,
104 | 1024,
105 | 4096
106 | ],
107 | "activation_fn":"mish",
108 | "drop_out":0.5,
109 | "end_with_softmax":false
110 | }
111 | }
112 | }
--------------------------------------------------------------------------------
/config/bird/bird_v8_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":6,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v8",
8 | "train":{
9 | "dir":"data/v8/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":16,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v8/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v8/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":24,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":10,
48 | "valid_metrics":"loss",
49 | "valid_metrics_less":true,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRD",
57 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_v8_env0_opt1/best",
58 | "loss":"mse",
59 | "do":256,
60 | "d":16,
61 | "t":64,
62 | "TLinear1":{
63 | "hid_layers":[
64 | 8,
65 | 32
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":4,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "FLinear":{
84 | "hid_layers":[
85 | 16,
86 | 32
87 | ],
88 | "activation_fn":"mish",
89 | "drop_out":0.5,
90 | "end_with_softmax":false
91 | },
92 | "RLinear":{
93 | "hid_layers":[
94 | 16,
95 | 8
96 | ],
97 | "activation_fn":"mish",
98 | "drop_out":0.5,
99 | "end_with_softmax":false
100 | },
101 | "TLinear2":{
102 | "hid_layers":[
103 | 256,
104 | 1024,
105 | 4096
106 | ],
107 | "activation_fn":"mish",
108 | "drop_out":0.5,
109 | "end_with_softmax":false
110 | }
111 | }
112 | }
--------------------------------------------------------------------------------
/data/augmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def data_augmentation(batch, aug_cfg, device):
5 | def _augmentation(amps, ids, envs, s_aug_cfg, device):
6 | if s_aug_cfg['type'].lower() == 'gaussian_noise':
7 | amps_aug, ids_aug, envs_aug = add_gaussian_noise(amps, ids, envs, s_aug_cfg, device);
8 | elif s_aug_cfg['type'].lower() == 'multipath_fading':
9 | amps_aug, ids_aug, envs_aug = mutipath_fading(amps, ids, envs, s_aug_cfg, device);
10 | elif s_aug_cfg['type'].lower() == 'freq_selective_fading':
11 | amps_aug, ids_aug, envs_aug = freq_selective_fading(amps, ids, envs, s_aug_cfg, device);
12 | elif s_aug_cfg['type'].lower() == 'antenna_sequence':
13 | amps_aug, ids_aug, envs_aug = antenna_sequence(amps, ids, envs, s_aug_cfg, device);
14 | elif s_aug_cfg['type'].lower() == 'subcarrier_mask':
15 | amps_aug, ids_aug, envs_aug = subcarrier_mask(amps, ids, envs, s_aug_cfg, device);
16 | else:
17 | raise RuntimeError;
18 |
19 | return amps_aug, ids_aug, envs_aug
20 | amps, ids, envs = batch;
21 | new_amps, new_ids, new_envs = amps.clone(), ids.clone(), envs.clone();
22 | for s_aug_cfg in aug_cfg.values():
23 | amps_aug, ids_aug, envs_aug = _augmentation(amps.clone(), ids, envs, s_aug_cfg, device);
24 | new_amps = torch.cat((new_amps, amps_aug), dim = 0);
25 | new_ids = torch.cat((new_ids, ids_aug), dim = 0);
26 | new_envs = torch.cat((new_envs, envs_aug), dim = 0);
27 | return new_amps, new_ids, new_envs
28 |
29 |
30 | def add_gaussian_noise(amps, ids, envs, s_aug_cfg, device):
31 | snr_db = s_aug_cfg['snr_db'];
32 | amps_n = amps + torch.normal(0, torch.sqrt(torch.mean(amps**2) / (10**(snr_db/10))), amps.shape, dtype = torch.float32, device = device);
33 | return amps_n, ids, envs;
34 |
35 | def mutipath_fading(amps, ids, envs, s_aug_cfg, device):
36 | delay = s_aug_cfg['delay'];
37 | coef = s_aug_cfg['coef'];
38 | n = ids.shape[0];
39 | mp = np.random.choice(len(delay), n);
40 | amps_mp = amps.clone();
41 | for i in range(n):
42 | d = delay[mp[i]];
43 | c = coef[mp[i]];
44 | for j in range(len(d)):
45 | amps_mp[i, d[j]:, :] = amps_mp[i, d[j]:, :] + c[j] * amps[i, :-d[j], :];
46 | return amps_mp, ids, envs;
47 |
48 | def freq_selective_fading(amps, ids, envs, s_aug_cfg, device):
49 | #amps:[N, T, R, F]
50 | num_f = amps.shape[-1];
51 | scopes = s_aug_cfg['scope'];
52 | scope = scopes[np.random.choice(len(scopes), 1, replace = False)[0]];
53 | coefs = torch.linspace(start = scope[0], end = scope[1], steps = num_f, device = device);
54 | return amps * coefs, ids, envs;
55 |
56 | def antenna_sequence(amps, ids, envs, s_aug_cfg, device):
57 | #amps:[N, T, R, F]
58 | R = amps.shape[2];
59 | ant_idxs = np.random.choice(R, R, replace = False);
60 | amps_ant = amps[:, :, ant_idxs, :];
61 | return amps_ant, ids, envs;
62 |
63 | def subcarrier_mask(amps, ids, envs, s_aug_cfg, device):
64 | #amps:[N, T, R, F]
65 | n_sc = int(s_aug_cfg['ratio'] * amps.shape[-1]);
66 | sc_idxs = np.random.choice(amps.shape[-1], n_sc, replace = False);
67 | amps[:, :, :, sc_idxs] = torch.zeros_like(amps[:, :, :, :n_sc]);
68 | return amps, ids, envs;
69 |
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v2_wo_aug+al.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v2",
8 | "train":{
9 | "dir":"data/v2/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v2/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v2/test_legal.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRDEncoder",
57 | "is_norm_first":true,
58 | "d":64,
59 | "t":64,
60 | "do":256,
61 | "TLinear":{
62 | "hid_layers":[
63 | 4096,
64 | 1024,
65 | 256
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":2,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "trans2":{
84 | "max_len":600,
85 | "is_norm_first":true,
86 | "d_fc":2048,
87 | "n_heads": 4,
88 | "n_layers": 5
89 | },
90 | "tail_net":{
91 | "hid_layers":[
92 | 1024,
93 | 512,
94 | 256
95 | ],
96 | "activation_fn":"mish",
97 | "drop_out":0.5,
98 | "end_with_softmax":false
99 | },
100 | "p_classifier":{
101 | "hid_layers":[
102 | 128,
103 | 64
104 | ],
105 | "activation_fn":"mish",
106 | "drop_out":0.5,
107 | "end_with_softmax":true
108 | },
109 | "env_classifier":{
110 | "hid_layers":[
111 | 128,
112 | 64
113 | ],
114 | "activation_fn":"mish",
115 | "drop_out":0.5,
116 | "end_with_softmax":true
117 | },
118 | "lambda_":1
119 | }
120 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v4_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env0.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":50,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRDEncoder",
57 | "is_norm_first":true,
58 | "d":64,
59 | "t":64,
60 | "do":256,
61 | "TLinear":{
62 | "hid_layers":[
63 | 4096,
64 | 1024,
65 | 256
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":2,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "trans2":{
84 | "max_len":600,
85 | "is_norm_first":true,
86 | "d_fc":2048,
87 | "n_heads": 4,
88 | "n_layers": 5
89 | },
90 | "tail_net":{
91 | "hid_layers":[
92 | 1024,
93 | 512,
94 | 256
95 | ],
96 | "activation_fn":"mish",
97 | "drop_out":0.5,
98 | "end_with_softmax":false
99 | },
100 | "p_classifier":{
101 | "hid_layers":[
102 | 128,
103 | 64
104 | ],
105 | "activation_fn":"mish",
106 | "drop_out":0.5,
107 | "end_with_softmax":true
108 | },
109 | "env_classifier":{
110 | "hid_layers":[
111 | 128,
112 | 64
113 | ],
114 | "activation_fn":"mish",
115 | "drop_out":0.5,
116 | "end_with_softmax":true
117 | },
118 | "lambda_":-1
119 | }
120 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v4_env1_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v4/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v4/test_env1.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":5,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRDEncoder",
57 | "is_norm_first":true,
58 | "d":64,
59 | "t":64,
60 | "do":256,
61 | "TLinear":{
62 | "hid_layers":[
63 | 4096,
64 | 1024,
65 | 256
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":2,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "trans2":{
84 | "max_len":600,
85 | "is_norm_first":true,
86 | "d_fc":2048,
87 | "n_heads": 4,
88 | "n_layers": 5
89 | },
90 | "tail_net":{
91 | "hid_layers":[
92 | 1024,
93 | 512,
94 | 256
95 | ],
96 | "activation_fn":"mish",
97 | "drop_out":0.5,
98 | "end_with_softmax":false
99 | },
100 | "p_classifier":{
101 | "hid_layers":[
102 | 128,
103 | 64
104 | ],
105 | "activation_fn":"mish",
106 | "drop_out":0.5,
107 | "end_with_softmax":true
108 | },
109 | "env_classifier":{
110 | "hid_layers":[
111 | 128,
112 | 64
113 | ],
114 | "activation_fn":"mish",
115 | "drop_out":0.5,
116 | "end_with_softmax":true
117 | },
118 | "lambda_":-1
119 | }
120 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v5_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env0_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":100,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRDEncoder",
57 | "is_norm_first":true,
58 | "d":64,
59 | "t":64,
60 | "do":256,
61 | "TLinear":{
62 | "hid_layers":[
63 | 4096,
64 | 1024,
65 | 256
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":2,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "trans2":{
84 | "max_len":600,
85 | "is_norm_first":true,
86 | "d_fc":2048,
87 | "n_heads": 4,
88 | "n_layers": 5
89 | },
90 | "tail_net":{
91 | "hid_layers":[
92 | 1024,
93 | 512,
94 | 256
95 | ],
96 | "activation_fn":"mish",
97 | "drop_out":0.5,
98 | "end_with_softmax":false
99 | },
100 | "p_classifier":{
101 | "hid_layers":[
102 | 128,
103 | 64
104 | ],
105 | "activation_fn":"mish",
106 | "drop_out":0.5,
107 | "end_with_softmax":true
108 | },
109 | "env_classifier":{
110 | "hid_layers":[
111 | 128,
112 | 64
113 | ],
114 | "activation_fn":"mish",
115 | "drop_out":0.5,
116 | "end_with_softmax":true
117 | },
118 | "lambda_":-1
119 | }
120 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v5_env1_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":9,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v5",
8 | "train":{
9 | "dir":"data/v5/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v5/valid_env1.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v5/test_env1_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":50,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRDEncoder",
57 | "is_norm_first":true,
58 | "d":64,
59 | "t":64,
60 | "do":256,
61 | "TLinear":{
62 | "hid_layers":[
63 | 4096,
64 | 1024,
65 | 256
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":2,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "trans2":{
84 | "max_len":600,
85 | "is_norm_first":true,
86 | "d_fc":2048,
87 | "n_heads": 4,
88 | "n_layers": 5
89 | },
90 | "tail_net":{
91 | "hid_layers":[
92 | 1024,
93 | 512,
94 | 256
95 | ],
96 | "activation_fn":"mish",
97 | "drop_out":0.5,
98 | "end_with_softmax":false
99 | },
100 | "p_classifier":{
101 | "hid_layers":[
102 | 128,
103 | 64
104 | ],
105 | "activation_fn":"mish",
106 | "drop_out":0.5,
107 | "end_with_softmax":true
108 | },
109 | "env_classifier":{
110 | "hid_layers":[
111 | 128,
112 | 64
113 | ],
114 | "activation_fn":"mish",
115 | "drop_out":0.5,
116 | "end_with_softmax":true
117 | },
118 | "lambda_":-1
119 | }
120 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v6_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":8,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v6",
8 | "train":{
9 | "dir":"data/v6/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v6/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v6/test_env0_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRDEncoder",
57 | "is_norm_first":true,
58 | "d":64,
59 | "t":64,
60 | "do":256,
61 | "TLinear":{
62 | "hid_layers":[
63 | 4096,
64 | 1024,
65 | 256
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":2,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "trans2":{
84 | "max_len":600,
85 | "is_norm_first":true,
86 | "d_fc":2048,
87 | "n_heads": 4,
88 | "n_layers": 5
89 | },
90 | "tail_net":{
91 | "hid_layers":[
92 | 1024,
93 | 512,
94 | 256
95 | ],
96 | "activation_fn":"mish",
97 | "drop_out":0.5,
98 | "end_with_softmax":false
99 | },
100 | "p_classifier":{
101 | "hid_layers":[
102 | 128,
103 | 64
104 | ],
105 | "activation_fn":"mish",
106 | "drop_out":0.5,
107 | "end_with_softmax":true
108 | },
109 | "env_classifier":{
110 | "hid_layers":[
111 | 128,
112 | 64
113 | ],
114 | "activation_fn":"mish",
115 | "drop_out":0.5,
116 | "end_with_softmax":true
117 | },
118 | "lambda_":-1
119 | }
120 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v7_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":7,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v7",
8 | "train":{
9 | "dir":"data/v7/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v7/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v7/test_env0_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRDEncoder",
57 | "is_norm_first":true,
58 | "d":64,
59 | "t":64,
60 | "do":256,
61 | "TLinear":{
62 | "hid_layers":[
63 | 4096,
64 | 1024,
65 | 256
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":2,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "trans2":{
84 | "max_len":600,
85 | "is_norm_first":true,
86 | "d_fc":2048,
87 | "n_heads": 4,
88 | "n_layers": 5
89 | },
90 | "tail_net":{
91 | "hid_layers":[
92 | 1024,
93 | 512,
94 | 256
95 | ],
96 | "activation_fn":"mish",
97 | "drop_out":0.5,
98 | "end_with_softmax":false
99 | },
100 | "p_classifier":{
101 | "hid_layers":[
102 | 128,
103 | 64
104 | ],
105 | "activation_fn":"mish",
106 | "drop_out":0.5,
107 | "end_with_softmax":true
108 | },
109 | "env_classifier":{
110 | "hid_layers":[
111 | 128,
112 | 64
113 | ],
114 | "activation_fn":"mish",
115 | "drop_out":0.5,
116 | "end_with_softmax":true
117 | },
118 | "lambda_":-1
119 | }
120 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v8_env0_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":6,
5 | "known_env_num":1,
6 | "data":{
7 | "dataset":"v8",
8 | "train":{
9 | "dir":"data/v8/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v8/valid_env0.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "test":{
31 | "dir":"data/v8/test_env0_pre.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | }
41 | },
42 | "train":{
43 | "is_DA":false,
44 | "max_epoch":300,
45 | "valid_start_epoch":1,
46 | "valid_step":1,
47 | "stop_train_step_valid_not_improve":30,
48 | "valid_metrics":"acc",
49 | "valid_metrics_less":false,
50 | "optimizer":{
51 | "lr":1e-5,
52 | "weight_decay":1e-4
53 | }
54 | },
55 | "model":{
56 | "name":"BIRDEncoder",
57 | "is_norm_first":true,
58 | "d":64,
59 | "t":64,
60 | "do":256,
61 | "TLinear":{
62 | "hid_layers":[
63 | 4096,
64 | 1024,
65 | 256
66 | ],
67 | "activation_fn":"mish",
68 | "drop_out":0.5,
69 | "end_with_softmax":false
70 | },
71 | "csinet":{
72 | "n_refine":2,
73 | "hid_layers":[
74 | 8,
75 | 16,
76 | 32,
77 | 32,
78 | 16,
79 | 8
80 | ],
81 | "activation_fn":"relu"
82 | },
83 | "trans2":{
84 | "max_len":600,
85 | "is_norm_first":true,
86 | "d_fc":2048,
87 | "n_heads": 4,
88 | "n_layers": 5
89 | },
90 | "tail_net":{
91 | "hid_layers":[
92 | 1024,
93 | 512,
94 | 256
95 | ],
96 | "activation_fn":"mish",
97 | "drop_out":0.5,
98 | "end_with_softmax":false
99 | },
100 | "p_classifier":{
101 | "hid_layers":[
102 | 128,
103 | 64
104 | ],
105 | "activation_fn":"mish",
106 | "drop_out":0.5,
107 | "end_with_softmax":true
108 | },
109 | "env_classifier":{
110 | "hid_layers":[
111 | 128,
112 | 64
113 | ],
114 | "activation_fn":"mish",
115 | "drop_out":0.5,
116 | "end_with_softmax":true
117 | },
118 | "lambda_":-1
119 | }
120 | }
--------------------------------------------------------------------------------
/executor.py:
--------------------------------------------------------------------------------
1 | import argparse, logging, os, torch, platform, subprocess, webbrowser, time
2 | from torch.utils.tensorboard import SummaryWriter
3 | from lib import glb_var, json_util, util, callback, colortext
4 |
5 | if __name__ == '__main__':
6 | glb_var.__init__();
7 | parse = argparse.ArgumentParser();
8 | parse.add_argument('--config', '-cfg', type = str, default = None, help = 'config for run');
9 | parse.add_argument('--saved_config', '-sc', type = str, default = None, help = 'path for saved config to test')
10 | parse.add_argument('--auto_shutdown', '-as', type = bool, default = False, help = 'automatic shutdown after program completion')
11 |
12 | args = parse.parse_args();
13 | is_train, is_test = False, False;
14 | if args.config is not None:
15 | config = json_util.jsonload(args.config);
16 | is_train = True;
17 | is_test = True;
18 | elif args.saved_config is not None:
19 | config = json_util.jsonload(args.saved_config);
20 | save_dir = os.path.dirname(args.saved_config) + '/';
21 | config['save_dir'] = save_dir;
22 | is_test = True;
23 |
24 | from lib.callback import Logger
25 | DATASET = config['data']['dataset'];
26 | INFO_LEVEL = logging.INFO;
27 | MODEL_NAME = config['model']['name'];
28 |
29 | if not os.path.exists(f'./cache/logger/{DATASET}/'):
30 | os.makedirs(f'./cache/logger/{DATASET}/');
31 | if 'save_dir' not in locals():
32 | save_dir = f'./cache/save/{DATASET}/{MODEL_NAME}_{util.get_date()}_{util.get_time()}/';
33 | config['save_dir'] = save_dir;
34 | if not os.path.exists(save_dir):
35 | os.makedirs(save_dir);
36 | logger = Logger(
37 | level = INFO_LEVEL,
38 | filename = f'./cache/logger/{DATASET}/{MODEL_NAME}_{util.get_date()}_{util.get_time()}.log',
39 | ).get_log();
40 | logger.debug(f'save dir:{save_dir}');
41 | if not os.path.exists(save_dir + 'tb'):
42 | os.makedirs(save_dir + 'tb');
43 | glb_var.set_value('logger', logger);
44 | if config['gpu_is_available']:
45 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu');
46 | else:
47 | device = torch.device('cpu');
48 | glb_var.set_value('device', device);
49 | util.kill_process_on_port(6006);
50 | tb_writer = SummaryWriter(log_dir = save_dir + 'tb');
51 |
52 | tb_process = subprocess.Popen(["tensorboard", "--logdir", save_dir + 'tb', "--port", "6006", "--reload_interval", "5"])
53 | time.sleep(5);
54 | try:
55 | subprocess.run(["explorer.exe", "http://localhost:6006"])
56 | except webbrowser.Error:
57 | webbrowser.open("http://localhost:6006")
58 |
59 | glb_var.set_value('tb_writer', tb_writer);
60 | util.set_seed(config['seed']);
61 | callback.set_custom_tqdm_warning();
62 |
63 | from t2 import train, test
64 | try:
65 | if is_train:
66 | config = train.train_model(config);
67 | if is_test:
68 | test.test_model(config);
69 | except KeyboardInterrupt:
70 | logger.info(colortext.YELLOW + "Exiting program..." + colortext.RESET)
71 |
72 | tb_writer.close();
73 | if args.auto_shutdown:
74 | logger.info('Automatic shutdown.');
75 | tb_process.terminate();
76 | if platform.system().lower() == 'linux' and not util.is_wsl():
77 | os.system("shutdown -h now");
78 | elif util.is_wsl() :
79 | os.system("/mnt/c/Windows/System32/cmd.exe /c shutdown /s /t 1")
80 | else:#windows
81 | os.system("shutdown /s /t 1");
82 |
83 | logger.info("\nTensorBoard is still running.")
84 | try:
85 | input(colortext.RED + "Press Enter to stop TensorBoard and exit..." + colortext.RESET);
86 | finally:
87 | tb_process.terminate();
--------------------------------------------------------------------------------
/model/framework/csiid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lib import glb_var, util
3 | from model.framework.base import Net
4 | from model import net_util
5 |
6 | class CSIID(Net):
7 | def __init__(self, model_cfg):
8 | super().__init__(model_cfg);
9 |
10 | # First convolutional layer
11 | self.conv1 = torch.nn.Sequential(
12 | torch.nn.Conv2d(3, 30, kernel_size=(100, 3), stride=(1, 3)),
13 | torch.nn.BatchNorm2d(30),
14 | torch.nn.ReLU()
15 | )
16 |
17 | # Second convolutional layer
18 | self.conv2 = torch.nn.Sequential(
19 | torch.nn.Conv2d(30, 30, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1)),
20 | torch.nn.BatchNorm2d(30),
21 | torch.nn.ReLU()
22 | )
23 |
24 | # Third convolutional layer
25 | self.conv3 = torch.nn.Sequential(
26 | torch.nn.Conv2d(30, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 0)),
27 | torch.nn.BatchNorm2d(30),
28 | torch.nn.ReLU()
29 | )
30 |
31 | # LSTM layer
32 | self.lstm = torch.nn.LSTM(input_size=480, hidden_size=128, num_layers=5, batch_first=True)
33 |
34 | # Fully connected layer
35 | self.fc = torch.nn.Linear(128, self.known_p_num)
36 |
37 | # Softmax
38 | self.softmax = torch.nn.Softmax(dim=1)
39 |
40 | @torch.no_grad()
41 | def encoder(self, amps):
42 |
43 | x = amps.permute(0, 2, 1, 3)
44 |
45 | # Convolutional layers
46 | x = self.conv1(x)
47 | x = self.conv2(x)
48 | x = self.conv3(x)
49 |
50 | # Crop to 203*22*30
51 |
52 |
53 | # Reshape for LSTM input (batch_size, seq_len, input_size)
54 | x = x.permute(0, 2, 1, 3)
55 | x = x.reshape(x.size(0), x.size(1), -1)
56 |
57 | # LSTM layer
58 | x, _ = self.lstm(x)
59 | x = x[:, -1, :] # Take the output from the last time step
60 | return x;
61 |
62 | def p_classifier(self, features):
63 | x = self.fc(features)
64 |
65 | # Softmax
66 | x = self.softmax(x)
67 | return x;
68 |
69 | def p_classify(self, x):
70 | x = self.encoder(x);
71 | x = self.p_classifier(x);
72 | return x
73 |
74 | def cal_loss(self, amps, ids, envs):
75 | id_probs = self.p_classify(amps);
76 | return torch.nn.CrossEntropyLoss()(id_probs, ids);
77 |
78 | class CSIIDAL(CSIID):
79 | def __init__(self, model_cfg):
80 | super().__init__(model_cfg);
81 | self.lambda_ = model_cfg['lambda_'];
82 |
83 | #env layer
84 | env_cfg = model_cfg['env_classifier'];
85 | env_cfg['activation_fn'] = net_util.get_activation_fn(env_cfg['activation_fn']);
86 | env_cfg['dim_in'] = 128;
87 | env_cfg['dim_out'] = self.known_env_num;
88 | self.env_layer = util.get_func_rets(net_util.get_mlp_net, env_cfg);
89 |
90 | def cal_loss(self, amps, ids, envs, amps_t = None, ids_t = None, envs_t = None):
91 | features = self.encoder(amps);
92 | id_probs = self.p_classifier(features);
93 | loss_id = torch.nn.CrossEntropyLoss()(id_probs, ids);
94 |
95 | if amps_t is None:
96 | return loss_id;
97 |
98 | env_probs = self.env_layer(net_util.GradientReversalF.apply(features, self.lambda_));
99 | loss_env = torch.nn.CrossEntropyLoss()(env_probs, envs);
100 |
101 | loss_t = loss_id + loss_env;
102 |
103 | feature_t = self.encoder(amps_t);
104 | id_probs_t = self.p_classifier(feature_t);
105 | id_loss_t = torch.nn.CrossEntropyLoss()(id_probs_t, ids_t);
106 |
107 | return loss_t + id_loss_t;
108 |
109 | glb_var.register_model('CSIID', CSIID);
110 | glb_var.register_model('CSIIDAL', CSIIDAL);
111 |
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v2_wo_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v2",
8 | "train":{
9 | "dir":"data/v2/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | }
18 | },
19 | "valid":{
20 | "dir":"data/v2/valid.mat",
21 | "loader":{
22 | "num_workers":4,
23 | "pin_memory":true,
24 | "batch_size":8,
25 | "drop_last":false,
26 | "shuffle":true,
27 | "prefetch_factor":4
28 | }
29 | },
30 | "target":{
31 | "dir":"data/v2/target.mat",
32 | "loader":{
33 | "num_workers":4,
34 | "pin_memory":true,
35 | "batch_size":8,
36 | "drop_last":false,
37 | "shuffle":true,
38 | "prefetch_factor":4
39 | }
40 | },
41 | "test":{
42 | "dir":"data/v2/test_legal.mat",
43 | "loader":{
44 | "num_workers":4,
45 | "pin_memory":true,
46 | "batch_size":8,
47 | "drop_last":false,
48 | "shuffle":true,
49 | "prefetch_factor":4
50 | }
51 | }
52 | },
53 | "train":{
54 | "is_DA":true,
55 | "max_epoch":300,
56 | "valid_start_epoch":1,
57 | "valid_step":1,
58 | "stop_train_step_valid_not_improve":30,
59 | "valid_metrics":"acc",
60 | "valid_metrics_less":false,
61 | "optimizer":{
62 | "lr":1e-5,
63 | "weight_decay":1e-4
64 | }
65 | },
66 | "model":{
67 | "name":"BIRDEncoder",
68 | "is_norm_first":true,
69 | "d":64,
70 | "t":64,
71 | "do":256,
72 | "TLinear":{
73 | "hid_layers":[
74 | 4096,
75 | 1024,
76 | 256
77 | ],
78 | "activation_fn":"mish",
79 | "drop_out":0.5,
80 | "end_with_softmax":false
81 | },
82 | "csinet":{
83 | "n_refine":2,
84 | "hid_layers":[
85 | 8,
86 | 16,
87 | 32,
88 | 32,
89 | 16,
90 | 8
91 | ],
92 | "activation_fn":"relu"
93 | },
94 | "trans2":{
95 | "max_len":600,
96 | "is_norm_first":true,
97 | "d_fc":2048,
98 | "n_heads": 4,
99 | "n_layers": 5
100 | },
101 | "tail_net":{
102 | "hid_layers":[
103 | 1024,
104 | 512,
105 | 256
106 | ],
107 | "activation_fn":"mish",
108 | "drop_out":0.5,
109 | "end_with_softmax":false
110 | },
111 | "p_classifier":{
112 | "hid_layers":[
113 | 128,
114 | 64
115 | ],
116 | "activation_fn":"mish",
117 | "drop_out":0.5,
118 | "end_with_softmax":true
119 | },
120 | "env_classifier":{
121 | "hid_layers":[
122 | 128,
123 | 64
124 | ],
125 | "activation_fn":"mish",
126 | "drop_out":0.5,
127 | "end_with_softmax":true
128 | },
129 | "lambda_":1
130 | }
131 | }
--------------------------------------------------------------------------------
/config/bird/bird_v1_aug.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v1",
8 | "train":{
9 | "dir":"data/v1/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":24,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | },
18 | "augmentation":{
19 | "gaussian_noise":{
20 | "type":"gaussian_noise",
21 | "snr_db":30
22 | },
23 | "multipath_fading":{
24 | "type":"multipath_fading",
25 | "delay":[
26 | [50],
27 | [50, 70],
28 | [50, 70, 80]
29 | ],
30 | "coef":[
31 | [0.8],
32 | [0.8, 0.6],
33 | [0.8, 0.6, 0.5]
34 | ]
35 | },
36 | "freq_selective_fading":{
37 | "type":"freq_selective_fading",
38 | "scope":[
39 | [0.8, 0.3],
40 | [0.7, 0.2],
41 | [0.6, 0.1],
42 | [0.8, 0.1]
43 | ]
44 | },
45 | "antenna_sequence":{
46 | "type":"antenna_sequence"
47 | },
48 | "subcarrier_mask":{
49 | "type":"subcarrier_mask",
50 | "ratio":0.1
51 | }
52 | }
53 | },
54 | "valid":{
55 | "dir":"data/v1/valid.mat",
56 | "loader":{
57 | "num_workers":4,
58 | "pin_memory":true,
59 | "batch_size":8,
60 | "drop_last":false,
61 | "shuffle":true,
62 | "prefetch_factor":4
63 | }
64 | },
65 | "test":{
66 | "dir":"data/v1/test.mat",
67 | "loader":{
68 | "num_workers":4,
69 | "pin_memory":true,
70 | "batch_size":8,
71 | "drop_last":false,
72 | "shuffle":true,
73 | "prefetch_factor":4
74 | }
75 | }
76 | },
77 | "train":{
78 | "is_DA":false,
79 | "max_epoch":100,
80 | "valid_start_epoch":20,
81 | "valid_step":5,
82 | "stop_train_step_valid_not_improve":6,
83 | "valid_metrics":"loss",
84 | "valid_metrics_less":true,
85 | "optimizer":{
86 | "lr":1e-5,
87 | "weight_decay":1e-4
88 | }
89 | },
90 | "model":{
91 | "name":"BIRD",
92 | "pretrained_enc_dir":"model/pretrained/BIRDEncoder_opt4/end",
93 | "loss":"nmse",
94 | "do":256,
95 | "d":16,
96 | "t":64,
97 | "TLinear1":{
98 | "hid_layers":[
99 | 8,
100 | 32
101 | ],
102 | "activation_fn":"mish",
103 | "drop_out":0.5,
104 | "end_with_softmax":false
105 | },
106 | "csinet":{
107 | "n_refine":4,
108 | "hid_layers":[
109 | 8,
110 | 16,
111 | 32,
112 | 32,
113 | 16,
114 | 8
115 | ],
116 | "activation_fn":"relu"
117 | },
118 | "FLinear":{
119 | "hid_layers":[
120 | 16,
121 | 32
122 | ],
123 | "activation_fn":"mish",
124 | "drop_out":0.5,
125 | "end_with_softmax":false
126 | },
127 | "RLinear":{
128 | "hid_layers":[
129 | 16,
130 | 8
131 | ],
132 | "activation_fn":"mish",
133 | "drop_out":0.5,
134 | "end_with_softmax":false
135 | },
136 | "TLinear2":{
137 | "hid_layers":[
138 | 256,
139 | 1024,
140 | 4096
141 | ],
142 | "activation_fn":"mish",
143 | "drop_out":0.5,
144 | "end_with_softmax":false
145 | }
146 | }
147 | }
--------------------------------------------------------------------------------
/model/framework/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lib import glb_var, util
3 |
4 | logger = glb_var.get_value('logger');
5 |
6 | class Net(torch.nn.Module):
7 | '''Abstract Net class to define the API methods
8 | '''
9 | def __init__(self, model_cfg) -> None:
10 | super().__init__();
11 | util.set_attr(self, model_cfg, except_type = dict);
12 | self.is_Intrusion_Detection = False;
13 |
14 | def _init_para(self, module):
15 | if isinstance(module, torch.nn.Embedding):
16 | module.weight.data.normal_(mean=0.0, std=1/module.embedding_dim)
17 | elif isinstance(module, torch.nn.LayerNorm):
18 | module.bias.data.fill_(1.0)
19 | module.weight.data.fill_(1.0)
20 | elif isinstance(module, torch.nn.Linear):
21 | module.weight.data.normal_()
22 | if module.bias is not None:
23 | module.bias.data.fill_(1.0)
24 |
25 | def forward(self):
26 | '''Defines the forward pass of the network. Must be implemented by subclasses.'''
27 | logger.error('Method needs to be called after being implemented');
28 | raise NotImplementedError;
29 |
30 | def encoder(self):
31 | '''Defines the encoder part of the network. Must be implemented by subclasses.'''
32 | logger.error('Method needs to be called after being implemented');
33 | raise NotImplementedError;
34 |
35 | @torch.no_grad()
36 | def cal_accuracy(self, amps, ids, envs) -> float:
37 | '''
38 | Calculates the accuracy of the model's predictions.
39 |
40 | If `is_Intrusion_Detection` is `False`, it calculates classification accuracy by comparing
41 | the predicted class with the ground truth (`ids`). If `is_Intrusion_Detection` is `True`,
42 | it calculates accuracy for intrusion detection tasks.
43 |
44 | Returns:
45 | ---------
46 | acc : float
47 | The accuracy of the model's predictions.
48 | '''
49 | #ids:[N]
50 | #envs:[N]
51 | if not self.is_Intrusion_Detection:
52 | #p
53 | id_pred = self.p_classify(amps).argmax(dim = -1);
54 | acc = (id_pred == ids).cpu().float().mean().item();
55 | else:
56 | intrude_pred = self.intrusion_detection(amps);
57 | intrude_gt = ids >= self.known_p_num;
58 | acc = (intrude_gt == intrude_pred).cpu().float().mean().item();
59 | return acc;
60 |
61 | def p_classify(self, amps):
62 | '''Defines the id classiify part of the network. Must be implemented by subclasses.'''
63 | logger.error('Method needs to be called after being implemented');
64 | raise NotImplementedError;
65 |
66 | def train_epoch_hook(self, trainer, epoch):
67 | '''Hook for executing custom logic during each training epoch.'''
68 | pass
69 |
70 | def valid_epoch_hook(self, trainer, epoch):
71 | '''Hook for executing custom logic during each validation epoch.'''
72 | pass
73 |
74 | def pre_test_hook(self, tester):
75 | '''Hook for executing custom logic before testing the model.'''
76 | pass
77 |
78 | def cal_loss(self, amps, ids, envs, amps_t = None, ids_t = None, envs_t = None):
79 | '''Calculate the loss'''
80 | logger.error('Method needs to be called after being implemented');
81 | raise NotImplementedError;
82 |
83 | def conventional_train(self, X, Y):
84 | '''Only applicable to traditional methods'''
85 | logger.error('Method needs to be called after being implemented');
86 | raise NotImplementedError;
87 |
88 | def save(self, save_dir):
89 | '''
90 | General model saving method for pytorch model
91 |
92 | Saves the model's state dictionary to the specified directory.
93 |
94 | The model's parameters are saved as `model_state_dict.pth`.
95 | '''
96 | torch.save(self.state_dict(), save_dir + '/model_state_dict.pth');
97 | logger.info(f"Model saved to {save_dir}\n");
98 |
99 | def load(self, load_dir):
100 | '''General method for loading models
101 |
102 | Loads the model's state dictionary from the specified directory.
103 |
104 | The model's parameters are loaded from `model_state_dict.pth`.
105 | '''
106 | self.load_state_dict(torch.load(load_dir + '/model_state_dict.pth', weights_only = True));
107 | logger.info(f"Model loaded from {load_dir}\n")
--------------------------------------------------------------------------------
/model/framework/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lib import util, glb_var
3 | import numpy as np
4 | from model import net_util
5 | from model.framework.base import Net
6 |
7 | device = glb_var.get_value('device');
8 | logger = glb_var.get_value('logger');
9 |
10 | class AE(Net):
11 | def __init__(self, model_cfg) -> None:
12 | super().__init__(model_cfg);
13 |
14 | #encoder
15 | encoder_cfg = model_cfg['encoder_cfg'];
16 | encoder_cfg['activation_fn'] = net_util.get_activation_fn(encoder_cfg['activation_fn']);
17 | encoder_cfg['dim_in'] = np.prod(self.dim_in[1:]);
18 | encoder_cfg['dim_out'] = self.do;
19 | self.encnet = util.get_func_rets(net_util.get_mlp_net, encoder_cfg);
20 |
21 | #decoder
22 | decoder_cfg = model_cfg['decoder_cfg'];
23 | decoder_cfg['activation_fn'] = net_util.get_activation_fn(decoder_cfg['activation_fn']);
24 | decoder_cfg['dim_in'] = self.do;
25 | decoder_cfg['dim_out'] = np.prod(self.dim_in[1:]);
26 | self.decnet = util.get_func_rets(net_util.get_mlp_net, decoder_cfg);
27 |
28 | self.loss_func = net_util.get_loss_func(model_cfg['loss']);
29 | self.threshold = .0;
30 | self.thresholds = dict();
31 |
32 | self.is_Intrusion_Detection = True;
33 |
34 | def encoder(self, amps):
35 | return self.encnet(amps.flatten(1, -1));
36 |
37 | def reconstruct(self, amps):
38 | #amps[B, T, R, F]
39 | feature = self.encoder(amps);
40 | csi_rcst = self.decnet(feature).reshape(amps.shape);
41 | return csi_rcst;
42 |
43 | def cal_loss(self, amps, ids, envs, keep_batch = False):
44 | #amps:[N, T, R, F]
45 | #ids:[N]
46 | #envs:[N]
47 | csi_rcst = self.reconstruct(amps);
48 | loss = self.loss_func(csi_rcst, amps, keep_batch);
49 | return loss;
50 |
51 | @torch.no_grad()
52 | def update_thresholds(self, loader, threshold_percents):
53 | self.thresholds = dict();
54 | loader.disable_aug();
55 | self.eval();
56 | rcst_errors = torch.zeros(0, device = device);
57 | for amps, ids, envs in iter(loader):
58 | rcst_errors = torch.cat((
59 | rcst_errors,
60 | self.cal_loss(amps, ids, envs, keep_batch = True)
61 | ));
62 | for percent in threshold_percents:
63 | self.thresholds[percent] = np.percentile(rcst_errors.cpu(), percent)
64 | logger.debug(f'Updated threshold({percent:.2f}%): {self.threshold}');
65 |
66 | def set_threshold(self, percent):
67 | self.threshold = self.thresholds[percent];
68 |
69 | def intrusion_detection(self, amps):
70 | '''
71 | Rrturns:
72 | --------
73 | Is it an intruder.
74 | '''
75 | rcst_errors = self.cal_loss(amps, None, None, keep_batch = True);
76 | return rcst_errors >= self.threshold
77 |
78 | class CAE(AE):
79 | def __init__(self, model_cfg) -> None:
80 | Net.__init__(self, model_cfg);
81 |
82 | #encoder
83 | layers = [self.dim_in[2]] + model_cfg['encoder_hid_layers'];
84 | self.encnet = torch.nn.Sequential(*[
85 | torch.nn.Sequential(
86 | torch.nn.Conv2d(layers[i], layers[i + 1], kernel_size = 3, stride = 2, padding = 1),
87 | torch.nn.ReLU()
88 | )
89 | for i in range(len(layers) - 1)
90 | ]);
91 |
92 | #decoder
93 | layers = model_cfg['decoder_hid_layers'] + [self.dim_in[2]];
94 | self.decnet = torch.nn.Sequential(*[
95 | torch.nn.Sequential(
96 | torch.nn.ConvTranspose2d(layers[i], layers[i + 1], kernel_size=3, stride=2, padding=1, output_padding=(1, 1)),
97 | torch.nn.Sigmoid() if i + 2 == len(layers) else torch.nn.ReLU()
98 | )
99 | for i in range(len(layers) - 1)
100 | ]);
101 |
102 | self.loss_func = net_util.get_loss_func(model_cfg['loss']);
103 | self.threshold = .0;
104 | self.thresholds = dict();
105 |
106 | self.is_Intrusion_Detection = True;
107 |
108 | def encoder(self, amps):
109 | return self.encnet(amps.permute(0, 2, 1, 3));
110 |
111 | def reconstruct(self, amps):
112 | #amps[B, T, R, F]
113 | feature = self.encoder(amps);
114 | csi_rcst = self.decnet(feature).permute(0, 2, 1, 3);
115 | return csi_rcst;
116 |
117 | glb_var.register_model('AE', AE);
118 | glb_var.register_model('CAE', CAE);
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v2_wo_al.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2025,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v2",
8 | "train":{
9 | "dir":"data/v2/train.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | },
18 | "augmentation":{
19 | "gaussian_noise":{
20 | "type":"gaussian_noise",
21 | "snr_db":30
22 | },
23 | "multipath_fading":{
24 | "type":"multipath_fading",
25 | "delay":[
26 | [50],
27 | [50, 70],
28 | [50, 70, 80]
29 | ],
30 | "coef":[
31 | [0.8],
32 | [0.8, 0.6],
33 | [0.8, 0.6, 0.5]
34 | ]
35 | },
36 | "freq_selective_fading":{
37 | "type":"freq_selective_fading",
38 | "scope":[
39 | [0.8, 0.3],
40 | [0.7, 0.2],
41 | [0.6, 0.1],
42 | [0.8, 0.1]
43 | ]
44 | },
45 | "antenna_sequence":{
46 | "type":"antenna_sequence"
47 | },
48 | "subcarrier_mask":{
49 | "type":"subcarrier_mask",
50 | "ratio":0.1
51 | }
52 | }
53 | },
54 | "valid":{
55 | "dir":"data/v2/valid.mat",
56 | "loader":{
57 | "num_workers":4,
58 | "pin_memory":true,
59 | "batch_size":8,
60 | "drop_last":false,
61 | "shuffle":true,
62 | "prefetch_factor":4
63 | }
64 | },
65 | "test":{
66 | "dir":"data/v2/test_legal.mat",
67 | "loader":{
68 | "num_workers":4,
69 | "pin_memory":true,
70 | "batch_size":8,
71 | "drop_last":false,
72 | "shuffle":true,
73 | "prefetch_factor":4
74 | }
75 | }
76 | },
77 | "train":{
78 | "is_DA":false,
79 | "max_epoch":100,
80 | "valid_start_epoch":1,
81 | "valid_step":1,
82 | "stop_train_step_valid_not_improve":10,
83 | "valid_metrics":"acc",
84 | "valid_metrics_less":false,
85 | "optimizer":{
86 | "lr":1e-5,
87 | "weight_decay":1e-4
88 | }
89 | },
90 | "model":{
91 | "name":"BIRDEncoder",
92 | "is_norm_first":true,
93 | "d":64,
94 | "t":64,
95 | "do":256,
96 | "TLinear":{
97 | "hid_layers":[
98 | 4096,
99 | 1024,
100 | 256
101 | ],
102 | "activation_fn":"mish",
103 | "drop_out":0.5,
104 | "end_with_softmax":false
105 | },
106 | "csinet":{
107 | "n_refine":2,
108 | "hid_layers":[
109 | 8,
110 | 16,
111 | 32,
112 | 32,
113 | 16,
114 | 8
115 | ],
116 | "activation_fn":"relu"
117 | },
118 | "trans2":{
119 | "max_len":600,
120 | "is_norm_first":true,
121 | "d_fc":2048,
122 | "n_heads": 4,
123 | "n_layers": 5
124 | },
125 | "tail_net":{
126 | "hid_layers":[
127 | 1024,
128 | 512,
129 | 256
130 | ],
131 | "activation_fn":"mish",
132 | "drop_out":0.5,
133 | "end_with_softmax":false
134 | },
135 | "p_classifier":{
136 | "hid_layers":[
137 | 128,
138 | 64
139 | ],
140 | "activation_fn":"mish",
141 | "drop_out":0.5,
142 | "end_with_softmax":true
143 | },
144 | "env_classifier":{
145 | "hid_layers":[
146 | 128,
147 | 64
148 | ],
149 | "activation_fn":"mish",
150 | "drop_out":0.5,
151 | "end_with_softmax":true
152 | },
153 | "lambda_":0
154 | }
155 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v4_env0.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env0.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | },
18 | "augmentation":{
19 | "gaussian_noise":{
20 | "type":"gaussian_noise",
21 | "snr_db":30
22 | },
23 | "multipath_fading":{
24 | "type":"multipath_fading",
25 | "delay":[
26 | [50],
27 | [50, 70],
28 | [50, 70, 80]
29 | ],
30 | "coef":[
31 | [0.8],
32 | [0.8, 0.6],
33 | [0.8, 0.6, 0.5]
34 | ]
35 | },
36 | "freq_selective_fading":{
37 | "type":"freq_selective_fading",
38 | "scope":[
39 | [0.8, 0.3],
40 | [0.7, 0.2],
41 | [0.6, 0.1],
42 | [0.8, 0.1]
43 | ]
44 | },
45 | "antenna_sequence":{
46 | "type":"antenna_sequence"
47 | },
48 | "subcarrier_mask":{
49 | "type":"subcarrier_mask",
50 | "ratio":0.1
51 | }
52 | }
53 | },
54 | "valid":{
55 | "dir":"data/v4/valid_env0.mat",
56 | "loader":{
57 | "num_workers":4,
58 | "pin_memory":true,
59 | "batch_size":8,
60 | "drop_last":false,
61 | "shuffle":true,
62 | "prefetch_factor":4
63 | }
64 | },
65 | "test":{
66 | "dir":"data/v4/test_env0.mat",
67 | "loader":{
68 | "num_workers":4,
69 | "pin_memory":true,
70 | "batch_size":8,
71 | "drop_last":false,
72 | "shuffle":true,
73 | "prefetch_factor":4
74 | }
75 | }
76 | },
77 | "train":{
78 | "is_DA":false,
79 | "max_epoch":300,
80 | "valid_start_epoch":1,
81 | "valid_step":1,
82 | "stop_train_step_valid_not_improve":50,
83 | "valid_metrics":"acc",
84 | "valid_metrics_less":false,
85 | "optimizer":{
86 | "lr":1e-5,
87 | "weight_decay":1e-4
88 | }
89 | },
90 | "model":{
91 | "name":"BIRDEncoder",
92 | "is_norm_first":true,
93 | "d":64,
94 | "t":64,
95 | "do":256,
96 | "TLinear":{
97 | "hid_layers":[
98 | 4096,
99 | 1024,
100 | 256
101 | ],
102 | "activation_fn":"mish",
103 | "drop_out":0.5,
104 | "end_with_softmax":false
105 | },
106 | "csinet":{
107 | "n_refine":2,
108 | "hid_layers":[
109 | 8,
110 | 16,
111 | 32,
112 | 32,
113 | 16,
114 | 8
115 | ],
116 | "activation_fn":"relu"
117 | },
118 | "trans2":{
119 | "max_len":600,
120 | "is_norm_first":true,
121 | "d_fc":2048,
122 | "n_heads": 4,
123 | "n_layers": 5
124 | },
125 | "tail_net":{
126 | "hid_layers":[
127 | 1024,
128 | 512,
129 | 256
130 | ],
131 | "activation_fn":"mish",
132 | "drop_out":0.5,
133 | "end_with_softmax":false
134 | },
135 | "p_classifier":{
136 | "hid_layers":[
137 | 128,
138 | 64
139 | ],
140 | "activation_fn":"mish",
141 | "drop_out":0.5,
142 | "end_with_softmax":true
143 | },
144 | "env_classifier":{
145 | "hid_layers":[
146 | 128,
147 | 64
148 | ],
149 | "activation_fn":"mish",
150 | "drop_out":0.5,
151 | "end_with_softmax":true
152 | },
153 | "lambda_":-1
154 | }
155 | }
--------------------------------------------------------------------------------
/config/bird/bird_encoder_v4_env1.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed":2024,
3 | "gpu_is_available":true,
4 | "known_p_num":10,
5 | "known_env_num":2,
6 | "data":{
7 | "dataset":"v4",
8 | "train":{
9 | "dir":"data/v4/train_env1.mat",
10 | "loader":{
11 | "num_workers":4,
12 | "pin_memory":true,
13 | "batch_size":8,
14 | "drop_last":false,
15 | "shuffle":true,
16 | "prefetch_factor":4
17 | },
18 | "augmentation":{
19 | "gaussian_noise":{
20 | "type":"gaussian_noise",
21 | "snr_db":30
22 | },
23 | "multipath_fading":{
24 | "type":"multipath_fading",
25 | "delay":[
26 | [50],
27 | [50, 70],
28 | [50, 70, 80]
29 | ],
30 | "coef":[
31 | [0.8],
32 | [0.8, 0.6],
33 | [0.8, 0.6, 0.5]
34 | ]
35 | },
36 | "freq_selective_fading":{
37 | "type":"freq_selective_fading",
38 | "scope":[
39 | [0.8, 0.3],
40 | [0.7, 0.2],
41 | [0.6, 0.1],
42 | [0.8, 0.1]
43 | ]
44 | },
45 | "antenna_sequence":{
46 | "type":"antenna_sequence"
47 | },
48 | "subcarrier_mask":{
49 | "type":"subcarrier_mask",
50 | "ratio":0.1
51 | }
52 | }
53 | },
54 | "valid":{
55 | "dir":"data/v4/valid_env1.mat",
56 | "loader":{
57 | "num_workers":4,
58 | "pin_memory":true,
59 | "batch_size":8,
60 | "drop_last":false,
61 | "shuffle":true,
62 | "prefetch_factor":4
63 | }
64 | },
65 | "test":{
66 | "dir":"data/v4/test_env1.mat",
67 | "loader":{
68 | "num_workers":4,
69 | "pin_memory":true,
70 | "batch_size":8,
71 | "drop_last":false,
72 | "shuffle":true,
73 | "prefetch_factor":4
74 | }
75 | }
76 | },
77 | "train":{
78 | "is_DA":false,
79 | "max_epoch":300,
80 | "valid_start_epoch":5,
81 | "valid_step":1,
82 | "stop_train_step_valid_not_improve":30,
83 | "valid_metrics":"acc",
84 | "valid_metrics_less":false,
85 | "optimizer":{
86 | "lr":1e-5,
87 | "weight_decay":1e-4
88 | }
89 | },
90 | "model":{
91 | "name":"BIRDEncoder",
92 | "is_norm_first":true,
93 | "d":64,
94 | "t":64,
95 | "do":256,
96 | "TLinear":{
97 | "hid_layers":[
98 | 4096,
99 | 1024,
100 | 256
101 | ],
102 | "activation_fn":"mish",
103 | "drop_out":0.5,
104 | "end_with_softmax":false
105 | },
106 | "csinet":{
107 | "n_refine":2,
108 | "hid_layers":[
109 | 8,
110 | 16,
111 | 32,
112 | 32,
113 | 16,
114 | 8
115 | ],
116 | "activation_fn":"relu"
117 | },
118 | "trans2":{
119 | "max_len":600,
120 | "is_norm_first":true,
121 | "d_fc":2048,
122 | "n_heads": 4,
123 | "n_layers": 5
124 | },
125 | "tail_net":{
126 | "hid_layers":[
127 | 1024,
128 | 512,
129 | 256
130 | ],
131 | "activation_fn":"mish",
132 | "drop_out":0.5,
133 | "end_with_softmax":false
134 | },
135 | "p_classifier":{
136 | "hid_layers":[
137 | 128,
138 | 64
139 | ],
140 | "activation_fn":"mish",
141 | "drop_out":0.5,
142 | "end_with_softmax":true
143 | },
144 | "env_classifier":{
145 | "hid_layers":[
146 | 128,
147 | 64
148 | ],
149 | "activation_fn":"mish",
150 | "drop_out":0.5,
151 | "end_with_softmax":true
152 | },
153 | "lambda_":-1
154 | }
155 | }
--------------------------------------------------------------------------------