├── .gitignore
├── .idea
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── label_smoothing.iml
├── misc.xml
├── modules.xml
└── workspace.xml
├── README.md
├── checkpoints
└── .gitignore
├── data
└── .gitignore
├── feature
└── .gitignore
├── generate_feature.py
├── log_plot.py
├── logs
└── .gitignore
├── lsr.py
├── nn.py
├── png
├── .gitignore
├── join1.png
└── join2.png
├── progressbar.py
├── run.py
├── tools.py
├── trainingmonitor.py
└── tsne_plot.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/label_smoothing.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 | 1576460244964
49 |
50 |
51 | 1576460244964
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Label Smoothing Pytorch
2 |
3 | This repository contains a PyTorch implementation of the Label Smoothing.
4 |
5 | ## Dependencies
6 |
7 | * PyTorch
8 | * torchvision
9 | * matplotlib
10 | * scikit-learn
11 |
12 | ## Example
13 |
14 | To produce th result, we use CIFAR-10 dataset for ResNet18.
15 |
16 | ```python
17 | # no label smoothing
18 | python run.py
19 |
20 | # use label smoothing
21 | python run.py --do_lsr
22 |
23 | # extract feature
24 | python generate_feature.py
25 |
26 | python generate_feature.py --do_lsr
27 |
28 | #----------- plot tsne
29 | python tsne_plot.py
30 |
31 | python tsne_plot.py --do_lsr
32 |
33 | ```
34 | ## Results
35 |
36 | Training result
37 |
38 | 
39 |
40 | TSNE Visualisation
41 |
42 | 
43 |
44 |
45 |
--------------------------------------------------------------------------------
/checkpoints/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/feature/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/generate_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | import numpy as np
5 | from nn import ResNet18
6 | from torch.utils.data import DataLoader
7 | from progressbar import ProgressBar
8 | from torchvision import datasets, transforms
9 | from tools import seed_everything
10 |
11 | class ExtractFeature(nn.Module):
12 | def __init__(self, pretrained):
13 | super(ExtractFeature ,self).__init__()
14 | self.pretrained = pretrained
15 | self._reset_model()
16 |
17 | def _reset_model(self):
18 | model = ResNet18()
19 | model.load_state_dict(torch.load(model_path))
20 | self._features = nn.Sequential(
21 | model.conv1,
22 | model.layer1,
23 | model.layer2,
24 | model.layer3,
25 | model.layer4,
26 | model.avgpool
27 | )
28 | def forward(self,inputs):
29 | out = self._features(inputs)
30 | out = torch.flatten(out,1)
31 | return out
32 |
33 | def generate_feature(data_loader):
34 | extract_feature.eval()
35 | out_target = []
36 | out_output =[]
37 | pbar = ProgressBar(n_total=len(data_loader), desc='GenerateFeature')
38 | for batch_idx,(data, target) in enumerate(data_loader):
39 | data, target = data.to(device), target.to(device)
40 | output = extract_feature(data)
41 | output_np = output.data.cpu().numpy()
42 | target_np = target.data.cpu().numpy()
43 |
44 | out_output.append(output_np)
45 | out_target.append(target_np[:, np.newaxis])
46 | pbar(step=batch_idx)
47 | output_array = np.concatenate(out_output, axis=0)
48 | target_array = np.concatenate(out_target, axis=0)
49 | np.save(f'./feature/{arch}_feature.npy', output_array, allow_pickle=False)
50 | np.save(f'./feature/{arch}_target.npy', target_array, allow_pickle=False)
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser(description='CIFAR10')
54 | parser.add_argument("--model", type=str, default='ResNet18')
55 | parser.add_argument('--seed',type=int,default=42)
56 | parser.add_argument('--epoch',type=int,default=30)
57 | parser.add_argument('--batch_size',type=int,default=128)
58 | parser.add_argument("--task", type=str, default='image')
59 | parser.add_argument("--do_lsr", action='store_true',help="Whether to do label smoothing.")
60 | args = parser.parse_args()
61 | seed_everything(args.seed)
62 |
63 | if args.do_lsr:
64 | arch = args.model+'_label_smoothing'
65 | else:
66 | arch = args.model
67 |
68 | model_path = f"./checkpoints/{arch}.bin"
69 | extract_feature = ExtractFeature(model_path)
70 | device = torch.device("cuda:0")
71 | extract_feature.to(device)
72 |
73 | data = {
74 | 'valid': datasets.CIFAR10(
75 | root='./data', train=False, download=True,
76 | transform=transforms.Compose([
77 | transforms.ToTensor(),
78 | transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))]
79 | )
80 | )
81 | }
82 |
83 | loaders = {
84 | 'valid': DataLoader(data['valid'], batch_size=128,
85 | num_workers=10, pin_memory=True,
86 | drop_last=False)
87 | }
88 | generate_feature(loaders['valid'])
89 |
--------------------------------------------------------------------------------
/log_plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from tools import load_json
4 | plt.switch_backend('agg') # 防止ssh上绘图问题
5 |
6 | data1 = load_json('./logs/ResNet18_training_monitor.json')
7 | data2 = load_json('./logs/ResNet18_label_smoothing_training_monitor.json')
8 |
9 | N = np.arange(0, len(data1['loss']))
10 | plt.style.use("ggplot")
11 | plt.figure()
12 | plt.plot(N, data1['loss'], label=f"ResNet18")
13 | plt.plot(N, data2['loss'], label=f"ResNet18_label_smooth")
14 | plt.legend()
15 | plt.xlabel("Epoch #")
16 | plt.ylabel('loss')
17 | plt.title(f"Training loss [Epoch {len(data1['loss'])}]")
18 | plt.savefig('./png/training_loss.png')
19 | plt.close()
20 |
21 | N = np.arange(0, len(data1['loss']))
22 | plt.style.use("ggplot")
23 | plt.figure()
24 | plt.plot(N, data1['valid_acc'], label=f"ResNet18")
25 | plt.plot(N, data2['valid_acc'], label=f"ResNet18_label_smooth")
26 | plt.legend()
27 | plt.xlabel("Epoch #")
28 | plt.ylabel('accuracy')
29 | plt.title(f"Valid accuracy [Epoch {len(data1['loss'])}]")
30 | plt.savefig('./png/valid_accuracy.png')
31 | plt.close()
--------------------------------------------------------------------------------
/logs/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/lsr.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | class LabelSmoothingCrossEntropy(nn.Module):
5 | def __init__(self, eps=0.1, reduction='mean'):
6 | super(LabelSmoothingCrossEntropy, self).__init__()
7 | self.eps = eps
8 | self.reduction = reduction
9 |
10 | def forward(self, output, target):
11 | c = output.size()[-1]
12 | log_preds = F.log_softmax(output, dim=-1)
13 | if self.reduction=='sum':
14 | loss = -log_preds.sum()
15 | else:
16 | loss = -log_preds.sum(dim=-1)
17 | if self.reduction=='mean':
18 | loss = loss.mean()
19 | return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)
--------------------------------------------------------------------------------
/nn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, inchannel, outchannel, stride=1):
8 | super(ResidualBlock, self).__init__()
9 | self.left = nn.Sequential(
10 | nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
11 | nn.BatchNorm2d(outchannel),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
14 | nn.BatchNorm2d(outchannel)
15 | )
16 | self.shortcut = nn.Sequential()
17 | if stride != 1 or inchannel != outchannel:
18 | self.shortcut = nn.Sequential(
19 | nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
20 | nn.BatchNorm2d(outchannel)
21 | )
22 |
23 | def forward(self, x):
24 | out = self.left(x)
25 | out += self.shortcut(x)
26 | out = F.relu(out)
27 | return out
28 |
29 | class ResNet(nn.Module):
30 | def __init__(self, ResidualBlock, num_classes=10):
31 | super(ResNet, self).__init__()
32 | self.inchannel = 64
33 | self.conv1 = nn.Sequential(
34 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
35 | nn.BatchNorm2d(64),
36 | nn.ReLU(),
37 | )
38 | self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
39 | self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
40 | self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
41 | self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
42 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
43 | self.fc = nn.Linear(512, num_classes)
44 |
45 | def make_layer(self, block, channels, num_blocks, stride):
46 | strides = [stride] + [1] * (num_blocks - 1) #strides=[1,1]
47 | layers = []
48 | for stride in strides:
49 | layers.append(block(self.inchannel, channels, stride))
50 | self.inchannel = channels
51 | return nn.Sequential(*layers)
52 |
53 | def forward(self, x):
54 | out = self.conv1(x)
55 | out = self.layer1(out)
56 | out = self.layer2(out)
57 | out = self.layer3(out)
58 | out = self.layer4(out)
59 | out = self.avgpool(out)
60 | out = torch.flatten(out, 1)
61 | out = self.fc(out)
62 | return out
63 |
64 |
65 | def ResNet18():
66 |
67 | return ResNet(ResidualBlock)
--------------------------------------------------------------------------------
/png/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/png/join1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lonePatient/label_smoothing_pytorch/499abd4c7ca466c432cd6f62b58fb238ad2fb703/png/join1.png
--------------------------------------------------------------------------------
/png/join2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lonePatient/label_smoothing_pytorch/499abd4c7ca466c432cd6f62b58fb238ad2fb703/png/join2.png
--------------------------------------------------------------------------------
/progressbar.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | class ProgressBar(object):
4 | '''
5 | custom progress bar
6 | Example:
7 | >>> pbar = ProgressBar(n_total=30,desc='training')
8 | >>> step = 2
9 | >>> pbar(step=step)
10 | '''
11 | def __init__(self, n_total,width=30,desc = 'Training'):
12 | self.width = width
13 | self.n_total = n_total
14 | self.start_time = time.time()
15 | self.desc = desc
16 |
17 | def __call__(self, step, info={}):
18 | now = time.time()
19 | current = step + 1
20 | recv_per = current / self.n_total
21 | bar = f'[{self.desc}] {current}/{self.n_total} ['
22 | if recv_per >= 1:
23 | recv_per = 1
24 | prog_width = int(self.width * recv_per)
25 | if prog_width > 0:
26 | bar += '=' * (prog_width - 1)
27 | if current< self.n_total:
28 | bar += ">"
29 | else:
30 | bar += '='
31 | bar += '.' * (self.width - prog_width)
32 | bar += ']'
33 | show_bar = f"\r{bar}"
34 | time_per_unit = (now - self.start_time) / current
35 | if current < self.n_total:
36 | eta = time_per_unit * (self.n_total - current)
37 | if eta > 3600:
38 | eta_format = ('%d:%02d:%02d' %
39 | (eta // 3600, (eta % 3600) // 60, eta % 60))
40 | elif eta > 60:
41 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
42 | else:
43 | eta_format = '%ds' % eta
44 | time_info = f' - ETA: {eta_format}'
45 | else:
46 | if time_per_unit >= 1:
47 | time_info = f' {time_per_unit:.1f}s/step'
48 | elif time_per_unit >= 1e-3:
49 | time_info = f' {time_per_unit * 1e3:.1f}ms/step'
50 | else:
51 | time_info = f' {time_per_unit * 1e6:.1f}us/step'
52 |
53 | show_bar += time_info
54 | if len(info) != 0:
55 | show_info = f'{show_bar} ' + \
56 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()])
57 | print(show_info, end='')
58 | else:
59 | print(show_bar, end='')
60 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import torch.nn as nn
4 | from nn import ResNet18
5 | from progressbar import ProgressBar
6 | from torchvision import datasets, transforms
7 | from torch.utils.data import DataLoader
8 | import torch.optim as optim
9 | from trainingmonitor import TrainingMonitor
10 | from lsr import LabelSmoothingCrossEntropy
11 | from tools import save_model
12 | from tools import AverageMeter
13 | from tools import seed_everything
14 |
15 | epochs = 30
16 | batch_size = 128
17 | seed = 42
18 |
19 | seed_everything(seed)
20 | model = ResNet18()
21 |
22 | device = torch.device("cuda:0")
23 | model.to(device)
24 |
25 | parser = argparse.ArgumentParser(description='CIFAR10')
26 | parser.add_argument("--model", type=str, default='ResNet18')
27 | parser.add_argument("--task", type=str, default='image')
28 | parser.add_argument("--do_lsr", action='store_true',help="Whether to do label smoothing.")
29 | args = parser.parse_args()
30 |
31 | if args.do_lsr:
32 | arch = args.model+'_label_smoothing'
33 | loss_fn = LabelSmoothingCrossEntropy()
34 | else:
35 | arch = args.model
36 | loss_fn = nn.CrossEntropyLoss()
37 |
38 | optimizer = optim.Adam(model.parameters(), lr=0.001)
39 | train_monitor = TrainingMonitor(file_dir='./logs/',arch = arch)
40 |
41 | def train(train_loader):
42 | pbar = ProgressBar(n_total=len(train_loader),desc='Training')
43 | train_loss = AverageMeter()
44 | model.train()
45 | for batch_idx, (data, target) in enumerate(train_loader):
46 | data, target = data.to(device), target.to(device)
47 | optimizer.zero_grad()
48 | output = model(data)
49 | loss = loss_fn(output, target)
50 | loss.backward()
51 | optimizer.step()
52 | pbar(step = batch_idx,info = {'loss':loss.item()})
53 | train_loss.update(loss.item(),n =1)
54 | return {'loss':train_loss.avg}
55 |
56 | def test(test_loader):
57 | pbar = ProgressBar(n_total=len(test_loader),desc='Testing')
58 | valid_loss = AverageMeter()
59 | valid_acc = AverageMeter()
60 | model.eval()
61 | count = 0
62 | with torch.no_grad():
63 | for batch_idx,(data, target) in enumerate(test_loader):
64 | data, target = data.to(device), target.to(device)
65 | output = model(data)
66 | loss = loss_fn(output, target).item() # sum up batch loss
67 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
68 | correct = pred.eq(target.view_as(pred)).sum().item()
69 | valid_loss.update(loss,n = data.size(0))
70 | valid_acc.update(correct, n=1)
71 | count += data.size(0)
72 | pbar(step=batch_idx)
73 | return {'valid_loss':valid_loss.avg,
74 | 'valid_acc':valid_acc.sum /count}
75 |
76 | data = {
77 | 'train': datasets.CIFAR10(
78 | root='./data', download=True,
79 | transform=transforms.Compose([
80 | transforms.RandomCrop((32, 32), padding=4),
81 | transforms.RandomHorizontalFlip(),
82 | transforms.ToTensor(),
83 | transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))]
84 | )
85 | ),
86 | 'valid': datasets.CIFAR10(
87 | root='./data', train=False, download=True,
88 | transform=transforms.Compose([
89 | transforms.ToTensor(),
90 | transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))]
91 | )
92 | )
93 | }
94 |
95 | loaders = {
96 | 'train': DataLoader(data['train'], batch_size=128, shuffle=True,
97 | num_workers=10, pin_memory=True,
98 | drop_last=True),
99 | 'valid': DataLoader(data['valid'], batch_size=128,
100 | num_workers=10, pin_memory=True,
101 | drop_last=False)
102 | }
103 | best_acc = 0.0
104 | for epoch in range(1, epochs + 1):
105 | train_log = train(loaders['train'])
106 | valid_log = test(loaders['valid'])
107 | logs = dict(train_log, **valid_log)
108 | show_info = f'\nEpoch: {epoch} - ' + "-".join([f' {key}: {value:.4f} ' for key, value in logs.items()])
109 | print(show_info)
110 | train_monitor.epoch_step(logs)
111 | if logs['valid_acc'] >=best_acc:
112 | print(f"Epoch {epoch}: valid_acc improved from {best_acc:.5f} to {logs['valid_acc']:.5f}")
113 | best_acc = logs['valid_acc']
114 | save_model(model,f'./checkpoints/{arch}.bin')
115 |
116 |
117 |
--------------------------------------------------------------------------------
/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pathlib import Path
3 | import json
4 | import random
5 | import torch
6 | import os
7 | import logging
8 | import torch.nn as nn
9 |
10 | logger = logging.getLogger()
11 | def print_config(config):
12 | info = "Running with the following configs:\n"
13 | for k, v in config.items():
14 | info += f"\t{k} : {str(v)}\n"
15 | print("\n" + info + "\n")
16 | return
17 |
18 | def init_logger(log_file=None, log_file_level=logging.NOTSET):
19 | '''
20 | Example:
21 | >>> init_logger(log_file)
22 | >>> logger.info("abc'")
23 | '''
24 | if isinstance(log_file,Path):
25 | log_file = str(log_file)
26 | log_format = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
27 | datefmt='%m/%d/%Y %H:%M:%S')
28 |
29 | logger = logging.getLogger()
30 | logger.setLevel(logging.INFO)
31 | console_handler = logging.StreamHandler()
32 | console_handler.setFormatter(log_format)
33 | logger.handlers = [console_handler]
34 | if log_file and log_file != '':
35 | file_handler = logging.FileHandler(log_file)
36 | file_handler.setLevel(log_file_level)
37 | # file_handler.setFormatter(log_format)
38 | logger.addHandler(file_handler)
39 | return logger
40 |
41 | def save_json(data, file_path):
42 | '''
43 | save json
44 | :param data:
45 | :param json_path:
46 | :param file_name:
47 | :return:
48 | '''
49 | if not isinstance(file_path, Path):
50 | file_path = Path(file_path)
51 | # if isinstance(data,dict):
52 | # data = json.dumps(data)
53 | with open(str(file_path), 'w') as f:
54 | json.dump(data, f)
55 |
56 |
57 | def load_json(file_path):
58 | '''
59 | load json
60 | :param json_path:
61 | :param file_name:
62 | :return:
63 | '''
64 | if not isinstance(file_path, Path):
65 | file_path = Path(file_path)
66 | with open(str(file_path), 'r') as f:
67 | data = json.load(f)
68 | return data
69 |
70 | def save_model(model, model_path):
71 | """ 存储不含有显卡信息的state_dict或model
72 | :param model:
73 | :param model_name:
74 | :param only_param:
75 | :return:
76 | """
77 | if isinstance(model_path, Path):
78 | model_path = str(model_path)
79 | if isinstance(model, nn.DataParallel):
80 | model = model.module
81 | state_dict = model.state_dict()
82 | for key in state_dict:
83 | state_dict[key] = state_dict[key].cpu()
84 | torch.save(state_dict, model_path)
85 |
86 | def load_model(model, model_path):
87 | '''
88 | 加载模型
89 | :param model:
90 | :param model_name:
91 | :param model_path:
92 | :param only_param:
93 | :return:
94 | '''
95 | if isinstance(model_path, Path):
96 | model_path = str(model_path)
97 | logging.info(f"loading model from {str(model_path)} .")
98 | states = torch.load(model_path)
99 | state = states['state_dict']
100 | if isinstance(model, nn.DataParallel):
101 | model.module.load_state_dict(state)
102 | else:
103 | model.load_state_dict(state)
104 | return model
105 |
106 | class AverageMeter(object):
107 | '''
108 | # computes and stores the average and current value
109 | # Example:
110 | # >>> loss = AverageMeter()
111 | # >>> for step,batch in enumerate(train_data):
112 | # >>> pred = self.model(batch)
113 | # >>> raw_loss = self.metrics(pred,target)
114 | # >>> loss.update(raw_loss.item(),n = 1)
115 | # >>> cur_loss = loss.avg
116 | # '''
117 |
118 | def __init__(self):
119 | self.reset()
120 |
121 | def reset(self):
122 | self.val = 0
123 | self.avg = 0
124 | self.sum = 0
125 | self.count = 0
126 |
127 | def update(self, val, n=1):
128 | self.val = val
129 | self.sum += val * n
130 | self.count += n
131 | self.avg = self.sum / self.count
132 |
133 | def seed_everything(seed=1029):
134 | '''
135 | :param seed:
136 | :param device:
137 | :return:
138 | '''
139 | random.seed(seed)
140 | os.environ['PYTHONHASHSEED'] = str(seed)
141 | np.random.seed(seed)
142 | torch.manual_seed(seed)
143 | torch.cuda.manual_seed(seed)
144 | torch.cuda.manual_seed_all(seed)
145 | # some cudnn methods can be random even after fixing the seed
146 | # unless you tell it to be deterministic
147 | torch.backends.cudnn.deterministic = True
--------------------------------------------------------------------------------
/trainingmonitor.py:
--------------------------------------------------------------------------------
1 | # encoding:utf-8
2 | import numpy as np
3 | from pathlib import Path
4 | import matplotlib.pyplot as plt
5 | from tools import load_json
6 | from tools import save_json
7 | plt.switch_backend('agg') # 防止ssh上绘图问题
8 |
9 | class TrainingMonitor():
10 | def __init__(self, file_dir, arch, add_test=False):
11 | '''
12 | :param startAt: 重新开始训练的epoch点
13 | '''
14 | if isinstance(file_dir, Path):
15 | pass
16 | else:
17 | file_dir = Path(file_dir)
18 | file_dir.mkdir(parents=True, exist_ok=True)
19 |
20 | self.arch = arch
21 | self.file_dir = file_dir
22 | self.H = {}
23 | self.add_test = add_test
24 | self.json_path = file_dir / (arch + "_training_monitor.json")
25 |
26 | def reset(self,start_at):
27 | if start_at > 0:
28 | if self.json_path is not None:
29 | if self.json_path.exists():
30 | self.H = load_json(self.json_path)
31 | for k in self.H.keys():
32 | self.H[k] = self.H[k][:start_at]
33 |
34 | def epoch_step(self, logs={}):
35 | for (k, v) in logs.items():
36 | l = self.H.get(k, [])
37 | # np.float32会报错
38 | if not isinstance(v, np.float):
39 | v = round(float(v), 4)
40 | l.append(v)
41 | self.H[k] = l
42 |
43 | # 写入文件
44 | if self.json_path is not None:
45 | save_json(data = self.H,file_path=self.json_path)
46 |
47 | # 保存train图像
48 | if len(self.H["loss"]) == 1:
49 | self.paths = {key: self.file_dir / (self.arch + f'_{key.upper()}') for key in self.H.keys()}
50 |
51 | if len(self.H["loss"]) > 1:
52 | # 指标变化
53 | # 曲线
54 | # 需要成对出现
55 | keys = [key for key, _ in self.H.items() if '_' not in key]
56 | for key in keys:
57 | N = np.arange(0, len(self.H[key]))
58 | plt.style.use("ggplot")
59 | plt.figure()
60 | plt.plot(N, self.H[key], label=f"train_{key}")
61 | plt.plot(N, self.H[f"valid_{key}"], label=f"valid_{key}")
62 | if self.add_test:
63 | plt.plot(N, self.H[f"test_{key}"], label=f"test_{key}")
64 | plt.legend()
65 | plt.xlabel("Epoch #")
66 | plt.ylabel(key)
67 | plt.title(f"Training {key} [Epoch {len(self.H[key])}]")
68 | plt.savefig(str(self.paths[key]))
69 | plt.close()
70 |
--------------------------------------------------------------------------------
/tsne_plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import matplotlib.pyplot as plt
4 | from sklearn.manifold import TSNE
5 |
6 | parser = argparse.ArgumentParser(description='CIFAR10')
7 | parser.add_argument("--model", type=str, default='ResNet18')
8 | parser.add_argument("--task", type=str, default='image')
9 | parser.add_argument("--do_lsr", action='store_true',help="Whether to do label smoothing.")
10 | args = parser.parse_args()
11 |
12 | if args.do_lsr:
13 | arch = args.model+'_label_smoothing'
14 | else:
15 | arch = args.model
16 | feature = np.load(f'./feature/{arch}_feature.npy').astype(np.float64)
17 | target = np.load(f'./feature/{arch}_target.npy')
18 | print('target shape: ', target.shape)
19 | print('feature shape: ', feature.shape)
20 |
21 | tsne = TSNE(n_components=2, init='pca', random_state=0)
22 | output_2d = tsne.fit_transform(feature)
23 | plt.rcParams['figure.figsize'] = 10, 10
24 | plt.scatter(output_2d[:, 0], output_2d[:, 1], c= target[:,0])
25 | plt.title(f"Validation {arch} tsne")
26 | plt.savefig(f'./png/{arch}_feature_2d.png', bbox_inches='tight')
27 | plt.show()
28 |
--------------------------------------------------------------------------------