├── tllib ├── reweight │ ├── __init__.py │ └── groupdro.py ├── normalization │ ├── __init__.py │ └── mixstyle │ │ └── __init__.py ├── self_training │ ├── __init__.py │ ├── uda.py │ ├── pseudo_label.py │ └── self_ensemble.py ├── translation │ ├── __init__.py │ ├── spgan │ │ ├── __init__.py │ │ ├── loss.py │ │ └── siamese.py │ ├── cyclegan │ │ ├── __init__.py │ │ └── transform.py │ └── cycada.py ├── alignment │ ├── d_adapt │ │ ├── __init__.py │ │ └── modeling │ │ │ ├── __init__.py │ │ │ ├── roi_heads │ │ │ └── __init__.py │ │ │ ├── meta_arch │ │ │ └── __init__.py │ │ │ └── matcher.py │ ├── __init__.py │ ├── rsd.py │ ├── coral.py │ ├── bsp.py │ └── adda.py ├── vision │ ├── __init__.py │ ├── models │ │ ├── object_detection │ │ │ ├── proposal_generator │ │ │ │ └── __init__.py │ │ │ ├── backbone │ │ │ │ ├── __init__.py │ │ │ │ └── mmdetection │ │ │ │ │ └── weight_init.py │ │ │ ├── roi_heads │ │ │ │ └── __init__.py │ │ │ ├── meta_arch │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── reid │ │ │ └── __init__.py │ │ ├── segmentation │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── keypoint_detection │ │ │ └── __init__.py │ │ └── digits.py │ └── datasets │ │ ├── regression │ │ └── __init__.py │ │ ├── reid │ │ ├── __init__.py │ │ └── convert.py │ │ ├── keypoint_detection │ │ └── __init__.py │ │ ├── segmentation │ │ ├── __init__.py │ │ ├── gta5.py │ │ └── synthia.py │ │ ├── cifar.py │ │ ├── __init__.py │ │ ├── retinopathy.py │ │ ├── resisc45.py │ │ ├── food101.py │ │ ├── patchcamelyon.py │ │ ├── eurosat.py │ │ ├── _util.py │ │ ├── sun397.py │ │ └── visda2017.py ├── utils │ ├── __init__.py │ ├── analysis │ │ ├── __init__.py │ │ └── tsne.py │ └── scheduler.py ├── regularization │ ├── __init__.py │ ├── knowledge_distillation.py │ └── bss.py ├── modules │ ├── __init__.py │ ├── entropy.py │ ├── domain_discriminator.py │ ├── kernels.py │ └── gl.py ├── ranking │ ├── __init__.py │ ├── transrate.py │ ├── leep.py │ └── nce.py └── __init__.py ├── docs ├── requirements.txt ├── html.zip ├── _static │ └── images │ │ ├── DANN.png │ │ └── TransLearn.png ├── tllib │ ├── utils │ │ ├── index.rst │ │ ├── analysis.rst │ │ ├── metric.rst │ │ └── base.rst │ ├── vision │ │ ├── index.rst │ │ ├── transforms.rst │ │ └── models.rst │ ├── alignment │ │ ├── index.rst │ │ ├── statistics_matching.rst │ │ ├── hypothesis_adversarial.rst │ │ └── domain_adversarial.rst │ ├── ranking.rst │ ├── reweight.rst │ ├── modules.rst │ ├── normalization.rst │ ├── regularization.rst │ └── translation.rst ├── index.rst ├── Makefile └── make.bat ├── examples ├── model_selection │ └── requirements.txt ├── domain_adaptation │ ├── image_classification │ │ ├── requirements.txt │ │ └── fig │ │ │ ├── dann_A2W.png │ │ │ └── resnet_A2W.png │ ├── partial_domain_adaptation │ │ └── requirements.txt │ ├── re_identification │ │ ├── requirements.txt │ │ ├── baseline.sh │ │ ├── ibn.sh │ │ └── spgan.sh │ ├── wilds_poverty │ │ ├── requirements.txt │ │ ├── fig │ │ │ └── poverty_train_loss.png │ │ ├── erm.sh │ │ ├── resnet_ms.py │ │ └── README.md │ ├── wilds_text │ │ ├── requirements.txt │ │ ├── fig │ │ │ ├── amazon_train_loss.png │ │ │ └── civilcomments_train_loss.png │ │ ├── erm.sh │ │ └── README.md │ ├── object_detection │ │ ├── requirements.txt │ │ ├── d_adapt │ │ │ ├── fig │ │ │ │ ├── comparison.png │ │ │ │ └── d_adapt_pipeline.png │ │ │ ├── config │ │ │ │ ├── retinanet_R_101_FPN_voc.yaml │ │ │ │ ├── faster_rcnn_R_101_C4_cityscapes.yaml │ │ │ │ ├── faster_rcnn_R_101_C4_voc.yaml │ │ │ │ └── faster_rcnn_vgg_16_cityscapes.yaml │ │ │ └── README.md │ │ ├── visualize.sh │ │ ├── config │ │ │ ├── faster_rcnn_R_101_C4_cityscapes.yaml │ │ │ ├── faster_rcnn_R_101_C4_voc.yaml │ │ │ ├── retinanet_R_101_FPN_voc.yaml │ │ │ └── faster_rcnn_vgg_16_cityscapes.yaml │ │ └── oracle.sh │ ├── wilds_image_classification │ │ ├── requirements.txt │ │ ├── erm.sh │ │ ├── dan.sh │ │ ├── jan.sh │ │ ├── mdd.sh │ │ ├── cdan.sh │ │ ├── dann.sh │ │ └── fixmatch.sh │ ├── wilds_ogb_molpcba │ │ ├── requirements.txt │ │ ├── fig │ │ │ └── ogb-molpcba_train_loss.png │ │ ├── erm.sh │ │ └── README.md │ ├── keypoint_detection │ │ ├── fig │ │ │ └── keypoint_detection.jpg │ │ ├── regda.sh │ │ ├── regda_fast.sh │ │ └── erm.sh │ ├── semantic_segmentation │ │ ├── fig │ │ │ ├── cyclegan_fake_T.png │ │ │ ├── cyclegan_real_S.png │ │ │ ├── segmentation_image.png │ │ │ ├── segmentation_label.png │ │ │ └── segmentation_pred.png │ │ ├── advent.sh │ │ ├── fda.sh │ │ ├── erm.sh │ │ ├── cycle_gan.sh │ │ └── cycada.sh │ ├── image_regression │ │ ├── erm.sh │ │ ├── rsd.sh │ │ ├── dann.sh │ │ ├── dd.sh │ │ └── utils.py │ └── openset_domain_adaptation │ │ ├── erm.sh │ │ └── osbp.sh ├── task_adaptation │ └── image_classification │ │ ├── requirements.txt │ │ └── convert_moco_to_pretrained.py ├── semi_supervised_learning │ └── image_classification │ │ ├── requirements.txt │ │ ├── convert_moco_to_pretrained.py │ │ └── noisy_student.sh └── domain_generalization │ ├── re_identification │ ├── requirements.txt │ ├── baseline.sh │ ├── mixstyle.sh │ └── ibn.sh │ └── image_classification │ ├── requirements.txt │ ├── erm.sh │ ├── mldg.sh │ ├── coral.sh │ ├── ibn.sh │ ├── irm.sh │ ├── groupdro.sh │ ├── vrex.sh │ └── mixstyle.sh ├── logo.png ├── Tllib.png ├── requirements.txt ├── .github └── ISSUE_TEMPLATE │ ├── custom.md │ ├── feature_request.md │ └── bug_report.md ├── CONTRIBUTING.md ├── LICENSE ├── setup.py └── .gitignore /tllib/reweight/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tllib/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tllib/self_training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tllib/translation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tllib/alignment/d_adapt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinxcontrib-httpdomain 2 | sphinx 3 | -------------------------------------------------------------------------------- /examples/model_selection/requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | numba 3 | -------------------------------------------------------------------------------- /examples/domain_adaptation/image_classification/requirements.txt: -------------------------------------------------------------------------------- 1 | timm -------------------------------------------------------------------------------- /examples/task_adaptation/image_classification/requirements.txt: -------------------------------------------------------------------------------- 1 | timm -------------------------------------------------------------------------------- /examples/domain_adaptation/partial_domain_adaptation/requirements.txt: -------------------------------------------------------------------------------- 1 | timm -------------------------------------------------------------------------------- /examples/semi_supervised_learning/image_classification/requirements.txt: -------------------------------------------------------------------------------- 1 | timm -------------------------------------------------------------------------------- /tllib/vision/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['datasets', 'models', 'transforms'] 2 | -------------------------------------------------------------------------------- /examples/domain_adaptation/re_identification/requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | opencv-python -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/logo.png -------------------------------------------------------------------------------- /Tllib.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/Tllib.png -------------------------------------------------------------------------------- /examples/domain_generalization/re_identification/requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | opencv-python -------------------------------------------------------------------------------- /docs/html.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/docs/html.zip -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_poverty/requirements.txt: -------------------------------------------------------------------------------- 1 | wilds 2 | tensorflow 3 | tensorboard -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | wilds 3 | higher -------------------------------------------------------------------------------- /tllib/vision/models/object_detection/proposal_generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .rpn import TLRPN -------------------------------------------------------------------------------- /tllib/vision/models/reid/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | 3 | __all__ = ['resnet'] 4 | -------------------------------------------------------------------------------- /tllib/alignment/d_adapt/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from . import meta_arch 2 | from . import roi_heads -------------------------------------------------------------------------------- /tllib/alignment/d_adapt/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .roi_heads import DecoupledRes5ROIHeads 2 | -------------------------------------------------------------------------------- /tllib/vision/models/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplabv2 import * 2 | 3 | __all__ = ['deeplabv2'] -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_text/requirements.txt: -------------------------------------------------------------------------------- 1 | wilds 2 | tensorflow 3 | tensorboard 4 | transformers -------------------------------------------------------------------------------- /tllib/vision/models/object_detection/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import VGG, build_vgg_fpn_backbone 2 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/requirements.txt: -------------------------------------------------------------------------------- 1 | mmcv 2 | timm 3 | prettytable 4 | pascal_voc_writer 5 | -------------------------------------------------------------------------------- /tllib/vision/models/object_detection/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .roi_heads import TLRes5ROIHeads, TLStandardROIHeads -------------------------------------------------------------------------------- /docs/_static/images/DANN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/docs/_static/images/DANN.png -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_image_classification/requirements.txt: -------------------------------------------------------------------------------- 1 | wilds 2 | timm 3 | tensorflow 4 | tensorboard 5 | -------------------------------------------------------------------------------- /tllib/vision/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .digits import * 3 | 4 | __all__ = ['resnet', 'digits'] 5 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_ogb_molpcba/requirements.txt: -------------------------------------------------------------------------------- 1 | torch_geometric 2 | wilds 3 | tensorflow 4 | tensorboard 5 | ogb -------------------------------------------------------------------------------- /tllib/vision/models/object_detection/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .rcnn import TLGeneralizedRCNN 2 | from .retinanet import TLRetinaNet -------------------------------------------------------------------------------- /docs/_static/images/TransLearn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/docs/_static/images/TransLearn.png -------------------------------------------------------------------------------- /tllib/translation/spgan/__init__.py: -------------------------------------------------------------------------------- 1 | from . import siamese 2 | from . import loss 3 | from .siamese import * 4 | from .loss import * 5 | -------------------------------------------------------------------------------- /tllib/vision/models/keypoint_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .pose_resnet import * 2 | from . import loss 3 | 4 | __all__ = ['pose_resnet'] -------------------------------------------------------------------------------- /tllib/alignment/d_adapt/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .rcnn import DecoupledGeneralizedRCNN 2 | from .retinanet import DecoupledRetinaNet -------------------------------------------------------------------------------- /tllib/vision/datasets/regression/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_regression import ImageRegression 2 | from .dsprites import DSprites 3 | from .mpi3d import MPI3D -------------------------------------------------------------------------------- /tllib/vision/models/object_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from . import meta_arch 2 | from . import roi_heads 3 | from . import proposal_generator 4 | from . import backbone 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | torchvision>=0.5.0 3 | numpy 4 | prettytable 5 | tqdm 6 | scikit-learn 7 | webcolors 8 | matplotlib 9 | opencv-python 10 | numba 11 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_text/fig/amazon_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/wilds_text/fig/amazon_train_loss.png -------------------------------------------------------------------------------- /tllib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import CompleteLogger 2 | from .meter import * 3 | from .data import ForeverDataIterator 4 | 5 | __all__ = ['metric', 'analysis', 'meter', 'data', 'logger'] -------------------------------------------------------------------------------- /examples/domain_adaptation/image_classification/fig/dann_A2W.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/image_classification/fig/dann_A2W.png -------------------------------------------------------------------------------- /examples/domain_adaptation/image_classification/fig/resnet_A2W.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/image_classification/fig/resnet_A2W.png -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_poverty/fig/poverty_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/wilds_poverty/fig/poverty_train_loss.png -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/d_adapt/fig/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/object_detection/d_adapt/fig/comparison.png -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_text/fig/civilcomments_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/wilds_text/fig/civilcomments_train_loss.png -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Custom issue template 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /examples/domain_adaptation/keypoint_detection/fig/keypoint_detection.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/keypoint_detection/fig/keypoint_detection.jpg -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/fig/cyclegan_fake_T.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/semantic_segmentation/fig/cyclegan_fake_T.png -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/fig/cyclegan_real_S.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/semantic_segmentation/fig/cyclegan_real_S.png -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/fig/segmentation_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/semantic_segmentation/fig/segmentation_image.png -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/fig/segmentation_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/semantic_segmentation/fig/segmentation_label.png -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/fig/segmentation_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/semantic_segmentation/fig/segmentation_pred.png -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_ogb_molpcba/fig/ogb-molpcba_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/wilds_ogb_molpcba/fig/ogb-molpcba_train_loss.png -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/d_adapt/fig/d_adapt_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Transfer-Learning-Library/HEAD/examples/domain_adaptation/object_detection/d_adapt/fig/d_adapt_pipeline.png -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_ogb_molpcba/erm.sh: -------------------------------------------------------------------------------- 1 | # ogb-molpcba 2 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --lr 3e-2 -b 4096 4096 --epochs 200 \ 3 | --seed 0 --deterministic --log logs/erm/obg_lr_0_03_deterministic 4 | -------------------------------------------------------------------------------- /tllib/alignment/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cdan 2 | from . import dann 3 | from . import mdd 4 | from . import dan 5 | from . import jan 6 | from . import mcd 7 | from . import osbp 8 | from . import adda 9 | from . import bsp 10 | -------------------------------------------------------------------------------- /tllib/regularization/__init__.py: -------------------------------------------------------------------------------- 1 | from .bss import * 2 | from .co_tuning import * 3 | from .delta import * 4 | from .bi_tuning import * 5 | from .knowledge_distillation import * 6 | 7 | __all__ = ['bss', 'co_tuning', 'delta', 'bi_tuning', 'knowledge_distillation'] -------------------------------------------------------------------------------- /docs/tllib/utils/index.rst: -------------------------------------------------------------------------------- 1 | ===================================== 2 | Utilities 3 | ===================================== 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :caption: Utilities 8 | :titlesonly: 9 | 10 | base 11 | metric 12 | analysis -------------------------------------------------------------------------------- /docs/tllib/vision/index.rst: -------------------------------------------------------------------------------- 1 | ===================================== 2 | Vision 3 | ===================================== 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :caption: Vision 8 | :titlesonly: 9 | 10 | datasets 11 | models 12 | transforms -------------------------------------------------------------------------------- /tllib/translation/cyclegan/__init__.py: -------------------------------------------------------------------------------- 1 | from . import discriminator 2 | from . import generator 3 | from . import loss 4 | from . import transform 5 | 6 | from .discriminator import * 7 | from .generator import * 8 | from .loss import * 9 | from .transform import * 10 | -------------------------------------------------------------------------------- /tllib/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import * 2 | from .regressor import * 3 | from .grl import * 4 | from .domain_discriminator import * 5 | from .kernels import * 6 | from .entropy import * 7 | 8 | __all__ = ['classifier', 'regressor', 'grl', 'kernels', 'domain_discriminator', 'entropy'] -------------------------------------------------------------------------------- /tllib/vision/datasets/reid/__init__.py: -------------------------------------------------------------------------------- 1 | from .market1501 import Market1501 2 | from .dukemtmc import DukeMTMC 3 | from .msmt17 import MSMT17 4 | from .personx import PersonX 5 | from .unreal import UnrealPerson 6 | 7 | __all__ = ['Market1501', 'DukeMTMC', 'MSMT17', 'PersonX', 'UnrealPerson'] 8 | -------------------------------------------------------------------------------- /tllib/vision/datasets/keypoint_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .rendered_hand_pose import RenderedHandPose 2 | from .hand_3d_studio import Hand3DStudio, Hand3DStudioAll 3 | from .freihand import FreiHand 4 | 5 | from .surreal import SURREAL 6 | from .lsp import LSP 7 | from .human36m import Human36M 8 | 9 | -------------------------------------------------------------------------------- /tllib/vision/datasets/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmentation_list import SegmentationList 2 | from .cityscapes import Cityscapes, FoggyCityscapes 3 | from .gta5 import GTA5 4 | from .synthia import Synthia 5 | 6 | __all__ = ["SegmentationList", "Cityscapes", "GTA5", "Synthia", "FoggyCityscapes"] 7 | -------------------------------------------------------------------------------- /docs/tllib/utils/analysis.rst: -------------------------------------------------------------------------------- 1 | ============== 2 | Analysis Tools 3 | ============== 4 | 5 | 6 | .. autofunction:: tllib.utils.analysis.collect_feature 7 | 8 | 9 | .. autofunction:: tllib.utils.analysis.a_distance.calculate 10 | 11 | 12 | .. autofunction:: tllib.utils.analysis.tsne.visualize 13 | 14 | -------------------------------------------------------------------------------- /tllib/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | from .logme import log_maximum_evidence 2 | from .nce import negative_conditional_entropy 3 | from .leep import log_expected_empirical_prediction 4 | from .hscore import h_score 5 | 6 | __all__ = ['log_maximum_evidence', 'negative_conditional_entropy', 'log_expected_empirical_prediction', 'h_score'] -------------------------------------------------------------------------------- /docs/tllib/alignment/index.rst: -------------------------------------------------------------------------------- 1 | ===================================== 2 | Feature Alignment 3 | ===================================== 4 | 5 | .. toctree:: 6 | :maxdepth: 3 7 | :caption: Feature Alignment 8 | :titlesonly: 9 | 10 | statistics_matching 11 | domain_adversarial 12 | hypothesis_adversarial 13 | -------------------------------------------------------------------------------- /tllib/__init__.py: -------------------------------------------------------------------------------- 1 | from . import alignment 2 | from . import self_training 3 | from . import translation 4 | from . import regularization 5 | from . import utils 6 | from . import vision 7 | from . import modules 8 | from . import ranking 9 | 10 | __version__ = '0.4' 11 | 12 | __all__ = ['alignment', 'self_training', 'translation', 'regularization', 'utils', 'vision', 'modules', 'ranking'] 13 | -------------------------------------------------------------------------------- /docs/tllib/utils/metric.rst: -------------------------------------------------------------------------------- 1 | =========== 2 | Metrics 3 | =========== 4 | 5 | Classification & Segmentation 6 | ============================== 7 | 8 | 9 | Accuracy 10 | --------------------------------- 11 | 12 | .. autofunction:: tllib.utils.metric.accuracy 13 | 14 | 15 | ConfusionMatrix 16 | --------------------------------- 17 | 18 | .. autoclass:: tllib.utils.metric.ConfusionMatrix 19 | :members: 20 | -------------------------------------------------------------------------------- /docs/tllib/vision/transforms.rst: -------------------------------------------------------------------------------- 1 | Transforms 2 | ============================= 3 | 4 | 5 | Classification 6 | --------------------------------- 7 | 8 | .. automodule:: tllib.vision.transforms 9 | :members: 10 | 11 | 12 | Segmentation 13 | --------------------------------- 14 | 15 | 16 | .. automodule:: tllib.vision.transforms.segmentation 17 | :members: 18 | 19 | 20 | Keypoint Detection 21 | --------------------------------- 22 | 23 | 24 | .. automodule:: tllib.vision.transforms.keypoint_detection 25 | :members: 26 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_text/erm.sh: -------------------------------------------------------------------------------- 1 | # civilcomments 2 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "civilcomments" --unlabeled-list "extra_unlabeled" \ 3 | --uniform-over-groups --groupby-fields y black --max-token-length 300 --lr 1e-05 --metric "acc_wg" \ 4 | --seed 0 --deterministic --log logs/erm/civilcomments 5 | 6 | # amazon 7 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "amazon" --max-token-length 512 \ 8 | --lr 1e-5 -b 24 24 --epochs 3 --metric "10th_percentile_acc" --seed 0 --deterministic --log logs/erm/amazon 9 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_image_classification/erm.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ 2 | --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121 3 | 4 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \ 5 | --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 -p 500 --metric "F1-macro_all" \ 6 | --log logs/erm/iwildcam/lr_1_deterministic 7 | -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/advent.sh: -------------------------------------------------------------------------------- 1 | # GTA5 to Cityscapes 2 | CUDA_VISIBLE_DEVICES=0 python advent.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \ 3 | --log logs/advent/gtav2cityscapes 4 | 5 | # Synthia to Cityscapes 6 | CUDA_VISIBLE_DEVICES=0 python advent.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \ 7 | --log logs/advent/synthia2cityscapes 8 | 9 | # Cityscapes to Foggy 10 | CUDA_VISIBLE_DEVICES=0 python advent.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \ 11 | --log logs/advent/cityscapes2foggy -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_image_classification/dan.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python dan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ 2 | --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/dan/fmow/lr_0_1_aa_v0_densenet121 3 | 4 | CUDA_VISIBLE_DEVICES=0 python dan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \ 5 | --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ 6 | --log logs/dan/iwildcam/lr_0_3_deterministic 7 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_image_classification/jan.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python jan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ 2 | --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/jan/fmow/lr_0_1_aa_v0_densenet121 3 | 4 | CUDA_VISIBLE_DEVICES=0 python jan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \ 5 | --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ 6 | --log logs/jan/iwildcam/lr_0_3_deterministic 7 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_image_classification/mdd.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python mdd.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ 2 | --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/mdd/fmow/lr_0_1_aa_v0_densenet121 3 | 4 | CUDA_VISIBLE_DEVICES=0 python mdd.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \ 5 | --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ 6 | --log logs/mdd/iwildcam/lr_0_3_deterministic 7 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_image_classification/cdan.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python cdan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ 2 | --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/cdan/fmow/lr_0_1_aa_v0_densenet121 3 | 4 | CUDA_VISIBLE_DEVICES=0 python cdan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \ 5 | --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ 6 | --log logs/cdan/iwildcam/lr_1_deterministic 7 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_image_classification/dann.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python dann.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ 2 | --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/dann/fmow/lr_0_1_aa_v0_densenet121 3 | 4 | CUDA_VISIBLE_DEVICES=0 python dann.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \ 5 | --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ 6 | --log logs/dann/iwildcam/lr_1_deterministic 7 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_image_classification/fixmatch.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ 2 | --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/fixmatch/fmow/lr_0_1_aa_v0_densenet121 3 | 4 | CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" \ 5 | --lr 0.3 --opt-level O1 --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 12 -b 24 24 -p 500 \ 6 | --metric "F1-macro_all" --log logs/fixmatch/iwildcam/lr_0_3_deterministic 7 | -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/fda.sh: -------------------------------------------------------------------------------- 1 | # GTA5 to Cityscapes 2 | CUDA_VISIBLE_DEVICES=0 python fda.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \ 3 | --log logs/fda/gtav2cityscapes --debug 4 | 5 | # Synthia to Cityscapes 6 | CUDA_VISIBLE_DEVICES=0 python fda.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \ 7 | --log logs/fda/synthia2cityscapes --debug 8 | 9 | # Cityscapes to FoggyCityscapes 10 | CUDA_VISIBLE_DEVICES=0 python fda.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \ 11 | --log logs/fda/cityscapes2foggy --debug 12 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ===================================== 2 | Transfer Learning 3 | ===================================== 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :caption: Transfer Learning API 8 | :titlesonly: 9 | 10 | tllib/modules 11 | tllib/alignment/index 12 | tllib/translation 13 | tllib/self_training 14 | tllib/reweight 15 | tllib/normalization 16 | tllib/regularization 17 | tllib/ranking 18 | 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | :caption: Common API 23 | :titlesonly: 24 | 25 | tllib/vision/index 26 | tllib/utils/index 27 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing to Transfer-Learning-Library 2 | 3 | All kinds of contributions are welcome, including but not limited to the following. 4 | 5 | - Fix typo or bugs 6 | - Add documentation 7 | - Add new features and components 8 | 9 | ### Workflow 10 | 11 | 1. fork and pull the latest Transfer-Learning-Library repository 12 | 2. checkout a new branch (do not use master branch for PRs) 13 | 3. commit your changes 14 | 4. create a PR 15 | 16 | ```{note} 17 | If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first. 18 | ``` 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = PyTorchSphinxTheme 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /docs/tllib/alignment/statistics_matching.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Statistics Matching 3 | ===================== 4 | 5 | 6 | .. _DAN: 7 | 8 | DAN: Deep Adaptation Network 9 | ----------------------------- 10 | 11 | .. autoclass:: tllib.alignment.dan.MultipleKernelMaximumMeanDiscrepancy 12 | 13 | 14 | .. _CORAL: 15 | 16 | Deep CORAL: Correlation Alignment for Deep Domain Adaptation 17 | -------------------------------------------------------------- 18 | 19 | .. autoclass:: tllib.alignment.coral.CorrelationAlignmentLoss 20 | 21 | 22 | .. _JAN: 23 | 24 | JAN: Joint Adaptation Network 25 | ------------------------------ 26 | 27 | .. autoclass:: tllib.alignment.jan.JointMultipleKernelMaximumMeanDiscrepancy 28 | 29 | -------------------------------------------------------------------------------- /examples/domain_adaptation/keypoint_detection/regda.sh: -------------------------------------------------------------------------------- 1 | # Hands Dataset 2 | CUDA_VISIBLE_DEVICES=0 python regda.py data/RHD data/H3D_crop \ 3 | -s RenderedHandPose -t Hand3DStudio --seed 0 --debug --log logs/regda/rhd2h3d 4 | CUDA_VISIBLE_DEVICES=0 python regda.py data/FreiHand data/RHD \ 5 | -s FreiHand -t RenderedHandPose --seed 0 --debug --log logs/regda/freihand2rhd 6 | 7 | # Body Dataset 8 | CUDA_VISIBLE_DEVICES=0 python regda.py data/surreal_processed data/Human36M \ 9 | -s SURREAL -t Human36M --seed 0 --debug --rotation 30 --epochs 10 --log logs/regda/surreal2human36m 10 | CUDA_VISIBLE_DEVICES=0 python regda.py data/surreal_processed data/lsp \ 11 | -s SURREAL -t LSP --seed 0 --debug --rotation 30 --log logs/regda/surreal2lsp 12 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/visualize.sh: -------------------------------------------------------------------------------- 1 | # Source Only Faster RCNN: VOC->Clipart 2 | CUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \ 3 | --test Clipart datasets/clipart --save-path visualizations/source_only/voc2clipart \ 4 | MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth 5 | 6 | # Source Only Faster RCNN: VOC->WaterColor, Comic 7 | CUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \ 8 | --test WaterColorTest datasets/watercolor ComicTest datasets/comic --save-path visualizations/source_only/voc2comic_watercolor \ 9 | MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth 10 | -------------------------------------------------------------------------------- /docs/tllib/ranking.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Ranking 3 | ===================== 4 | 5 | 6 | 7 | .. _H_score: 8 | 9 | H-score 10 | ------------------------------------------- 11 | 12 | .. autofunction:: tllib.ranking.hscore.h_score 13 | 14 | 15 | .. _LEEP: 16 | 17 | LEEP: Log Expected Empirical Prediction 18 | ------------------------------------------- 19 | 20 | .. autofunction:: tllib.ranking.leep.log_expected_empirical_prediction 21 | 22 | 23 | .. _NCE: 24 | 25 | NCE: Negative Conditional Entropy 26 | ------------------------------------------- 27 | 28 | .. autofunction:: tllib.ranking.nce.negative_conditional_entropy 29 | 30 | 31 | .. _LogME: 32 | 33 | LogME: Log Maximum Evidence 34 | ------------------------------------------- 35 | 36 | .. autofunction:: tllib.ranking.logme.log_maximum_evidence 37 | 38 | -------------------------------------------------------------------------------- /examples/task_adaptation/image_classification/convert_moco_to_pretrained.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import sys 6 | import torch 7 | 8 | if __name__ == "__main__": 9 | input = sys.argv[1] 10 | 11 | obj = torch.load(input, map_location="cpu") 12 | obj = obj["state_dict"] 13 | 14 | newmodel = {} 15 | fc = {} 16 | for k, v in obj.items(): 17 | if not k.startswith("module.encoder_q."): 18 | continue 19 | old_k = k 20 | k = k.replace("module.encoder_q.", "") 21 | if k.startswith("fc"): 22 | print(k) 23 | fc[k] = v 24 | else: 25 | newmodel[k] = v 26 | 27 | with open(sys.argv[2], "wb") as f: 28 | torch.save(newmodel, f) 29 | 30 | with open(sys.argv[3], "wb") as f: 31 | torch.save(fc, f) 32 | -------------------------------------------------------------------------------- /examples/semi_supervised_learning/image_classification/convert_moco_to_pretrained.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import sys 6 | import torch 7 | 8 | if __name__ == "__main__": 9 | input = sys.argv[1] 10 | 11 | obj = torch.load(input, map_location="cpu") 12 | obj = obj["state_dict"] 13 | 14 | newmodel = {} 15 | fc = {} 16 | for k, v in obj.items(): 17 | if not k.startswith("module.encoder_q."): 18 | continue 19 | old_k = k 20 | k = k.replace("module.encoder_q.", "") 21 | if k.startswith("fc"): 22 | print(k) 23 | fc[k] = v 24 | else: 25 | newmodel[k] = v 26 | 27 | with open(sys.argv[2], "wb") as f: 28 | torch.save(newmodel, f) 29 | 30 | with open(sys.argv[3], "wb") as f: 31 | torch.save(fc, f) 32 | -------------------------------------------------------------------------------- /docs/tllib/reweight.rst: -------------------------------------------------------------------------------- 1 | ======================================= 2 | Re-weighting 3 | ======================================= 4 | 5 | 6 | .. _PADA: 7 | 8 | PADA: Partial Adversarial Domain Adaptation 9 | --------------------------------------------- 10 | 11 | .. autoclass:: tllib.reweight.pada.ClassWeightModule 12 | 13 | .. autoclass:: tllib.reweight.pada.AutomaticUpdateClassWeightModule 14 | :members: 15 | 16 | .. autofunction:: tllib.reweight.pada.collect_classification_results 17 | 18 | 19 | .. _IWAN: 20 | 21 | IWAN: Importance Weighted Adversarial Nets 22 | --------------------------------------------- 23 | 24 | .. autoclass:: tllib.reweight.iwan.ImportanceWeightModule 25 | :members: 26 | 27 | 28 | 29 | .. _GroupDRO: 30 | 31 | GroupDRO: Group Distributionally robust optimization 32 | ------------------------------------------------------ 33 | 34 | .. autoclass:: tllib.reweight.groupdro.AutomaticUpdateDomainWeightModule 35 | :members: 36 | -------------------------------------------------------------------------------- /docs/tllib/utils/base.rst: -------------------------------------------------------------------------------- 1 | Generic Tools 2 | ============== 3 | 4 | 5 | Average Meter 6 | --------------------------------- 7 | 8 | .. autoclass:: tllib.utils.meter.AverageMeter 9 | :members: 10 | 11 | Progress Meter 12 | --------------------------------- 13 | 14 | .. autoclass:: tllib.utils.meter.ProgressMeter 15 | :members: 16 | 17 | Meter 18 | --------------------------------- 19 | 20 | .. autoclass:: tllib.utils.meter.Meter 21 | :members: 22 | 23 | Data 24 | --------------------------------- 25 | 26 | .. autoclass:: tllib.utils.data.ForeverDataIterator 27 | :members: 28 | 29 | .. autoclass:: tllib.utils.data.CombineDataset 30 | :members: 31 | 32 | .. autofunction:: tllib.utils.data.send_to_device 33 | 34 | .. autofunction:: tllib.utils.data.concatenate 35 | 36 | Logger 37 | ----------- 38 | 39 | .. autoclass:: tllib.utils.logger.TextLogger 40 | :members: 41 | 42 | 43 | .. autoclass:: tllib.utils.logger.CompleteLogger 44 | :members: 45 | 46 | -------------------------------------------------------------------------------- /tllib/vision/datasets/cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from torchvision.datasets.cifar import CIFAR10 as CIFAR10Base, CIFAR100 as CIFAR100Base 6 | 7 | 8 | class CIFAR10(CIFAR10Base): 9 | """ 10 | `CIFAR10 `_ Dataset. 11 | """ 12 | 13 | def __init__(self, root, split='train', transform=None, download=True): 14 | super(CIFAR10, self).__init__(root, train=split == 'train', transform=transform, download=download) 15 | self.num_classes = 10 16 | 17 | 18 | class CIFAR100(CIFAR100Base): 19 | """ 20 | `CIFAR100 `_ Dataset. 21 | """ 22 | 23 | def __init__(self, root, split='train', transform=None, download=True): 24 | super(CIFAR100, self).__init__(root, train=split == 'train', transform=transform, download=download) 25 | self.num_classes = 100 26 | -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/erm.sh: -------------------------------------------------------------------------------- 1 | # Source Only 2 | # GTA5 to Cityscapes 3 | CUDA_VISIBLE_DEVICES=0 python erm.py data/GTA5 data/Cityscapes \ 4 | -s GTA5 -t Cityscapes --log logs/erm/gtav2cityscapes 5 | 6 | # Synthia to Cityscapes 7 | CUDA_VISIBLE_DEVICES=0 python erm.py data/synthia data/Cityscapes \ 8 | -s Synthia -t Cityscapes --log logs/erm/synthia2cityscapes 9 | 10 | # Cityscapes to FoggyCityscapes 11 | CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \ 12 | -s Cityscapes -t FoggyCityscapes --log logs/erm/cityscapes2foggy 13 | 14 | # Oracle 15 | # Oracle Results on Cityscapes 16 | CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \ 17 | -s Cityscapes -t Cityscapes --log logs/oracle/cityscapes 18 | 19 | # Oracle Results on Foggy Cityscapes 20 | CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \ 21 | -s FoggyCityscapes -t FoggyCityscapes --log logs/oracle/foggy_cityscapes 22 | 23 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SPHINXOPTS= 11 | set SPHINXBUILD=sphinx-build 12 | set SOURCEDIR=. 13 | set BUILDDIR=build 14 | set SPHINXPROJ=PyTorchSphinxTheme 15 | 16 | if "%1" == "" goto help 17 | 18 | %SPHINXBUILD% >NUL 2>NUL 19 | if errorlevel 9009 ( 20 | echo. 21 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 22 | echo.then set the SPHINXBUILD environment variable to point to the full 23 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 24 | echo.Sphinx directory to PATH. 25 | echo. 26 | echo.If you don't have Sphinx installed, grab it from 27 | echo.http://sphinx-doc.org/ 28 | exit /b 1 29 | ) 30 | 31 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 32 | goto end 33 | 34 | :help 35 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 36 | 37 | :end 38 | popd 39 | -------------------------------------------------------------------------------- /examples/domain_adaptation/keypoint_detection/regda_fast.sh: -------------------------------------------------------------------------------- 1 | # regda_fast is provided by https://github.com/YouJiacheng?tab=repositories 2 | # On single V100(16G), overall adversarial training time is reduced by about 40%. 3 | # yet the PCK might drop 1% for each dataset. 4 | # Hands Dataset 5 | CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/RHD data/H3D_crop \ 6 | -s RenderedHandPose -t Hand3DStudio --seed 0 --debug --log logs/regda_fast/rhd2h3d 7 | CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/FreiHand data/RHD \ 8 | -s FreiHand -t RenderedHandPose --seed 0 --debug --log logs/regda_fast/freihand2rhd 9 | 10 | # Body Dataset 11 | CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/surreal_processed data/Human36M \ 12 | -s SURREAL -t Human36M --seed 0 --debug --rotation 30 --epochs 10 --log logs/regda_fast/surreal2human36m 13 | CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/surreal_processed data/lsp \ 14 | -s SURREAL -t LSP --seed 0 --debug --rotation 30 --log logs/regda_fast/surreal2lsp 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/config/faster_rcnn_R_101_C4_cityscapes.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "TLGeneralizedRCNN" 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | MASK_ON: False 5 | RESNETS: 6 | DEPTH: 101 7 | ROI_HEADS: 8 | NAME: "TLRes5ROIHeads" 9 | NUM_CLASSES: 8 10 | BATCH_SIZE_PER_IMAGE: 512 11 | ANCHOR_GENERATOR: 12 | SIZES: [ [ 64, 128, 256, 512 ] ] 13 | RPN: 14 | PRE_NMS_TOPK_TEST: 6000 15 | POST_NMS_TOPK_TEST: 1000 16 | BATCH_SIZE_PER_IMAGE: 256 17 | PROPOSAL_GENERATOR: 18 | NAME: "TLRPN" 19 | INPUT: 20 | MIN_SIZE_TRAIN: (512, 544, 576, 608, 640, 672, 704,) 21 | MIN_SIZE_TEST: 608 22 | MAX_SIZE_TRAIN: 1166 23 | DATASETS: 24 | TRAIN: ("cityscapes_trainval",) 25 | TEST: ("cityscapes_test",) 26 | SOLVER: 27 | STEPS: (12000,) 28 | MAX_ITER: 16000 # 16 epochs 29 | WARMUP_ITERS: 100 30 | CHECKPOINT_PERIOD: 2000 31 | IMS_PER_BATCH: 2 32 | BASE_LR: 0.005 33 | TEST: 34 | EVAL_PERIOD: 2000 35 | VIS_PERIOD: 500 36 | VERSION: 2 37 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/config/faster_rcnn_R_101_C4_voc.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "TLGeneralizedRCNN" 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | MASK_ON: False 5 | RESNETS: 6 | DEPTH: 101 7 | ROI_HEADS: 8 | NAME: "TLRes5ROIHeads" 9 | NUM_CLASSES: 20 10 | BATCH_SIZE_PER_IMAGE: 256 11 | ANCHOR_GENERATOR: 12 | SIZES: [ [ 64, 128, 256, 512 ] ] 13 | RPN: 14 | PRE_NMS_TOPK_TEST: 6000 15 | POST_NMS_TOPK_TEST: 1000 16 | BATCH_SIZE_PER_IMAGE: 128 17 | PROPOSAL_GENERATOR: 18 | NAME: "TLRPN" 19 | INPUT: 20 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704,) 21 | MIN_SIZE_TEST: 608 22 | MAX_SIZE_TRAIN: 1166 23 | DATASETS: 24 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 25 | TEST: ('voc_2007_test',) 26 | SOLVER: 27 | STEPS: (12000, ) 28 | MAX_ITER: 16000 # 16 epochs 29 | WARMUP_ITERS: 100 30 | CHECKPOINT_PERIOD: 2000 31 | IMS_PER_BATCH: 4 32 | BASE_LR: 0.005 33 | TEST: 34 | EVAL_PERIOD: 2000 35 | VIS_PERIOD: 500 36 | VERSION: 2 -------------------------------------------------------------------------------- /docs/tllib/modules.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Modules 3 | ===================== 4 | 5 | 6 | Classifier 7 | ------------------------------- 8 | .. autoclass:: tllib.modules.classifier.Classifier 9 | :members: 10 | 11 | Regressor 12 | ------------------------------- 13 | .. autoclass:: tllib.modules.regressor.Regressor 14 | :members: 15 | 16 | Domain Discriminator 17 | ------------------------------- 18 | .. autoclass:: tllib.modules.domain_discriminator.DomainDiscriminator 19 | :members: 20 | 21 | GRL: Gradient Reverse Layer 22 | ----------------------------- 23 | .. autoclass:: tllib.modules.grl.WarmStartGradientReverseLayer 24 | :members: 25 | 26 | Gaussian Kernels 27 | ------------------------ 28 | .. autoclass:: tllib.modules.kernels.GaussianKernel 29 | 30 | 31 | Entropy 32 | ------------------------ 33 | .. autofunction:: tllib.modules.entropy.entropy 34 | 35 | 36 | Knowledge Distillation Loss 37 | ------------------------------- 38 | .. autoclass:: tllib.modules.loss.KnowledgeDistillationLoss 39 | :members: 40 | 41 | 42 | -------------------------------------------------------------------------------- /examples/domain_adaptation/keypoint_detection/erm.sh: -------------------------------------------------------------------------------- 1 | # Source Only 2 | # Hands Dataset 3 | CUDA_VISIBLE_DEVICES=0 python erm.py data/RHD data/H3D_crop \ 4 | -s RenderedHandPose -t Hand3DStudio --log logs/erm/rhd2h3d --debug --seed 0 5 | CUDA_VISIBLE_DEVICES=0 python erm.py data/FreiHand data/RHD \ 6 | -s FreiHand -t RenderedHandPose --log logs/erm/freihand2rhd --debug --seed 0 7 | 8 | # Body Dataset 9 | CUDA_VISIBLE_DEVICES=0 python erm.py data/surreal_processed data/Human36M \ 10 | -s SURREAL -t Human36M --log logs/erm/surreal2human36m --debug --seed 0 --rotation 30 11 | CUDA_VISIBLE_DEVICES=0 python erm.py data/surreal_processed data/lsp \ 12 | -s SURREAL -t LSP --log logs/erm/surreal2lsp --debug --seed 0 --rotation 30 13 | 14 | # Oracle Results 15 | CUDA_VISIBLE_DEVICES=0 python erm.py data/H3D_crop data/H3D_crop \ 16 | -s Hand3DStudio -t Hand3DStudio --log logs/oracle/h3d --debug --seed 0 17 | CUDA_VISIBLE_DEVICES=0 python erm.py data/Human36M data/Human36M \ 18 | -s Human36M -t Human36M --log logs/oracle/human36m --debug --seed 0 --rotation 30 19 | -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/cycle_gan.sh: -------------------------------------------------------------------------------- 1 | # GTA5 to Cityscapes 2 | # First, train the CycleGAN 3 | CUDA_VISIBLE_DEVICES=0 python cycle_gan.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \ 4 | --log logs/cyclegan/gtav2cityscapes --translated-root data/GTA52Cityscapes/CycleGAN_39 5 | # Then, train the src_only model on the translated source dataset 6 | CUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA52Cityscapes/CycleGAN_39 data/Cityscapes \ 7 | -s GTA5 -t Cityscapes --log logs/cyclegan_src_only/gtav2cityscapes 8 | 9 | 10 | # Cityscapes to FoggyCityscapes 11 | # First, train the CycleGAN 12 | CUDA_VISIBLE_DEVICES=0 python cycle_gan.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \ 13 | --log logs/cyclegan/cityscapes2foggy --translated-root data/Cityscapes2Foggy/CycleGAN_39 14 | # Then, train the src_only model on the translated source dataset 15 | CUDA_VISIBLE_DEVICES=0 python source_only.py data/Cityscapes2Foggy/CycleGAN_39 data/Cityscapes \ 16 | -s Cityscapes -t FoggyCityscapes --log logs/cyclegan_src_only/cityscapes2foggy 17 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_poverty/erm.sh: -------------------------------------------------------------------------------- 1 | # official split scheme 2 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold A \ 3 | --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_A 4 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold B \ 5 | --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_B 6 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold C \ 7 | --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_C 8 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold D \ 9 | --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_D 10 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold E \ 11 | --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_E 12 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/d_adapt/config/retinanet_R_101_FPN_voc.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "DecoupledRetinaNet" 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | BACKBONE: 5 | NAME: "build_retinanet_resnet_fpn_backbone" 6 | MASK_ON: False 7 | RESNETS: 8 | DEPTH: 101 9 | OUT_FEATURES: [ "res4", "res5" ] 10 | ANCHOR_GENERATOR: 11 | SIZES: !!python/object/apply:eval [ "[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [64, 128, 256, 512 ]]" ] 12 | RETINANET: 13 | NUM_CLASSES: 20 14 | IN_FEATURES: ["p4", "p5", "p6", "p7"] 15 | FPN: 16 | IN_FEATURES: ["res4", "res5"] 17 | INPUT: 18 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, ) 19 | MIN_SIZE_TEST: 608 20 | MAX_SIZE_TRAIN: 1166 21 | DATASETS: 22 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 23 | TEST: ('voc_2007_test',) 24 | SOLVER: 25 | STEPS: (3999, ) 26 | MAX_ITER: 4000 # 16 epochs 27 | WARMUP_ITERS: 100 28 | CHECKPOINT_PERIOD: 1000 29 | IMS_PER_BATCH: 8 30 | BASE_LR: 0.001 31 | TEST: 32 | EVAL_PERIOD: 500 33 | VIS_PERIOD: 20 34 | VERSION: 2 -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_R_101_C4_cityscapes.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "DecoupledGeneralizedRCNN" 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | MASK_ON: False 5 | RESNETS: 6 | DEPTH: 101 7 | ROI_HEADS: 8 | NAME: "DecoupledRes5ROIHeads" 9 | NUM_CLASSES: 8 10 | BATCH_SIZE_PER_IMAGE: 512 11 | ANCHOR_GENERATOR: 12 | SIZES: [ [ 64, 128, 256, 512 ] ] 13 | RPN: 14 | PRE_NMS_TOPK_TEST: 6000 15 | POST_NMS_TOPK_TEST: 1000 16 | BATCH_SIZE_PER_IMAGE: 256 17 | PROPOSAL_GENERATOR: 18 | NAME: "TLRPN" 19 | INPUT: 20 | MIN_SIZE_TRAIN: (512, 544, 576, 608, 640, 672, 704,) 21 | MIN_SIZE_TEST: 800 22 | MAX_SIZE_TRAIN: 1166 23 | DATASETS: 24 | TRAIN: ("cityscapes_trainval",) 25 | TEST: ("cityscapes_test",) 26 | SOLVER: 27 | STEPS: (3999, ) 28 | MAX_ITER: 4000 # 4 epochs 29 | WARMUP_ITERS: 100 30 | CHECKPOINT_PERIOD: 1000 31 | IMS_PER_BATCH: 2 32 | BASE_LR: 0.005 33 | LR_SCHEDULER_NAME: "ExponentialLR" 34 | GAMMA: 0.1 35 | TEST: 36 | EVAL_PERIOD: 500 37 | VIS_PERIOD: 20 38 | VERSION: 2 39 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_R_101_C4_voc.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "DecoupledGeneralizedRCNN" 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | MASK_ON: False 5 | RESNETS: 6 | DEPTH: 101 7 | ROI_HEADS: 8 | NAME: "DecoupledRes5ROIHeads" 9 | NUM_CLASSES: 20 10 | BATCH_SIZE_PER_IMAGE: 256 11 | ANCHOR_GENERATOR: 12 | SIZES: [ [ 64, 128, 256, 512 ] ] 13 | RPN: 14 | PRE_NMS_TOPK_TEST: 6000 15 | POST_NMS_TOPK_TEST: 1000 16 | BATCH_SIZE_PER_IMAGE: 128 17 | PROPOSAL_GENERATOR: 18 | NAME: "TLRPN" 19 | INPUT: 20 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704,) 21 | MIN_SIZE_TEST: 608 22 | MAX_SIZE_TRAIN: 1166 23 | DATASETS: 24 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 25 | TEST: ('voc_2007_test',) 26 | SOLVER: 27 | STEPS: (3999, ) 28 | MAX_ITER: 4000 # 16 epochs 29 | WARMUP_ITERS: 100 30 | CHECKPOINT_PERIOD: 1000 31 | IMS_PER_BATCH: 4 32 | BASE_LR: 0.00025 33 | LR_SCHEDULER_NAME: "ExponentialLR" 34 | GAMMA: 0.1 35 | TEST: 36 | EVAL_PERIOD: 500 37 | VIS_PERIOD: 20 38 | VERSION: 2 -------------------------------------------------------------------------------- /examples/domain_generalization/re_identification/baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Market1501 -> Duke 3 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \ 4 | --finetune --seed 0 --log logs/baseline/Market2Duke 5 | 6 | # Duke -> Market1501 7 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a reid_resnet50 \ 8 | --finetune --seed 0 --log logs/baseline/Duke2Market 9 | 10 | # Market1501 -> MSMT 11 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a reid_resnet50 \ 12 | --finetune --seed 0 --log logs/baseline/Market2MSMT 13 | 14 | # MSMT -> Market1501 15 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a reid_resnet50 \ 16 | --finetune --seed 0 --log logs/baseline/MSMT2Market 17 | 18 | # Duke -> MSMT 19 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ 20 | --finetune --seed 0 --log logs/baseline/Duke2MSMT 21 | 22 | # MSMT -> Duke 23 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ 24 | --finetune --seed 0 --log logs/baseline/MSMT2Duke 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/config/retinanet_R_101_FPN_voc.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "TLRetinaNet" 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | BACKBONE: 5 | NAME: "build_retinanet_resnet_fpn_backbone" 6 | MASK_ON: False 7 | RESNETS: 8 | DEPTH: 101 9 | OUT_FEATURES: [ "res4", "res5" ] 10 | ANCHOR_GENERATOR: 11 | SIZES: !!python/object/apply:eval [ "[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [64, 128, 256, 512 ]]" ] 12 | RETINANET: 13 | NUM_CLASSES: 20 14 | IN_FEATURES: ["p4", "p5", "p6", "p7"] 15 | IOU_THRESHOLDS: [ 0.4, 0.5 ] 16 | IOU_LABELS: [ 0, -1, 1 ] 17 | SMOOTH_L1_LOSS_BETA: 0.0 18 | FPN: 19 | IN_FEATURES: ["res4", "res5"] 20 | INPUT: 21 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, ) 22 | MIN_SIZE_TEST: 608 23 | MAX_SIZE_TRAIN: 1166 24 | DATASETS: 25 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 26 | TEST: ('voc_2007_test',) 27 | SOLVER: 28 | STEPS: (12000, ) 29 | MAX_ITER: 16000 # 16 epochs 30 | WARMUP_ITERS: 100 31 | CHECKPOINT_PERIOD: 2000 32 | IMS_PER_BATCH: 8 33 | BASE_LR: 0.005 34 | TEST: 35 | EVAL_PERIOD: 2000 36 | VIS_PERIOD: 500 37 | VERSION: 2 -------------------------------------------------------------------------------- /examples/domain_generalization/re_identification/mixstyle.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Market1501 -> Duke 3 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t DukeMTMC -a resnet50 \ 4 | --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2Duke 5 | 6 | # Duke -> Market1501 7 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s DukeMTMC -t Market1501 -a resnet50 \ 8 | --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Duke2Market 9 | 10 | # Market1501 -> MSMT 11 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t MSMT17 -a resnet50 \ 12 | --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2MSMT 13 | 14 | # MSMT -> Market1501 15 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s MSMT17 -t Market1501 -a resnet50 \ 16 | --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/MSMT2Market 17 | 18 | # Duke -> MSMT 19 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s DukeMTMC -t MSMT17 -a resnet50 \ 20 | --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Duke2MSMT 21 | 22 | # MSMT -> Duke 23 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s MSMT17 -t DukeMTMC -a resnet50 \ 24 | --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/MSMT2Duke 25 | -------------------------------------------------------------------------------- /tllib/modules/entropy.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch 6 | 7 | 8 | def entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor: 9 | r"""Entropy of prediction. 10 | The definition is: 11 | 12 | .. math:: 13 | entropy(p) = - \sum_{c=1}^C p_c \log p_c 14 | 15 | where C is number of classes. 16 | 17 | Args: 18 | predictions (tensor): Classifier predictions. Expected to contain raw, normalized scores for each class 19 | reduction (str, optional): Specifies the reduction to apply to the output: 20 | ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, 21 | ``'mean'``: the sum of the output will be divided by the number of 22 | elements in the output. Default: ``'mean'`` 23 | 24 | Shape: 25 | - predictions: :math:`(minibatch, C)` where C means the number of classes. 26 | - Output: :math:`(minibatch, )` by default. If :attr:`reduction` is ``'mean'``, then scalar. 27 | """ 28 | epsilon = 1e-5 29 | H = -predictions * torch.log(predictions + epsilon) 30 | H = H.sum(dim=1) 31 | if reduction == 'mean': 32 | return H.mean() 33 | else: 34 | return H 35 | -------------------------------------------------------------------------------- /tllib/utils/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torch.nn as nn 4 | import tqdm 5 | 6 | 7 | def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module, 8 | device: torch.device, max_num_features=None) -> torch.Tensor: 9 | """ 10 | Fetch data from `data_loader`, and then use `feature_extractor` to collect features 11 | 12 | Args: 13 | data_loader (torch.utils.data.DataLoader): Data loader. 14 | feature_extractor (torch.nn.Module): A feature extractor. 15 | device (torch.device) 16 | max_num_features (int): The max number of features to return 17 | 18 | Returns: 19 | Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`). 20 | """ 21 | feature_extractor.eval() 22 | all_features = [] 23 | with torch.no_grad(): 24 | for i, data in enumerate(tqdm.tqdm(data_loader)): 25 | if max_num_features is not None and i >= max_num_features: 26 | break 27 | inputs = data[0].to(device) 28 | feature = feature_extractor(inputs).cpu() 29 | all_features.append(feature) 30 | return torch.cat(all_features, dim=0) 31 | -------------------------------------------------------------------------------- /tllib/regularization/knowledge_distillation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class KnowledgeDistillationLoss(nn.Module): 6 | """Knowledge Distillation Loss. 7 | 8 | Args: 9 | T (double): Temperature. Default: 1. 10 | reduction (str, optional): Specifies the reduction to apply to the output: 11 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 12 | ``'mean'``: the sum of the output will be divided by the number of 13 | elements in the output, ``'sum'``: the output will be summed. Default: ``'batchmean'`` 14 | 15 | Inputs: 16 | - y_student (tensor): logits output of the student 17 | - y_teacher (tensor): logits output of the teacher 18 | 19 | Shape: 20 | - y_student: (minibatch, `num_classes`) 21 | - y_teacher: (minibatch, `num_classes`) 22 | 23 | """ 24 | def __init__(self, T=1., reduction='batchmean'): 25 | super(KnowledgeDistillationLoss, self).__init__() 26 | self.T = T 27 | self.kl = nn.KLDivLoss(reduction=reduction) 28 | 29 | def forward(self, y_student, y_teacher): 30 | """""" 31 | return self.kl(F.log_softmax(y_student / self.T, dim=-1), F.softmax(y_teacher / self.T, dim=-1)) 32 | -------------------------------------------------------------------------------- /tllib/vision/datasets/reid/convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | import os.path as osp 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | 9 | 10 | def convert_to_pytorch_dataset(dataset, root=None, transform=None, return_idxes=False): 11 | class ReidDataset(Dataset): 12 | def __init__(self, dataset, root, transform): 13 | super(ReidDataset, self).__init__() 14 | self.dataset = dataset 15 | self.root = root 16 | self.transform = transform 17 | self.return_idxes = return_idxes 18 | 19 | def __len__(self): 20 | return len(self.dataset) 21 | 22 | def __getitem__(self, index): 23 | fname, pid, cid = self.dataset[index] 24 | fpath = fname 25 | if self.root is not None: 26 | fpath = osp.join(self.root, fname) 27 | 28 | img = Image.open(fpath).convert('RGB') 29 | 30 | if self.transform is not None: 31 | img = self.transform(img) 32 | 33 | if not self.return_idxes: 34 | return img, fname, pid, cid 35 | else: 36 | return img, fname, pid, cid, index 37 | 38 | return ReidDataset(dataset, root, transform) 39 | -------------------------------------------------------------------------------- /examples/domain_adaptation/re_identification/baseline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Market1501 -> Duke 3 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \ 4 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2Duke 5 | 6 | # Duke -> Market1501 7 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \ 8 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2Market 9 | 10 | # Market1501 -> MSMT 11 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \ 12 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMT 13 | 14 | # MSMT -> Market1501 15 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \ 16 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Market 17 | 18 | # Duke -> MSMT 19 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ 20 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMT 21 | 22 | # MSMT -> Duke 23 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ 24 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Duke 25 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_ogb_molpcba/README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Domain Adaptation for WILDS (Molecule classification) 2 | 3 | ## Installation 4 | 5 | It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results. 6 | 7 | Then, you need to run 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | At last, you need to install torch_sparse following `https://github.com/rusty1s/pytorch_sparse`. 14 | 15 | ## Dataset 16 | 17 | Following datasets can be downloaded automatically: 18 | 19 | - [OGB-MolPCBA (WILDS)](https://wilds.stanford.edu/datasets/) 20 | 21 | ## Supported Methods 22 | 23 | TODO 24 | 25 | ## Usage 26 | 27 | The shell files give all the training scripts we use, e.g. 28 | 29 | ``` 30 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --lr 3e-2 -b 4096 4096 --epochs 200 \ 31 | --seed 0 --deterministic --log logs/erm/obg_lr_0_03_deterministic 32 | ``` 33 | 34 | ## Results 35 | 36 | ### Performance on WILDS-OGB-MolPCBA (GIN-virtual) 37 | 38 | | Methods | Val Avg Precision | Test Avg Precision | GPU Memory Usage(GB)| 39 | | --- | --- | --- | --- | 40 | | ERM | 29.0 | 28.0 | 17.8 | 41 | 42 | ### Visualization 43 | 44 | We use tensorboard to record the training process and visualize the outputs of the models. 45 | 46 | ``` 47 | tensorboard --logdir=logs 48 | ``` 49 | 50 | 51 | -------------------------------------------------------------------------------- /tllib/alignment/rsd.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch.nn as nn 6 | import torch 7 | 8 | 9 | class RepresentationSubspaceDistance(nn.Module): 10 | """ 11 | `Representation Subspace Distance (ICML 2021) `_ 12 | 13 | Args: 14 | trade_off (float): The trade-off value between Representation Subspace Distance 15 | and Base Mismatch Penalization. Default: 0.1 16 | 17 | Inputs: 18 | - f_s (tensor): feature representations on source domain, :math:`f^s` 19 | - f_t (tensor): feature representations on target domain, :math:`f^t` 20 | 21 | """ 22 | def __init__(self, trade_off=0.1): 23 | super(RepresentationSubspaceDistance, self).__init__() 24 | self.trade_off = trade_off 25 | 26 | def forward(self, f_s, f_t): 27 | U_s, _, _ = torch.svd(f_s.t()) 28 | U_t, _, _ = torch.svd(f_t.t()) 29 | P_s, cosine, P_t = torch.svd(torch.mm(U_s.t(), U_t)) 30 | sine = torch.sqrt(1 - torch.pow(cosine, 2)) 31 | rsd = torch.norm(sine, 1) # Representation Subspace Distance 32 | bmp = torch.norm(torch.abs(P_s) - torch.abs(P_t), 2) # Base Mismatch Penalization 33 | return rsd + self.trade_off * bmp -------------------------------------------------------------------------------- /tllib/vision/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagelist import ImageList 2 | from .office31 import Office31 3 | from .officehome import OfficeHome 4 | from .visda2017 import VisDA2017 5 | from .officecaltech import OfficeCaltech 6 | from .domainnet import DomainNet 7 | from .imagenet_r import ImageNetR 8 | from .imagenet_sketch import ImageNetSketch 9 | from .pacs import PACS 10 | from .digits import * 11 | from .aircrafts import Aircraft 12 | from .cub200 import CUB200 13 | from .stanford_cars import StanfordCars 14 | from .stanford_dogs import StanfordDogs 15 | from .coco70 import COCO70 16 | from .oxfordpets import OxfordIIITPets 17 | from .dtd import DTD 18 | from .oxfordflowers import OxfordFlowers102 19 | from .patchcamelyon import PatchCamelyon 20 | from .retinopathy import Retinopathy 21 | from .eurosat import EuroSAT 22 | from .resisc45 import Resisc45 23 | from .food101 import Food101 24 | from .sun397 import SUN397 25 | from .caltech101 import Caltech101 26 | from .cifar import CIFAR10, CIFAR100 27 | 28 | __all__ = ['ImageList', 'Office31', 'OfficeHome', "VisDA2017", "OfficeCaltech", "DomainNet", "ImageNetR", 29 | "ImageNetSketch", "Aircraft", "cub200", "StanfordCars", "StanfordDogs", "COCO70", "OxfordIIITPets", "PACS", 30 | "DTD", "OxfordFlowers102", "PatchCamelyon", "Retinopathy", "EuroSAT", "Resisc45", "Food101", "SUN397", 31 | "Caltech101", "CIFAR10", "CIFAR100"] 32 | -------------------------------------------------------------------------------- /tllib/vision/datasets/retinopathy.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from .imagelist import ImageList 7 | 8 | 9 | class Retinopathy(ImageList): 10 | """`Retinopathy `_ dataset \ 11 | consists of image-label pairs with high-resolution retina images, and labels that indicate \ 12 | the presence of Diabetic Retinopahy (DR) in a 0-4 scale (No DR, Mild, Moderate, Severe, \ 13 | or Proliferative DR). 14 | 15 | .. note:: You need to download the source data manually into `root` directory. 16 | 17 | Args: 18 | root (str): Root directory of dataset 19 | split (str, optional): The dataset split, supports ``train``, or ``test``. 20 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 21 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 22 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 23 | 24 | """ 25 | CLASSES = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'] 26 | 27 | def __init__(self, root, split, download=False, **kwargs): 28 | 29 | super(Retinopathy, self).__init__(os.path.join(root, split), Retinopathy.CLASSES, os.path.join(root, "image_list", "{}.txt".format(split)), **kwargs) 30 | -------------------------------------------------------------------------------- /docs/tllib/alignment/hypothesis_adversarial.rst: -------------------------------------------------------------------------------- 1 | ========================================== 2 | Hypothesis Adversarial Learning 3 | ========================================== 4 | 5 | 6 | 7 | .. _MCD: 8 | 9 | MCD: Maximum Classifier Discrepancy 10 | -------------------------------------------- 11 | 12 | .. autofunction:: tllib.alignment.mcd.classifier_discrepancy 13 | 14 | .. autofunction:: tllib.alignment.mcd.entropy 15 | 16 | .. autoclass:: tllib.alignment.mcd.ImageClassifierHead 17 | 18 | 19 | .. _MDD: 20 | 21 | 22 | MDD: Margin Disparity Discrepancy 23 | -------------------------------------------- 24 | 25 | 26 | .. autoclass:: tllib.alignment.mdd.MarginDisparityDiscrepancy 27 | 28 | 29 | **MDD for Classification** 30 | 31 | 32 | .. autoclass:: tllib.alignment.mdd.ClassificationMarginDisparityDiscrepancy 33 | 34 | 35 | .. autoclass:: tllib.alignment.mdd.ImageClassifier 36 | :members: 37 | 38 | .. autofunction:: tllib.alignment.mdd.shift_log 39 | 40 | 41 | **MDD for Regression** 42 | 43 | .. autoclass:: tllib.alignment.mdd.RegressionMarginDisparityDiscrepancy 44 | 45 | .. autoclass:: tllib.alignment.mdd.ImageRegressor 46 | 47 | 48 | .. _RegDA: 49 | 50 | RegDA: Regressive Domain Adaptation 51 | -------------------------------------------- 52 | 53 | .. autoclass:: tllib.alignment.regda.PseudoLabelGenerator2d 54 | 55 | .. autoclass:: tllib.alignment.regda.RegressionDisparity 56 | 57 | .. autoclass:: tllib.alignment.regda.PoseResNet2d 58 | -------------------------------------------------------------------------------- /docs/tllib/normalization.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Normalization 3 | ===================== 4 | 5 | 6 | 7 | .. _AFN: 8 | 9 | AFN: Adaptive Feature Norm 10 | ----------------------------- 11 | 12 | .. autoclass:: tllib.normalization.afn.AdaptiveFeatureNorm 13 | 14 | .. autoclass:: tllib.normalization.afn.Block 15 | 16 | .. autoclass:: tllib.normalization.afn.ImageClassifier 17 | 18 | 19 | StochNorm: Stochastic Normalization 20 | ------------------------------------------ 21 | 22 | .. autoclass:: tllib.normalization.stochnorm.StochNorm1d 23 | 24 | .. autoclass:: tllib.normalization.stochnorm.StochNorm2d 25 | 26 | .. autoclass:: tllib.normalization.stochnorm.StochNorm3d 27 | 28 | .. autofunction:: tllib.normalization.stochnorm.convert_model 29 | 30 | 31 | .. _IBN: 32 | 33 | IBN-Net: Instance-Batch Normalization Network 34 | ------------------------------------------------ 35 | 36 | .. autoclass:: tllib.normalization.ibn.InstanceBatchNorm2d 37 | 38 | .. autoclass:: tllib.normalization.ibn.IBNNet 39 | :members: 40 | 41 | .. automodule:: tllib.normalization.ibn 42 | :members: 43 | 44 | 45 | .. _MIXSTYLE: 46 | 47 | MixStyle: Domain Generalization with MixStyle 48 | ------------------------------------------------- 49 | 50 | .. autoclass:: tllib.normalization.mixstyle.MixStyle 51 | 52 | .. note:: 53 | MixStyle is only activated during `training` stage, with some probability :math:`p`. 54 | 55 | .. automodule:: tllib.normalization.mixstyle.resnet 56 | :members: 57 | -------------------------------------------------------------------------------- /tllib/self_training/uda.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class StrongWeakConsistencyLoss(nn.Module): 11 | """ 12 | Consistency loss between strong and weak augmented samples from `Unsupervised Data Augmentation for 13 | Consistency Training (NIPS 2020) `_. 14 | 15 | Args: 16 | threshold (float): Confidence threshold. 17 | temperature (float): Temperature. 18 | 19 | Inputs: 20 | - y_strong: unnormalized classifier predictions on strong augmented samples. 21 | - y: unnormalized classifier predictions on weak augmented samples. 22 | 23 | Shape: 24 | - y, y_strong: :math:`(minibatch, C)` where C means the number of classes. 25 | - Output: scalar. 26 | 27 | """ 28 | 29 | def __init__(self, threshold: float, temperature: float): 30 | super(StrongWeakConsistencyLoss, self).__init__() 31 | self.threshold = threshold 32 | self.temperature = temperature 33 | 34 | def forward(self, y_strong, y): 35 | confidence, _ = F.softmax(y.detach(), dim=1).max(dim=1) 36 | mask = (confidence > self.threshold).float() 37 | log_prob = F.log_softmax(y_strong / self.temperature, dim=1) 38 | con_loss = (F.kl_div(log_prob, F.softmax(y.detach(), dim=1), reduction='none').sum(dim=1)) 39 | con_loss = (con_loss * mask).sum() / max(mask.sum(), 1) 40 | 41 | return con_loss 42 | -------------------------------------------------------------------------------- /docs/tllib/regularization.rst: -------------------------------------------------------------------------------- 1 | =========================================== 2 | Regularization 3 | =========================================== 4 | 5 | .. _L2: 6 | 7 | L2 8 | ------ 9 | 10 | .. autoclass:: tllib.regularization.delta.L2Regularization 11 | 12 | 13 | .. _L2SP: 14 | 15 | L2-SP 16 | ------ 17 | 18 | .. autoclass:: tllib.regularization.delta.SPRegularization 19 | 20 | 21 | .. _DELTA: 22 | 23 | DELTA: DEep Learning Transfer using Feature Map with Attention 24 | ------------------------------------------------------------------------------------- 25 | 26 | .. autoclass:: tllib.regularization.delta.BehavioralRegularization 27 | 28 | .. autoclass:: tllib.regularization.delta.AttentionBehavioralRegularization 29 | 30 | .. autoclass:: tllib.regularization.delta.IntermediateLayerGetter 31 | 32 | 33 | .. _LWF: 34 | 35 | LWF: Learning without Forgetting 36 | ------------------------------------------ 37 | 38 | .. autoclass:: tllib.regularization.lwf.Classifier 39 | 40 | 41 | 42 | .. _CoTuning: 43 | 44 | Co-Tuning 45 | ------------------------------------------ 46 | 47 | .. autoclass:: tllib.regularization.co_tuning.CoTuningLoss 48 | 49 | .. autoclass:: tllib.regularization.co_tuning.Relationship 50 | 51 | 52 | .. _StochNorm: 53 | 54 | 55 | .. _BiTuning: 56 | 57 | Bi-Tuning 58 | ------------------------------------------ 59 | 60 | .. autoclass:: tllib.regularization.bi_tuning.BiTuning 61 | 62 | 63 | .. _BSS: 64 | 65 | BSS: Batch Spectral Shrinkage 66 | ------------------------------------------ 67 | 68 | .. autoclass:: tllib.regularization.bss.BatchSpectralShrinkage 69 | 70 | -------------------------------------------------------------------------------- /tllib/translation/cyclegan/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as T 8 | 9 | from tllib.vision.transforms import Denormalize 10 | 11 | 12 | class Translation(nn.Module): 13 | """ 14 | Image Translation Transform Module 15 | 16 | Args: 17 | generator (torch.nn.Module): An image generator, e.g. :meth:`~tllib.translation.cyclegan.resnet_9_generator` 18 | device (torch.device): device to put the generator. Default: 'cpu' 19 | mean (tuple): the normalized mean for image 20 | std (tuple): the normalized std for image 21 | Input: 22 | - image (PIL.Image): raw image in shape H x W x C 23 | 24 | Output: 25 | raw image in shape H x W x 3 26 | 27 | """ 28 | def __init__(self, generator, device=torch.device("cpu"), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): 29 | super(Translation, self).__init__() 30 | self.generator = generator.to(device) 31 | self.device = device 32 | self.pre_process = T.Compose([ 33 | T.ToTensor(), 34 | T.Normalize(mean, std) 35 | ]) 36 | self.post_process = T.Compose([ 37 | Denormalize(mean, std), 38 | T.ToPILImage() 39 | ]) 40 | 41 | def forward(self, image): 42 | image = self.pre_process(image.copy()) # C x H x W 43 | image = image.to(self.device) 44 | generated_image = self.generator(image.unsqueeze(dim=0)).squeeze(dim=0).cpu() 45 | return self.post_process(generated_image) 46 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_vgg_16_cityscapes.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "DecoupledGeneralizedRCNN" 3 | WEIGHTS: 'https://open-mmlab.oss-cn-beijing.aliyuncs.com/pretrain/vgg16_caffe-292e1171.pth' 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | MASK_ON: False 7 | BACKBONE: 8 | NAME: "build_vgg_fpn_backbone" 9 | ROI_HEADS: 10 | IN_FEATURES: ["p3", "p4", "p5", "p6"] 11 | NAME: "DecoupledStandardROIHeads" 12 | NUM_CLASSES: 8 13 | ROI_BOX_HEAD: 14 | NAME: "FastRCNNConvFCHead" 15 | NUM_FC: 2 16 | POOLER_RESOLUTION: 7 17 | ANCHOR_GENERATOR: 18 | SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ] # One size for each in feature map 19 | ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ] # Three aspect 20 | RPN: 21 | IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"] 22 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 23 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 24 | POST_NMS_TOPK_TRAIN: 1000 25 | POST_NMS_TOPK_TEST: 1000 26 | PROPOSAL_GENERATOR: 27 | NAME: "TLRPN" 28 | INPUT: 29 | FORMAT: "RGB" 30 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 31 | MIN_SIZE_TEST: 800 32 | MAX_SIZE_TEST: 1280 33 | MAX_SIZE_TRAIN: 1280 34 | DATASETS: 35 | TRAIN: ("cityscapes_trainval",) 36 | TEST: ("cityscapes_test",) 37 | SOLVER: 38 | STEPS: (3999, ) 39 | MAX_ITER: 4000 # 4 epochs 40 | WARMUP_ITERS: 100 41 | CHECKPOINT_PERIOD: 1000 42 | IMS_PER_BATCH: 8 43 | BASE_LR: 0.01 44 | LR_SCHEDULER_NAME: "ExponentialLR" 45 | GAMMA: 0.1 46 | TEST: 47 | EVAL_PERIOD: 500 48 | VIS_PERIOD: 20 49 | VERSION: 2 -------------------------------------------------------------------------------- /tllib/vision/datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | 6 | from torchvision.datasets.folder import ImageFolder 7 | import random 8 | 9 | 10 | class Resisc45(ImageFolder): 11 | """`Resisc45 `_ dataset \ 12 | is a scene classification task from remote sensing images. There are 45 classes, \ 13 | containing 700 images each, including tennis court, ship, island, lake, \ 14 | parking lot, sparse residential, or stadium. \ 15 | The image size is RGB 256x256 pixels. 16 | 17 | .. note:: You need to download the source data manually into `root` directory. 18 | 19 | Args: 20 | root (str): Root directory of dataset 21 | split (str, optional): The dataset split, supports ``train``, or ``test``. 22 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 23 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 24 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 25 | 26 | """ 27 | def __init__(self, root, split='train', download=False, **kwargs): 28 | super(Resisc45, self).__init__(root, **kwargs) 29 | random.seed(0) 30 | random.shuffle(self.samples) 31 | if split == 'train': 32 | self.samples = self.samples[:25200] 33 | else: 34 | self.samples = self.samples[25200:] 35 | 36 | @property 37 | def num_classes(self) -> int: 38 | """Number of classes""" 39 | return len(self.classes) 40 | -------------------------------------------------------------------------------- /tllib/ranking/transrate.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Louis Fouquet 3 | @contact: louisfouquet75@gmail.com 4 | """ 5 | import numpy as np 6 | 7 | __all__ = ['transrate'] 8 | 9 | 10 | def coding_rate(features: np.ndarray, eps=1e-4): 11 | f = features 12 | n, d = f.shape 13 | (_, rate) = np.linalg.slogdet((np.eye(d) + 1 / (n * eps) * f.transpose() @ f)) 14 | return 0.5 * rate 15 | 16 | 17 | def transrate(features: np.ndarray, labels: np.ndarray, eps=1e-4): 18 | r""" 19 | TransRate in `Frustratingly easy transferability estimation (ICML 2022) 20 | `_. 21 | 22 | The TransRate :math:`TrR` can be described as: 23 | 24 | .. math:: 25 | TrR= R\left(f, \espilon \right) - R\left(f, \espilon \mid y \right) 26 | 27 | where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector, 28 | :math:`R` is the coding rate with distortion rate :math:`\epsilon` 29 | 30 | Args: 31 | features (np.ndarray):features extracted by pre-trained model. 32 | labels (np.ndarray): groud-truth labels. 33 | eps (float, optional): distortion rare (Default: 1e-4) 34 | 35 | Shape: 36 | - features: (N, F), with number of samples N and feature dimension F. 37 | - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`. 38 | - score: scalar. 39 | """ 40 | f = features 41 | y = labels 42 | f = f - np.mean(f, axis=0, keepdims=True) 43 | Rf = coding_rate(f, eps) 44 | Rfy = 0.0 45 | C = int(y.max() + 1) 46 | for i in range(C): 47 | Rfy += coding_rate(f[(y == i).flatten()], eps) 48 | return Rf - Rfy / C 49 | -------------------------------------------------------------------------------- /examples/domain_adaptation/image_regression/erm.sh: -------------------------------------------------------------------------------- 1 | # DSprites 2 | CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_C2N 3 | CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_C2S 4 | CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_N2C 5 | CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_N2S 6 | CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_S2C 7 | CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_S2N 8 | 9 | # MPI3D 10 | CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RL2RC 11 | CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RL2T 12 | CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RC2RL 13 | CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RC2T 14 | CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_T2RL 15 | CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_T2RC 16 | -------------------------------------------------------------------------------- /examples/domain_adaptation/image_regression/rsd.sh: -------------------------------------------------------------------------------- 1 | # DSprites 2 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_C2N 3 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_C2S 4 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_N2C 5 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_N2S 6 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_S2C 7 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_S2N 8 | 9 | # MPI3D 10 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RL2RC --resize-size 224 11 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RL2T --resize-size 224 12 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RC2RL --resize-size 224 13 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RC2T --resize-size 224 14 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_T2RL --resize-size 224 15 | CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_T2RC --resize-size 224 16 | -------------------------------------------------------------------------------- /examples/domain_adaptation/semantic_segmentation/cycada.sh: -------------------------------------------------------------------------------- 1 | # GTA5 to Cityscapes 2 | # First, train the CycleGAN 3 | CUDA_VISIBLE_DEVICES=0 python cycada.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \ 4 | --log logs/cycada/gtav2cityscapes --pretrain logs/src_only/gtav2cityscapes/checkpoints/59.pth \ 5 | --translated-root data/GTA52Cityscapes/cycada_39 6 | # Then, train the src_only model on the translated source dataset 7 | CUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA52Cityscapes/cycada_39 data/Cityscapes \ 8 | -s GTA5 -t Cityscapes --log logs/cycada_src_only/gtav2cityscapes 9 | 10 | 11 | ## Synthia to Cityscapes 12 | # First, train the Cycada 13 | CUDA_VISIBLE_DEVICES=0 python cycada.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \ 14 | --log logs/cycada/synthia2cityscapes --pretrain logs/src_only/synthia2cityscapes/checkpoints/59.pth \ 15 | --translated-root data/Synthia2Cityscapes/cycada_39 16 | # Then, train the src_only model on the translated source dataset 17 | CUDA_VISIBLE_DEVICES=0 python source_only.py data/Synthia2Cityscapes/cycada_39 data/Cityscapes \ 18 | -s Synthia -t Cityscapes --log logs/cycada_src_only/synthia2cityscapes 19 | 20 | 21 | # Cityscapes to FoggyCityscapes 22 | # First, train the CycleGAN 23 | CUDA_VISIBLE_DEVICES=0 python cycada.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \ 24 | --log logs/cycada/cityscapes2foggy --pretrain logs/src_only/cityscapes2foggy/checkpoints/59.pth \ 25 | --translated-root data/Cityscapes2Foggy/cycada_39 26 | # Then, train the src_only model on the translated source dataset 27 | CUDA_VISIBLE_DEVICES=0 python source_only.py data/Cityscapes2Foggy/cycada_39 data/Cityscapes \ 28 | -s Cityscapes -t FoggyCityscapes --log logs/cycada_src_only/cityscapes2foggy 29 | -------------------------------------------------------------------------------- /tllib/regularization/bss.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yifei Ji 3 | @contact: jiyf990330@163.com 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | __all__ = ['BatchSpectralShrinkage'] 9 | 10 | 11 | class BatchSpectralShrinkage(nn.Module): 12 | r""" 13 | The regularization term in `Catastrophic Forgetting Meets Negative Transfer: 14 | Batch Spectral Shrinkage for Safe Transfer Learning (NIPS 2019) `_. 15 | 16 | 17 | The BSS regularization of feature matrix :math:`F` can be described as: 18 | 19 | .. math:: 20 | L_{bss}(F) = \sum_{i=1}^{k} \sigma_{-i}^2 , 21 | 22 | where :math:`k` is the number of singular values to be penalized, :math:`\sigma_{-i}` is the :math:`i`-th smallest singular value of feature matrix :math:`F`. 23 | 24 | All the singular values of feature matrix :math:`F` are computed by `SVD`: 25 | 26 | .. math:: 27 | F = U\Sigma V^T, 28 | 29 | where the main diagonal elements of the singular value matrix :math:`\Sigma` is :math:`[\sigma_1, \sigma_2, ..., \sigma_b]`. 30 | 31 | 32 | Args: 33 | k (int): The number of singular values to be penalized. Default: 1 34 | 35 | Shape: 36 | - Input: :math:`(b, |\mathcal{f}|)` where :math:`b` is the batch size and :math:`|\mathcal{f}|` is feature dimension. 37 | - Output: scalar. 38 | 39 | """ 40 | def __init__(self, k=1): 41 | super(BatchSpectralShrinkage, self).__init__() 42 | self.k = k 43 | 44 | def forward(self, feature): 45 | result = 0 46 | u, s, v = torch.svd(feature.t()) 47 | num = s.size(0) 48 | for i in range(self.k): 49 | result += torch.pow(s[num-1-i], 2) 50 | return result 51 | -------------------------------------------------------------------------------- /examples/domain_adaptation/image_regression/dann.sh: -------------------------------------------------------------------------------- 1 | # DSprites 2 | CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_C2N 3 | CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_C2S 4 | CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_N2C 5 | CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_N2S 6 | CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_S2C 7 | CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_S2N 8 | 9 | # MPI3D 10 | CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RL2RC --resize-size 224 11 | CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RL2T --resize-size 224 12 | CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RC2RL --resize-size 224 13 | CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RC2T --resize-size 224 14 | CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_T2RL --resize-size 224 15 | CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_T2RC --resize-size 224 16 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/config/faster_rcnn_vgg_16_cityscapes.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "TLGeneralizedRCNN" 3 | WEIGHTS: 'https://open-mmlab.oss-cn-beijing.aliyuncs.com/pretrain/vgg16_caffe-292e1171.pth' 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | MASK_ON: False 7 | BACKBONE: 8 | NAME: "build_vgg_fpn_backbone" 9 | ROI_HEADS: 10 | IN_FEATURES: ["p3", "p4", "p5", "p6"] 11 | NAME: "TLStandardROIHeads" 12 | NUM_CLASSES: 8 13 | ROI_BOX_HEAD: 14 | NAME: "FastRCNNConvFCHead" 15 | NUM_FC: 2 16 | POOLER_RESOLUTION: 7 17 | ANCHOR_GENERATOR: 18 | SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ] # One size for each in feature map 19 | ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ] # Three aspect 20 | RPN: 21 | IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"] 22 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 23 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 24 | # Detectron1 uses 2000 proposals per-batch, 25 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 26 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 27 | POST_NMS_TOPK_TRAIN: 1000 28 | POST_NMS_TOPK_TEST: 1000 29 | PROPOSAL_GENERATOR: 30 | NAME: "TLRPN" 31 | INPUT: 32 | FORMAT: "RGB" 33 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 34 | MIN_SIZE_TEST: 800 35 | MAX_SIZE_TEST: 1280 36 | MAX_SIZE_TRAIN: 1280 37 | DATASETS: 38 | TRAIN: ("cityscapes_trainval",) 39 | TEST: ("cityscapes_test",) 40 | SOLVER: 41 | STEPS: (12000,) 42 | MAX_ITER: 16000 # 16 epochs 43 | WARMUP_ITERS: 100 44 | CHECKPOINT_PERIOD: 2000 45 | IMS_PER_BATCH: 8 46 | BASE_LR: 0.01 47 | TEST: 48 | EVAL_PERIOD: 2000 49 | VIS_PERIOD: 500 50 | VERSION: 2 -------------------------------------------------------------------------------- /tllib/vision/datasets/food101.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | from torchvision.datasets.folder import ImageFolder 6 | import os.path as osp 7 | from ._util import download as download_data, check_exits 8 | 9 | 10 | class Food101(ImageFolder): 11 | """`Food-101 `_ is a dataset 12 | for fine-grained visual recognition with 101,000 images in 101 food categories. 13 | 14 | Args: 15 | root (str): Root directory of dataset. 16 | split (str, optional): The dataset split, supports ``train``, or ``test``. 17 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 18 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 19 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 20 | in root directory. If dataset is already downloaded, it is not downloaded again. 21 | 22 | .. note:: In `root`, there will exist following files after downloading. 23 | :: 24 | train/ 25 | test/ 26 | """ 27 | download_list = [ 28 | ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/1d7bd727cc1e4ce2bef5/?dl=1"), 29 | ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/7e11992d7495417db32b/?dl=1") 30 | ] 31 | 32 | def __init__(self, root, split='train', transform=None, download=True): 33 | if download: 34 | list(map(lambda args: download_data(root, *args), self.download_list)) 35 | else: 36 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) 37 | super(Food101, self).__init__(osp.join(root, split), transform=transform) 38 | self.num_classes = 101 39 | -------------------------------------------------------------------------------- /tllib/vision/datasets/patchcamelyon.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from .imagelist import ImageList 7 | from ._util import download as download_data, check_exits 8 | 9 | 10 | class PatchCamelyon(ImageList): 11 | """ 12 | The `PatchCamelyon `_ dataset contains \ 13 | 327680 images of histopathologic scans of lymph node sections. \ 14 | The classification task consists in predicting the presence of metastatic tissue \ 15 | in given image (i.e., two classes). All images are 96x96 pixels 16 | 17 | Args: 18 | root (str): Root directory of dataset 19 | split (str, optional): The dataset split, supports ``train``, or ``test``. 20 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 21 | in root directory. If dataset is already downloaded, it is not downloaded again. 22 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 23 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 24 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 25 | """ 26 | CLASSES = ['0', '1'] 27 | 28 | def __init__(self, root, split, download=False, **kwargs): 29 | if download: 30 | download_data(root, "patch_camelyon", "patch_camelyon.tgz", "https://cloud.tsinghua.edu.cn/f/21360b3441a54274b843/?dl=1") 31 | else: 32 | check_exits(root, "patch_camelyon") 33 | 34 | root = os.path.join(root, "patch_camelyon") 35 | super(PatchCamelyon, self).__init__(root, PatchCamelyon.CLASSES, os.path.join(root, "imagelist", "{}.txt".format(split)), **kwargs) 36 | 37 | -------------------------------------------------------------------------------- /tllib/translation/spgan/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/Simon4Yan/eSPGAN 3 | @author: Baixu Chen 4 | @contact: cbx_99_hasta@outlook.com 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | class ContrastiveLoss(torch.nn.Module): 11 | r"""Contrastive loss from `Dimensionality Reduction by Learning an Invariant Mapping (CVPR 2006) 12 | `_. 13 | 14 | Given output features :math:`f_1, f_2`, we use :math:`D` to denote the pairwise euclidean distance between them, 15 | :math:`Y` to denote the ground truth labels, :math:`m` to denote a pre-defined margin, then contrastive loss is 16 | calculated as 17 | 18 | .. math:: 19 | (1 - Y)\frac{1}{2}D^2 + (Y)\frac{1}{2}\{\text{max}(0, m-D)^2\} 20 | 21 | Args: 22 | margin (float, optional): margin for contrastive loss. Default: 2.0 23 | 24 | Inputs: 25 | - output1 (tensor): feature representations of the first set of samples (:math:`f_1` here). 26 | - output2 (tensor): feature representations of the second set of samples (:math:`f_2` here). 27 | - label (tensor): labels (:math:`Y` here). 28 | 29 | Shape: 30 | - output1, output2: :math:`(minibatch, F)` where F means the dimension of input features. 31 | - label: :math:`(minibatch, )` 32 | """ 33 | def __init__(self, margin=2.0): 34 | super(ContrastiveLoss, self).__init__() 35 | self.margin = margin 36 | 37 | def forward(self, output1, output2, label): 38 | euclidean_distance = F.pairwise_distance(output1, output2) 39 | loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + 40 | label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) 41 | 42 | return loss 43 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_text/README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Domain Adaptation for WILDS (Text Classification) 2 | 3 | ## Installation 4 | 5 | It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results. 6 | 7 | You need to run 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Dataset 14 | 15 | Following datasets can be downloaded automatically: 16 | 17 | - [CivilComments (WILDS)](https://wilds.stanford.edu/datasets/) 18 | - [Amazon (WILDS)](https://wilds.stanford.edu/datasets/) 19 | 20 | ## Supported Methods 21 | 22 | TODO 23 | 24 | ## Usage 25 | 26 | The shell files give all the training scripts we use, e.g. 27 | 28 | ``` 29 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "civilcomments" --unlabeled-list "extra_unlabeled" \ 30 | --uniform-over-groups --groupby-fields y black --max-token-length 300 --lr 1e-05 --metric "acc_wg" \ 31 | --seed 0 --deterministic --log logs/erm/civilcomments 32 | ``` 33 | 34 | ## Results 35 | 36 | ### Performance on WILDS-CivilComments (DistilBert) 37 | 38 | | Methods | Val Avg Acc | Val Worst-Group Acc | Test Avg Acc | Test Worst-Group Acc | GPU Memory Usage(GB)| 39 | | --- | --- | --- | --- | --- | --- | 40 | | ERM | 89.2 | 67.7 | 88.9 | 68.5 | 6.4 | 41 | 42 | ### Performance on WILDS-Amazon (DistilBert) 43 | 44 | | Methods | Val Avg Acc | Test Avg Acc | Val 10% Acc | Test 10% Acc | GPU Memory Usage(GB)| 45 | | --- | --- | --- | --- | --- | --- | 46 | | ERM | 72.6 | 71.6 | 54.7 | 53.8 | 12.8 | 47 | 48 | ### Visualization 49 | 50 | We use tensorboard to record the training process and visualize the outputs of the models. 51 | 52 | ``` 53 | tensorboard --logdir=logs 54 | ``` 55 | 56 | #### WILDS-CivilComments 57 | 58 | 59 | 60 | #### WILDS-Amazon 61 | 62 | -------------------------------------------------------------------------------- /tllib/self_training/pseudo_label.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ConfidenceBasedSelfTrainingLoss(nn.Module): 11 | """ 12 | Self training loss that adopts confidence threshold to select reliable pseudo labels from 13 | `Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks (ICML 2013) 14 | `_. 15 | 16 | Args: 17 | threshold (float): Confidence threshold. 18 | 19 | Inputs: 20 | - y: unnormalized classifier predictions. 21 | - y_target: unnormalized classifier predictions which will used for generating pseudo labels. 22 | 23 | Returns: 24 | A tuple, including 25 | - self_training_loss: self training loss with pseudo labels. 26 | - mask: binary mask that indicates which samples are retained (whose confidence is above the threshold). 27 | - pseudo_labels: generated pseudo labels. 28 | 29 | Shape: 30 | - y, y_target: :math:`(minibatch, C)` where C means the number of classes. 31 | - self_training_loss: scalar. 32 | - mask, pseudo_labels :math:`(minibatch, )`. 33 | 34 | """ 35 | 36 | def __init__(self, threshold: float): 37 | super(ConfidenceBasedSelfTrainingLoss, self).__init__() 38 | self.threshold = threshold 39 | 40 | def forward(self, y, y_target): 41 | confidence, pseudo_labels = F.softmax(y_target.detach(), dim=1).max(dim=1) 42 | mask = (confidence > self.threshold).float() 43 | self_training_loss = (F.cross_entropy(y, pseudo_labels, reduction='none') * mask).mean() 44 | 45 | return self_training_loss, mask, pseudo_labels 46 | -------------------------------------------------------------------------------- /tllib/translation/spgan/siamese.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/Simon4Yan/eSPGAN 3 | @author: Baixu Chen 4 | @contact: cbx_99_hasta@outlook.com 5 | """ 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ConvBlock(nn.Module): 11 | """Basic block with structure Conv-LeakyReLU->Pool""" 12 | def __init__(self, in_dim, out_dim): 13 | super(ConvBlock, self).__init__() 14 | self.conv_block = nn.Sequential( 15 | nn.Conv2d(in_dim, out_dim, kernel_size=4, stride=2, padding=1), 16 | nn.LeakyReLU(0.2), 17 | nn.MaxPool2d(kernel_size=2, stride=2) 18 | ) 19 | 20 | def forward(self, x): 21 | return self.conv_block(x) 22 | 23 | 24 | class SiameseNetwork(nn.Module): 25 | """Siamese network whose input is an image of shape :math:`(3,H,W)` and output is an one-dimensional feature vector. 26 | 27 | Args: 28 | nsf (int): dimension of output feature representation. 29 | """ 30 | def __init__(self, nsf=64): 31 | super(SiameseNetwork, self).__init__() 32 | self.conv = nn.Sequential( 33 | nn.Conv2d(3, nsf, kernel_size=4, stride=2, padding=1), 34 | nn.LeakyReLU(0.2), 35 | nn.MaxPool2d(kernel_size=2, stride=2), 36 | ConvBlock(nsf, nsf * 2), 37 | ConvBlock(nsf * 2, nsf * 4), 38 | ) 39 | self.flatten = nn.Flatten() 40 | self.fc1 = nn.Linear(2048, nsf * 2, bias=False) 41 | self.leaky_relu = nn.LeakyReLU(0.2) 42 | self.dropout = nn.Dropout(0.5) 43 | self.fc2 = nn.Linear(nsf * 2, nsf, bias=False) 44 | 45 | def forward(self, x): 46 | x = self.flatten(self.conv(x)) 47 | x = self.fc1(x) 48 | x = self.leaky_relu(x) 49 | x = self.dropout(x) 50 | x = self.fc2(x) 51 | x = F.normalize(x) 52 | return x 53 | -------------------------------------------------------------------------------- /tllib/utils/analysis/tsne.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch 6 | import matplotlib 7 | 8 | matplotlib.use('Agg') 9 | from sklearn.manifold import TSNE 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import matplotlib.colors as col 13 | 14 | 15 | def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor, 16 | filename: str, source_color='r', target_color='b'): 17 | """ 18 | Visualize features from different domains using t-SNE. 19 | 20 | Args: 21 | source_feature (tensor): features from source domain in shape :math:`(minibatch, F)` 22 | target_feature (tensor): features from target domain in shape :math:`(minibatch, F)` 23 | filename (str): the file name to save t-SNE 24 | source_color (str): the color of the source features. Default: 'r' 25 | target_color (str): the color of the target features. Default: 'b' 26 | 27 | """ 28 | source_feature = source_feature.numpy() 29 | target_feature = target_feature.numpy() 30 | features = np.concatenate([source_feature, target_feature], axis=0) 31 | 32 | # map features to 2-d using TSNE 33 | X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features) 34 | 35 | # domain labels, 1 represents source while 0 represents target 36 | domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature)))) 37 | 38 | # visualize using matplotlib 39 | fig, ax = plt.subplots(figsize=(10, 10)) 40 | ax.spines['top'].set_visible(False) 41 | ax.spines['right'].set_visible(False) 42 | ax.spines['bottom'].set_visible(False) 43 | ax.spines['left'].set_visible(False) 44 | plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20) 45 | plt.xticks([]) 46 | plt.yticks([]) 47 | plt.savefig(filename) 48 | -------------------------------------------------------------------------------- /examples/domain_generalization/re_identification/ibn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Market1501 -> Duke 3 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a resnet50_ibn_a \ 4 | --finetune --seed 0 --log logs/ibn/Market2Duke 5 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a resnet50_ibn_b \ 6 | --finetune --seed 0 --log logs/ibn/Market2Duke 7 | 8 | # Duke -> Market1501 9 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a resnet50_ibn_a \ 10 | --finetune --seed 0 --log logs/ibn/Duke2Market 11 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a resnet50_ibn_b \ 12 | --finetune --seed 0 --log logs/ibn/Duke2Market 13 | 14 | # Market1501 -> MSMT 15 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a resnet50_ibn_a \ 16 | --finetune --seed 0 --log logs/ibn/Market2MSMT 17 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a resnet50_ibn_b \ 18 | --finetune --seed 0 --log logs/ibn/Market2MSMT 19 | 20 | # MSMT -> Market1501 21 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a resnet50_ibn_a \ 22 | --finetune --seed 0 --log logs/ibn/MSMT2Market 23 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a resnet50_ibn_b \ 24 | --finetune --seed 0 --log logs/ibn/MSMT2Market 25 | 26 | # Duke -> MSMT 27 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a resnet50_ibn_a \ 28 | --finetune --seed 0 --log logs/ibn/Duke2MSMT 29 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a resnet50_ibn_b \ 30 | --finetune --seed 0 --log logs/ibn/Duke2MSMT 31 | 32 | # MSMT -> Duke 33 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a resnet50_ibn_a \ 34 | --finetune --seed 0 --log logs/ibn/MSMT2Duke 35 | CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a resnet50_ibn_b \ 36 | --finetune --seed 0 --log logs/ibn/MSMT2Duke 37 | -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/d_adapt/README.md: -------------------------------------------------------------------------------- 1 | # Decoupled Adaptation for Cross-Domain Object Detection 2 | 3 | ## Installation 4 | Our code is based on 5 | - [Detectron latest(v0.6)](https://detectron2.readthedocs.io/en/latest/tutorials/install.html) 6 | - [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models) 7 | 8 | please install them before usage. 9 | 10 | ## Method 11 | Compared with previous cross-domain object detection methods, D-adapt decouples the adversarial adaptation from the training of detector. 12 |
13 | Editor 14 |
15 | 16 | The whole pipeline is as follows: 17 |
18 | Editor 19 |
20 | 21 | First, you need to run ``source_only.py`` to obtain pre-trained models. (See source_only.sh for scripts.) 22 | Then you need to run ``d_adapt.py`` to obtain adapted models. (See d_adapt.sh for scripts). 23 | When the domain discrepancy is large, you need to run ``d_adapt.py`` multiple times. 24 | 25 | For better readability, we implement the training of category adaptor in ``category_adaptation.py``, 26 | implement the training of the bounding box adaptor in``bbox_adaptation.py``, 27 | and implement the training of the detector and connect the above components in ``d_adapt.py``. 28 | This can facilitate you to modify and replace other adaptors. 29 | 30 | We provide independent training arguments for detector, category adaptor and bounding box adaptor. 31 | The arguments of latter two end with ``-c`` and ``-b`` respectively. 32 | 33 | 34 | ## Citation 35 | If you use these methods in your research, please consider citing. 36 | 37 | ``` 38 | @inproceedings{jiang2021decoupled, 39 | title = {Decoupled Adaptation for Cross-Domain Object Detection}, 40 | author = {Junguang Jiang and Baixu Chen and Jianmin Wang and Mingsheng Long}, 41 | booktitle = {ICLR}, 42 | year = {2022} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /tllib/vision/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from .imagelist import ImageList 7 | from ._util import download as download_data, check_exits 8 | 9 | 10 | class EuroSAT(ImageList): 11 | """ 12 | `EuroSAT `_ dataset consists in classifying \ 13 | Sentinel-2 satellite images into 10 different types of land use (Residential, \ 14 | Industrial, River, Highway, etc). \ 15 | The spatial resolution corresponds to 10 meters per pixel, and the image size \ 16 | is 64x64 pixels. 17 | 18 | Args: 19 | root (str): Root directory of dataset 20 | split (str, optional): The dataset split, supports ``train``, or ``test``. 21 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 22 | in root directory. If dataset is already downloaded, it is not downloaded again. 23 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 24 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 25 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 26 | """ 27 | CLASSES =['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 28 | 'PermanentCrop', 'Residential', 'River', 'SeaLake'] 29 | 30 | def __init__(self, root, split='train', download=False, **kwargs): 31 | if download: 32 | download_data(root, "eurosat", "eurosat.tgz", "https://cloud.tsinghua.edu.cn/f/9983d7ab86184d74bb17/?dl=1") 33 | else: 34 | check_exits(root, "eurosat") 35 | split = 'train[:21600]' if split == 'train' else 'train[21600:]' 36 | 37 | root = os.path.join(root, "eurosat") 38 | super(EuroSAT, self).__init__(root, EuroSAT.CLASSES, os.path.join(root, "imagelist", "{}.txt".format(split)), **kwargs) 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/erm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ResNet50, PACS 3 | CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_P 4 | CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_A 5 | CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_C 6 | CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_S 7 | 8 | # ResNet50, Office-Home 9 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/erm/OfficeHome_Pr 10 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/erm/OfficeHome_Rw 11 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/erm/OfficeHome_Cl 12 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/erm/OfficeHome_Ar 13 | 14 | # ResNet50, DomainNet 15 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_c 16 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_i 17 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_p 18 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_q 19 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_r 20 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_s 21 | -------------------------------------------------------------------------------- /tllib/vision/datasets/segmentation/gta5.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from .segmentation_list import SegmentationList 7 | from .cityscapes import Cityscapes 8 | from .._util import download as download_data 9 | 10 | 11 | class GTA5(SegmentationList): 12 | """`GTA5 `_ 13 | 14 | Args: 15 | root (str): Root directory of dataset 16 | split (str, optional): The dataset split, supports ``train``. 17 | data_folder (str, optional): Sub-directory of the image. Default: 'images'. 18 | label_folder (str, optional): Sub-directory of the label. Default: 'labels'. 19 | mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None. 20 | transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \ 21 | and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`. 22 | 23 | .. note:: You need to download GTA5 manually. 24 | Ensure that there exist following directories in the `root` directory before you using this class. 25 | :: 26 | images/ 27 | labels/ 28 | """ 29 | download_list = [ 30 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/f719733e339544e9a330/?dl=1"), 31 | ] 32 | 33 | def __init__(self, root, split='train', data_folder='images', label_folder='labels', **kwargs): 34 | assert split in ['train'] 35 | # download meta information from Internet 36 | list(map(lambda args: download_data(root, *args), self.download_list)) 37 | data_list_file = os.path.join(root, "image_list", "{}.txt".format(split)) 38 | self.split = split 39 | super(GTA5, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, data_folder, label_folder, 40 | id_to_train_id=Cityscapes.ID_TO_TRAIN_ID, 41 | train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs) 42 | -------------------------------------------------------------------------------- /examples/domain_adaptation/image_regression/dd.sh: -------------------------------------------------------------------------------- 1 | # DSprites 2 | CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_C2N --wd 0.0005 3 | CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_C2S --wd 0.0005 4 | CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_N2C --wd 0.0005 5 | CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_N2S --wd 0.0005 6 | CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_S2C --wd 0.0005 7 | CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_S2N --wd 0.0005 8 | 9 | # MPI3D 10 | CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RL2RC --normalization IN --resize-size 224 --weight-decay 0.001 11 | CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RL2T --normalization IN --resize-size 224 --weight-decay 0.001 12 | CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RC2RL --normalization IN --resize-size 224 --weight-decay 0.001 13 | CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RC2T --normalization IN --resize-size 224 --weight-decay 0.001 14 | CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_T2RL --normalization IN --resize-size 224 --weight-decay 0.001 15 | CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_T2RC --normalization IN --resize-size 224 --weight-decay 0.001 16 | -------------------------------------------------------------------------------- /tllib/ranking/leep.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yong Liu 3 | @contact: liuyong1095556447@163.com 4 | """ 5 | 6 | import numpy as np 7 | 8 | __all__ = ['log_expected_empirical_prediction'] 9 | 10 | 11 | def log_expected_empirical_prediction(predictions: np.ndarray, labels: np.ndarray): 12 | r""" 13 | Log Expected Empirical Prediction in `LEEP: A New Measure to 14 | Evaluate Transferability of Learned Representations (ICML 2020) 15 | `_. 16 | 17 | The LEEP :math:`\mathcal{T}` can be described as: 18 | 19 | .. math:: 20 | \mathcal{T}=\mathbb{E}\log \left(\sum_{z \in \mathcal{C}_s} \hat{P}\left(y \mid z\right) \theta\left(y \right)_{z}\right) 21 | 22 | where :math:`\theta\left(y\right)_{z}` is the predictions of pre-trained model on source category, :math:`\hat{P}\left(y \mid z\right)` is the empirical conditional distribution estimated by prediction and ground-truth label. 23 | 24 | Args: 25 | predictions (np.ndarray): predictions of pre-trained model. 26 | labels (np.ndarray): groud-truth labels. 27 | 28 | Shape: 29 | - predictions: (N, :math:`C_s`), with number of samples N and source class number :math:`C_s`. 30 | - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`. 31 | - score: scalar 32 | """ 33 | N, C_s = predictions.shape 34 | labels = labels.reshape(-1) 35 | C_t = int(np.max(labels) + 1) 36 | 37 | normalized_prob = predictions / float(N) 38 | joint = np.zeros((C_t, C_s), dtype=float) # placeholder for joint distribution over (y, z) 39 | 40 | for i in range(C_t): 41 | this_class = normalized_prob[labels == i] 42 | row = np.sum(this_class, axis=0) 43 | joint[i] = row 44 | 45 | p_target_given_source = (joint / joint.sum(axis=0, keepdims=True)).T # P(y | z) 46 | empirical_prediction = predictions @ p_target_given_source 47 | empirical_prob = np.array([predict[label] for predict, label in zip(empirical_prediction, labels)]) 48 | score = np.mean(np.log(empirical_prob)) 49 | 50 | return score 51 | -------------------------------------------------------------------------------- /tllib/self_training/self_ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | from typing import Optional 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from tllib.modules.classifier import Classifier as ClassifierBase 11 | 12 | 13 | class ClassBalanceLoss(nn.Module): 14 | r""" 15 | Class balance loss that penalises the network for making predictions that exhibit large class imbalance. 16 | Given predictions :math:`p` with dimension :math:`(N, C)`, we first calculate 17 | the mini-batch mean per-class probability :math:`p_{mean}` with dimension :math:`(C, )`, where 18 | 19 | .. math:: 20 | p_{mean}^j = \frac{1}{N} \sum_{i=1}^N p_i^j 21 | 22 | Then we calculate binary cross entropy loss between :math:`p_{mean}` and uniform probability vector :math:`u` with 23 | the same dimension where :math:`u^j` = :math:`\frac{1}{C}` 24 | 25 | .. math:: 26 | loss = \text{BCELoss}(p_{mean}, u) 27 | 28 | Args: 29 | num_classes (int): Number of classes 30 | 31 | Inputs: 32 | - p (tensor): predictions from classifier 33 | 34 | Shape: 35 | - p: :math:`(N, C)` where C means the number of classes. 36 | """ 37 | 38 | def __init__(self, num_classes): 39 | super(ClassBalanceLoss, self).__init__() 40 | self.uniform_distribution = torch.ones(num_classes) / num_classes 41 | 42 | def forward(self, p: torch.Tensor): 43 | return F.binary_cross_entropy(p.mean(dim=0), self.uniform_distribution.to(p.device)) 44 | 45 | 46 | class ImageClassifier(ClassifierBase): 47 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): 48 | bottleneck = nn.Sequential( 49 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)), 50 | # nn.Flatten(), 51 | nn.Linear(backbone.out_features, bottleneck_dim), 52 | nn.BatchNorm1d(bottleneck_dim), 53 | nn.ReLU() 54 | ) 55 | super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) 56 | -------------------------------------------------------------------------------- /tllib/vision/datasets/_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from typing import List 7 | from torchvision.datasets.utils import download_and_extract_archive 8 | 9 | 10 | def download(root: str, file_name: str, archive_name: str, url_link: str): 11 | """ 12 | Download file from internet url link. 13 | 14 | Args: 15 | root (str) The directory to put downloaded files. 16 | file_name: (str) The name of the unzipped file. 17 | archive_name: (str) The name of archive(zipped file) downloaded. 18 | url_link: (str) The url link to download data. 19 | 20 | .. note:: 21 | If `file_name` already exists under path `root`, then it is not downloaded again. 22 | Else `archive_name` will be downloaded from `url_link` and extracted to `file_name`. 23 | """ 24 | if not os.path.exists(os.path.join(root, file_name)): 25 | print("Downloading {}".format(file_name)) 26 | # if os.path.exists(os.path.join(root, archive_name)): 27 | # os.remove(os.path.join(root, archive_name)) 28 | try: 29 | download_and_extract_archive(url_link, download_root=root, filename=archive_name, remove_finished=False) 30 | except Exception: 31 | print("Fail to download {} from url link {}".format(archive_name, url_link)) 32 | print('Please check you internet connection.' 33 | "Simply trying again may be fine.") 34 | exit(0) 35 | 36 | 37 | def check_exits(root: str, file_name: str): 38 | """Check whether `file_name` exists under directory `root`. """ 39 | if not os.path.exists(os.path.join(root, file_name)): 40 | print("Dataset directory {} not found under {}".format(file_name, root)) 41 | exit(-1) 42 | 43 | 44 | def read_list_from_file(file_name: str) -> List[str]: 45 | """Read data from file and convert each line into an element in the list""" 46 | result = [] 47 | with open(file_name, "r") as f: 48 | for line in f.readlines(): 49 | result.append(line.strip()) 50 | return result 51 | -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/mldg.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ResNet50, PACS 3 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_P 4 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_A 5 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_C 6 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_S 7 | 8 | # ResNet50, Office-Home 9 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Pr 10 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Rw 11 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Cl 12 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Ar 13 | 14 | # ResNet50, DomainNet 15 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_c 16 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_i 17 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_p 18 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_q 19 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_r 20 | CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_s 21 | -------------------------------------------------------------------------------- /tllib/alignment/coral.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class CorrelationAlignmentLoss(nn.Module): 10 | r"""The `Correlation Alignment Loss` in 11 | `Deep CORAL: Correlation Alignment for Deep Domain Adaptation (ECCV 2016) `_. 12 | 13 | Given source features :math:`f_S` and target features :math:`f_T`, the covariance matrices are given by 14 | 15 | .. math:: 16 | C_S = \frac{1}{n_S-1}(f_S^Tf_S-\frac{1}{n_S}(\textbf{1}^Tf_S)^T(\textbf{1}^Tf_S)) 17 | .. math:: 18 | C_T = \frac{1}{n_T-1}(f_T^Tf_T-\frac{1}{n_T}(\textbf{1}^Tf_T)^T(\textbf{1}^Tf_T)) 19 | 20 | where :math:`\textbf{1}` denotes a column vector with all elements equal to 1, :math:`n_S, n_T` denotes number of 21 | source and target samples, respectively. We use :math:`d` to denote feature dimension, use 22 | :math:`{\Vert\cdot\Vert}^2_F` to denote the squared matrix `Frobenius norm`. The correlation alignment loss is 23 | given by 24 | 25 | .. math:: 26 | l_{CORAL} = \frac{1}{4d^2}\Vert C_S-C_T \Vert^2_F 27 | 28 | Inputs: 29 | - f_s (tensor): feature representations on source domain, :math:`f^s` 30 | - f_t (tensor): feature representations on target domain, :math:`f^t` 31 | 32 | Shape: 33 | - f_s, f_t: :math:`(N, d)` where d means the dimension of input features, :math:`N=n_S=n_T` is mini-batch size. 34 | - Outputs: scalar. 35 | """ 36 | 37 | def __init__(self): 38 | super(CorrelationAlignmentLoss, self).__init__() 39 | 40 | def forward(self, f_s: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor: 41 | mean_s = f_s.mean(0, keepdim=True) 42 | mean_t = f_t.mean(0, keepdim=True) 43 | cent_s = f_s - mean_s 44 | cent_t = f_t - mean_t 45 | cov_s = torch.mm(cent_s.t(), cent_s) / (len(f_s) - 1) 46 | cov_t = torch.mm(cent_t.t(), cent_t) / (len(f_t) - 1) 47 | 48 | mean_diff = (mean_s - mean_t).pow(2).mean() 49 | cov_diff = (cov_s - cov_t).pow(2).mean() 50 | 51 | return mean_diff + cov_diff 52 | -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/coral.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ResNet50, PACS 3 | CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_P 4 | CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_A 5 | CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_C 6 | CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_S 7 | 8 | # ResNet50, Office-Home 9 | CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/coral/OfficeHome_Pr 10 | CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/coral/OfficeHome_Rw 11 | CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/coral/OfficeHome_Cl 12 | CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/coral/OfficeHome_Ar 13 | 14 | # ResNet50, DomainNet 15 | CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_c 16 | CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_i 17 | CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_p 18 | CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_q 19 | CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_r 20 | CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_s 21 | -------------------------------------------------------------------------------- /tllib/ranking/nce.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yong Liu 3 | @contact: liuyong1095556447@163.com 4 | """ 5 | import numpy as np 6 | 7 | __all__ = ['negative_conditional_entropy'] 8 | 9 | 10 | def negative_conditional_entropy(source_labels: np.ndarray, target_labels: np.ndarray): 11 | r""" 12 | Negative Conditional Entropy in `Transferability and Hardness of Supervised 13 | Classification Tasks (ICCV 2019) `_. 14 | 15 | The NCE :math:`\mathcal{H}` can be described as: 16 | 17 | .. math:: 18 | \mathcal{H}=-\sum_{y \in \mathcal{C}_t} \sum_{z \in \mathcal{C}_s} \hat{P}(y, z) \log \frac{\hat{P}(y, z)}{\hat{P}(z)} 19 | 20 | where :math:`\hat{P}(z)` is the empirical distribution and :math:`\hat{P}\left(y \mid z\right)` is the empirical 21 | conditional distribution estimated by source and target label. 22 | 23 | Args: 24 | source_labels (np.ndarray): predicted source labels. 25 | target_labels (np.ndarray): groud-truth target labels. 26 | 27 | Shape: 28 | - source_labels: (N, ) elements in [0, :math:`C_s`), with source class number :math:`C_s`. 29 | - target_labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`. 30 | """ 31 | C_t = int(np.max(target_labels) + 1) 32 | C_s = int(np.max(source_labels) + 1) 33 | N = len(source_labels) 34 | 35 | joint = np.zeros((C_t, C_s), dtype=float) # placeholder for the joint distribution, shape [C_t, C_s] 36 | for s, t in zip(source_labels, target_labels): 37 | s = int(s) 38 | t = int(t) 39 | joint[t, s] += 1.0 / N 40 | p_z = joint.sum(axis=0, keepdims=True) 41 | 42 | p_target_given_source = (joint / p_z).T # P(y | z), shape [C_s, C_t] 43 | mask = p_z.reshape(-1) != 0 # valid Z, shape [C_s] 44 | p_target_given_source = p_target_given_source[mask] + 1e-20 # remove NaN where p(z) = 0, add 1e-20 to avoid log (0) 45 | entropy_y_given_z = np.sum(- p_target_given_source * np.log(p_target_given_source), axis=1, keepdims=True) 46 | conditional_entropy = np.sum(entropy_y_given_z * p_z.reshape((-1, 1))[mask]) 47 | 48 | return -conditional_entropy 49 | -------------------------------------------------------------------------------- /tllib/vision/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | import os 6 | from .imagelist import ImageList 7 | from ._util import download as download_data, check_exits 8 | 9 | 10 | class SUN397(ImageList): 11 | """`SUN397 `_ is a dataset for scene understanding 12 | with 108,754 images in 397 scene categories. The number of images varies across categories, 13 | but there are at least 100 images per category. Note that the authors construct 10 partitions, 14 | where each partition contains 50 training images and 50 testing images per class. We adopt partition 1. 15 | 16 | Args: 17 | root (str): Root directory of dataset 18 | split (str, optional): The dataset split, supports ``train``, or ``test``. 19 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 20 | in root directory. If dataset is already downloaded, it is not downloaded again. 21 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 22 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 23 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 24 | 25 | """ 26 | dataset_url = ("SUN397", "SUN397.tar.gz", "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz") 27 | image_list_url = ( 28 | "SUN397/image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/dec0775147c144ea9f75/?dl=1") 29 | 30 | def __init__(self, root, split='train', download=True, **kwargs): 31 | if download: 32 | download_data(root, *self.dataset_url) 33 | download_data(os.path.join(root, 'SUN397'), *self.image_list_url) 34 | else: 35 | check_exits(root, "SUN397") 36 | check_exits(root, "SUN397/image_list") 37 | 38 | classes = list([str(i) for i in range(397)]) 39 | root = os.path.join(root, 'SUN397') 40 | super(SUN397, self).__init__(root, classes, os.path.join(root, 'image_list', '{}.txt'.format(split)), **kwargs) 41 | -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/ibn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # IBN_ResNet50_b, PACS 3 | CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s A C S -t P -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_P 4 | CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P C S -t A -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_A 5 | CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A S -t C -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_C 6 | CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A C -t S -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_S 7 | 8 | # IBN_ResNet50_b, Office-Home 9 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Pr 10 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Rw 11 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Cl 12 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Ar 13 | 14 | # IBN_ResNet50_b, DomainNet 15 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_c 16 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_i 17 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_p 18 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_q 19 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_r 20 | CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_s 21 | -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/irm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ResNet50, PACS 3 | CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_P 4 | CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_A 5 | CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_C 6 | CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_S 7 | 8 | # ResNet50, Office-Home 9 | CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr 10 | CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/irm/OfficeHome_Rw 11 | CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/irm/OfficeHome_Cl 12 | CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/irm/OfficeHome_Ar 13 | 14 | # ResNet50, DomainNet 15 | CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_c 16 | CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_i 17 | CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_p 18 | CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_q 19 | CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_r 20 | CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_s 21 | -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/groupdro.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ResNet50, PACS 3 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_P 4 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_A 5 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_C 6 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_S 7 | 8 | # ResNet50, Office-Home 9 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Pr 10 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Rw 11 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Cl 12 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Ar 13 | 14 | # ResNet50, DomainNet 15 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_c 16 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_i 17 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_p 18 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_q 19 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_r 20 | CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_s 21 | -------------------------------------------------------------------------------- /tllib/modules/domain_discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from typing import List, Dict 6 | import torch.nn as nn 7 | 8 | __all__ = ['DomainDiscriminator'] 9 | 10 | 11 | class DomainDiscriminator(nn.Sequential): 12 | r"""Domain discriminator model from 13 | `Domain-Adversarial Training of Neural Networks (ICML 2015) `_ 14 | 15 | Distinguish whether the input features come from the source domain or the target domain. 16 | The source domain label is 1 and the target domain label is 0. 17 | 18 | Args: 19 | in_feature (int): dimension of the input feature 20 | hidden_size (int): dimension of the hidden features 21 | batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`. 22 | Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True. 23 | 24 | Shape: 25 | - Inputs: (minibatch, `in_feature`) 26 | - Outputs: :math:`(minibatch, 1)` 27 | """ 28 | 29 | def __init__(self, in_feature: int, hidden_size: int, batch_norm=True, sigmoid=True): 30 | if sigmoid: 31 | final_layer = nn.Sequential( 32 | nn.Linear(hidden_size, 1), 33 | nn.Sigmoid() 34 | ) 35 | else: 36 | final_layer = nn.Linear(hidden_size, 2) 37 | if batch_norm: 38 | super(DomainDiscriminator, self).__init__( 39 | nn.Linear(in_feature, hidden_size), 40 | nn.BatchNorm1d(hidden_size), 41 | nn.ReLU(), 42 | nn.Linear(hidden_size, hidden_size), 43 | nn.BatchNorm1d(hidden_size), 44 | nn.ReLU(), 45 | final_layer 46 | ) 47 | else: 48 | super(DomainDiscriminator, self).__init__( 49 | nn.Linear(in_feature, hidden_size), 50 | nn.ReLU(inplace=True), 51 | nn.Dropout(0.5), 52 | nn.Linear(hidden_size, hidden_size), 53 | nn.ReLU(inplace=True), 54 | nn.Dropout(0.5), 55 | final_layer 56 | ) 57 | 58 | def get_parameters(self) -> List[Dict]: 59 | return [{"params": self.parameters(), "lr": 1.}] 60 | 61 | 62 | -------------------------------------------------------------------------------- /tllib/vision/models/object_detection/backbone/mmdetection/weight_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | # Source: https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/weight_init.py 3 | 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | 8 | def constant_init(module, val, bias=0): 9 | nn.init.constant_(module.weight, val) 10 | if hasattr(module, 'bias') and module.bias is not None: 11 | nn.init.constant_(module.bias, bias) 12 | 13 | 14 | def xavier_init(module, gain=1, bias=0, distribution='normal'): 15 | assert distribution in ['uniform', 'normal'] 16 | if distribution == 'uniform': 17 | nn.init.xavier_uniform_(module.weight, gain=gain) 18 | else: 19 | nn.init.xavier_normal_(module.weight, gain=gain) 20 | if hasattr(module, 'bias') and module.bias is not None: 21 | nn.init.constant_(module.bias, bias) 22 | 23 | 24 | def normal_init(module, mean=0, std=1, bias=0): 25 | nn.init.normal_(module.weight, mean, std) 26 | if hasattr(module, 'bias') and module.bias is not None: 27 | nn.init.constant_(module.bias, bias) 28 | 29 | 30 | def uniform_init(module, a=0, b=1, bias=0): 31 | nn.init.uniform_(module.weight, a, b) 32 | if hasattr(module, 'bias') and module.bias is not None: 33 | nn.init.constant_(module.bias, bias) 34 | 35 | 36 | def kaiming_init(module, 37 | a=0, 38 | mode='fan_out', 39 | nonlinearity='relu', 40 | bias=0, 41 | distribution='normal'): 42 | assert distribution in ['uniform', 'normal'] 43 | if distribution == 'uniform': 44 | nn.init.kaiming_uniform_( 45 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 46 | else: 47 | nn.init.kaiming_normal_( 48 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 49 | if hasattr(module, 'bias') and module.bias is not None: 50 | nn.init.constant_(module.bias, bias) 51 | 52 | 53 | def caffe2_xavier_init(module, bias=0): 54 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch 55 | # Acknowledgment to FAIR's internal code 56 | kaiming_init( 57 | module, 58 | a=1, 59 | mode='fan_in', 60 | nonlinearity='leaky_relu', 61 | distribution='uniform') -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_poverty/resnet_ms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified based on torchvision.models.resnet 3 | @author: Jiaxin Li 4 | @contact: thulijx@gmail.com 5 | """ 6 | import torch.nn as nn 7 | from torchvision import models 8 | from torchvision.models.resnet import BasicBlock, Bottleneck 9 | import copy 10 | 11 | __all__ = ['resnet18_ms', 'resnet34_ms', 'resnet50_ms', 'resnet101_ms', 'resnet152_ms'] 12 | 13 | 14 | class ResNetMS(models.ResNet): 15 | """ 16 | ResNet with input channels parameter, without fully connected layer. 17 | """ 18 | 19 | def __init__(self, in_channels, *args, **kwargs): 20 | super(ResNetMS, self).__init__(*args, **kwargs) 21 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, 22 | bias=False) 23 | self._out_features = self.fc.in_features 24 | nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu') 25 | 26 | def forward(self, x): 27 | x = self.conv1(x) 28 | x = self.bn1(x) 29 | x = self.relu(x) 30 | x = self.maxpool(x) 31 | 32 | x = self.layer1(x) 33 | x = self.layer2(x) 34 | x = self.layer3(x) 35 | x = self.layer4(x) 36 | 37 | # x = self.avgpool(x) 38 | # x = torch.flatten(x, 1) 39 | # x = self.fc(x) 40 | return x 41 | 42 | @property 43 | def out_features(self) -> int: 44 | """The dimension of output features""" 45 | return self._out_features 46 | 47 | def copy_head(self) -> nn.Module: 48 | """Copy the origin fully connected layer""" 49 | return copy.deepcopy(self.fc) 50 | 51 | 52 | def resnet18_ms(num_channels=3): 53 | model = ResNetMS(num_channels, BasicBlock, [2, 2, 2, 2]) 54 | return model 55 | 56 | 57 | def resnet34_ms(num_channels=3): 58 | model = ResNetMS(num_channels, BasicBlock, [3, 4, 6, 3]) 59 | return model 60 | 61 | 62 | def resnet50_ms(num_channels=3): 63 | model = ResNetMS(num_channels, Bottleneck, [3, 4, 6, 3]) 64 | return model 65 | 66 | 67 | def resnet101_ms(num_channels=3): 68 | model = ResNetMS(num_channels, Bottleneck, [3, 4, 23, 3]) 69 | return model 70 | 71 | 72 | def resnet152_ms(num_channels=3): 73 | model = ResNetMS(num_channels, Bottleneck, [3, 8, 36, 3]) 74 | return model 75 | -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/vrex.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ResNet50, PACS 3 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_P 4 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_A 5 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_C 6 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_S 7 | 8 | # ResNet50, Office-Home 9 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Pr 10 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Rw 11 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Cl 12 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Ar 13 | 14 | # ResNet50, DomainNet 15 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_c 16 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_i 17 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_p 18 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_q 19 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_r 20 | CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_s 21 | -------------------------------------------------------------------------------- /tllib/alignment/bsp.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | from typing import Optional 6 | import torch 7 | import torch.nn as nn 8 | from tllib.modules.classifier import Classifier as ClassifierBase 9 | 10 | 11 | class BatchSpectralPenalizationLoss(nn.Module): 12 | r"""Batch spectral penalization loss from `Transferability vs. Discriminability: Batch 13 | Spectral Penalization for Adversarial Domain Adaptation (ICML 2019) 14 | `_. 15 | 16 | Given source features :math:`f_s` and target features :math:`f_t` in current mini batch, singular value 17 | decomposition is first performed 18 | 19 | .. math:: 20 | f_s = U_s\Sigma_sV_s^T 21 | 22 | .. math:: 23 | f_t = U_t\Sigma_tV_t^T 24 | 25 | Then batch spectral penalization loss is calculated as 26 | 27 | .. math:: 28 | loss=\sum_{i=1}^k(\sigma_{s,i}^2+\sigma_{t,i}^2) 29 | 30 | where :math:`\sigma_{s,i},\sigma_{t,i}` refer to the :math:`i-th` largest singular value of source features 31 | and target features respectively. We empirically set :math:`k=1`. 32 | 33 | Inputs: 34 | - f_s (tensor): feature representations on source domain, :math:`f^s` 35 | - f_t (tensor): feature representations on target domain, :math:`f^t` 36 | 37 | Shape: 38 | - f_s, f_t: :math:`(N, F)` where F means the dimension of input features. 39 | - Outputs: scalar. 40 | 41 | """ 42 | 43 | def __init__(self): 44 | super(BatchSpectralPenalizationLoss, self).__init__() 45 | 46 | def forward(self, f_s, f_t): 47 | _, s_s, _ = torch.svd(f_s) 48 | _, s_t, _ = torch.svd(f_t) 49 | loss = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2) 50 | return loss 51 | 52 | 53 | class ImageClassifier(ClassifierBase): 54 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): 55 | bottleneck = nn.Sequential( 56 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)), 57 | # nn.Flatten(), 58 | nn.Linear(backbone.out_features, bottleneck_dim), 59 | nn.BatchNorm1d(bottleneck_dim), 60 | nn.ReLU(), 61 | ) 62 | super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) 63 | -------------------------------------------------------------------------------- /examples/semi_supervised_learning/image_classification/noisy_student.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ImageNet Supervised Pretrain (ResNet50) 4 | # ====================================================================================================================== 5 | # CIFAR 100 6 | CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ 7 | --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ 8 | --lr 0.01 --finetune --epochs 20 --seed 0 --log logs/noisy_student/cifar100_4_labels_per_class/iter_0 9 | 10 | for round in 0 1 2; do 11 | CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ 12 | --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ 13 | --pretrained-teacher logs/noisy_student/cifar100_4_labels_per_class/iter_$round/checkpoints/latest.pth \ 14 | --lr 0.01 --finetune --epochs 40 --T 0.5 --seed 0 --log logs/noisy_student/cifar100_4_labels_per_class/iter_$((round + 1)) 15 | done 16 | 17 | # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) 18 | # ====================================================================================================================== 19 | # CIFAR100 20 | CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ 21 | --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ 22 | --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ 23 | --lr 0.001 --finetune --lr-scheduler cos --epochs 20 --seed 0 \ 24 | --log logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_0 25 | 26 | for round in 0 1 2; do 27 | CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ 28 | --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ 29 | --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ 30 | --pretrained-teacher logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_$round/checkpoints/latest.pth \ 31 | --lr 0.001 --finetune --lr-scheduler cos --epochs 40 --T 1 --seed 0 \ 32 | --log logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_$((round + 1)) 33 | done 34 | -------------------------------------------------------------------------------- /examples/domain_adaptation/image_regression/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import sys 6 | import time 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d 10 | from torch.nn.modules.instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d 11 | 12 | sys.path.append('../../..') 13 | from tllib.utils.meter import AverageMeter, ProgressMeter 14 | 15 | 16 | def convert_model(module): 17 | """convert BatchNorms in the `module` into InstanceNorms""" 18 | source_modules = (BatchNorm1d, BatchNorm2d, BatchNorm3d) 19 | target_modules = (InstanceNorm1d, InstanceNorm2d, InstanceNorm3d) 20 | for src_module, tgt_module in zip(source_modules, target_modules): 21 | if isinstance(module, src_module): 22 | mod = tgt_module(module.num_features, module.eps, module.momentum, module.affine) 23 | module = mod 24 | 25 | for name, child in module.named_children(): 26 | module.add_module(name, convert_model(child)) 27 | 28 | return module 29 | 30 | 31 | def validate(val_loader, model, args, factors, device): 32 | batch_time = AverageMeter('Time', ':6.3f') 33 | mae_losses = [AverageMeter('mae {}'.format(factor), ':6.3f') for factor in factors] 34 | progress = ProgressMeter( 35 | len(val_loader), 36 | [batch_time] + mae_losses, 37 | prefix='Test: ') 38 | 39 | # switch to evaluate mode 40 | model.eval() 41 | 42 | with torch.no_grad(): 43 | end = time.time() 44 | for i, (images, target) in enumerate(val_loader): 45 | images = images.to(device) 46 | target = target.to(device) 47 | 48 | # compute output 49 | output = model(images) 50 | for j in range(len(factors)): 51 | mae_loss = F.l1_loss(output[:, j], target[:, j]) 52 | mae_losses[j].update(mae_loss.item(), images.size(0)) 53 | 54 | # measure elapsed time 55 | batch_time.update(time.time() - end) 56 | end = time.time() 57 | 58 | if i % args.print_freq == 0: 59 | progress.display(i) 60 | 61 | for i, factor in enumerate(factors): 62 | print("{} MAE {mae.avg:6.3f}".format(factor, mae=mae_losses[i])) 63 | mean_mae = sum(l.avg for l in mae_losses) / len(factors) 64 | return mean_mae 65 | -------------------------------------------------------------------------------- /examples/domain_adaptation/re_identification/ibn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Market1501 -> Duke 3 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a resnet50_ibn_a \ 4 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2Duke 5 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a resnet50_ibn_b \ 6 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2Duke 7 | 8 | # Duke -> Market1501 9 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a resnet50_ibn_a \ 10 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2Market 11 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a resnet50_ibn_b \ 12 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2Market 13 | 14 | # Market1501 -> MSMT 15 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a resnet50_ibn_a \ 16 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2MSMT 17 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a resnet50_ibn_b \ 18 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2MSMT 19 | 20 | # MSMT -> Market1501 21 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a resnet50_ibn_a \ 22 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Market 23 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a resnet50_ibn_b \ 24 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Market 25 | 26 | # Duke -> MSMT 27 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a resnet50_ibn_a \ 28 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2MSMT 29 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a resnet50_ibn_b \ 30 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2MSMT 31 | 32 | # MSMT -> Duke 33 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a resnet50_ibn_a \ 34 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Duke 35 | CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a resnet50_ibn_b \ 36 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Duke 37 | -------------------------------------------------------------------------------- /examples/domain_adaptation/wilds_poverty/README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Domain Adaptation for WILDS (Image Regression) 2 | 3 | ## Installation 4 | 5 | It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results. 6 | 7 | You need to install apex following `https://github.com/NVIDIA/apex`. Then run 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Dataset 14 | 15 | Following datasets can be downloaded automatically: 16 | 17 | - [PovertyMap (WILDS)](https://wilds.stanford.edu/datasets/) 18 | 19 | ## Supported Methods 20 | 21 | TODO 22 | 23 | ## Usage 24 | 25 | Our code is based 26 | on [https://github.com/NVIDIA/apex/edit/master/examples/imagenet](https://github.com/NVIDIA/apex/edit/master/examples/imagenet) 27 | . It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and 28 | VGG, on the WILDS dataset. 29 | Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed 30 | precision "optimization levels" or `opt_level`s. 31 | For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). 32 | 33 | The shell files give all the training scripts we use, e.g. 34 | 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold A \ 37 | --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_A 38 | ``` 39 | 40 | ## Results 41 | 42 | ### Performance on WILDS-PovertyMap (ResNet18-MultiSpectral) 43 | 44 | | Method | Val Pearson r | Test Pearson r | Val Worst-U/R Pearson r | Test Worst-U/R Pearson r | GPU Memory Usage(GB) | 45 | | --- | --- | --- | --- | --- | --- | 46 | | ERM | 0.80 | 0.80 | 0.54 | 0.50 | 3.5 | 47 | 48 | ### Distributed training 49 | 50 | We uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process. 51 | 52 | ``` 53 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 erm.py /data/wilds --arch 'resnet18_ms' \ 54 | --opt-level O1 --deterministic --log logs/erm/poverty --lr 1e-3 --wd 0.0 --epochs 200 --metric r_wg --split_scheme official -b 64 64 --fold A 55 | ``` 56 | 57 | ### Visualization 58 | 59 | We use tensorboard to record the training process and visualize the outputs of the models. 60 | 61 | ``` 62 | tensorboard --logdir=logs 63 | ``` 64 | 65 | -------------------------------------------------------------------------------- /examples/domain_generalization/image_classification/mixstyle.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ResNet50, PACS 3 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s A C S -t P -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_P 4 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P C S -t A -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_A 5 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P A S -t C -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_C 6 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P A C -t S -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_S 7 | 8 | # ResNet50, Office-Home 9 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Pr 10 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Rw 11 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Cl 12 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Ar 13 | 14 | # ResNet50, DomainNet 15 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_c 16 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_i 17 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_p 18 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_q 19 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_r 20 | CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_s 21 | -------------------------------------------------------------------------------- /tllib/alignment/d_adapt/modeling/matcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch 6 | from torch import Tensor, nn 7 | 8 | from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple 9 | 10 | 11 | class MaxOverlapMatcher(object): 12 | """ 13 | This class assigns to each predicted "element" (e.g., a box) a ground-truth 14 | element. Each predicted element will have exactly zero or one matches; each 15 | ground-truth element may be matched to one predicted elements. 16 | """ 17 | 18 | def __init__(self): 19 | pass 20 | 21 | def __call__(self, match_quality_matrix): 22 | """ 23 | Args: 24 | match_quality_matrix (Tensor[float]): an MxN tensor, containing the 25 | pairwise quality between M ground-truth elements and N predicted 26 | elements. All elements must be >= 0 (due to the us of `torch.nonzero` 27 | for selecting indices in :meth:`set_low_quality_matches_`). 28 | 29 | Returns: 30 | matches (Tensor[int64]): a vector of length N, where matches[i] is a matched 31 | ground-truth index in [0, M) 32 | match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates 33 | whether a prediction is a true or false positive or ignored 34 | """ 35 | assert match_quality_matrix.dim() == 2 36 | # match_quality_matrix is M (gt) x N (predicted) 37 | # Max over gt elements (dim 0) to find best gt candidate for each prediction 38 | _, matched_idxs = match_quality_matrix.max(dim=0) 39 | 40 | anchor_labels = match_quality_matrix.new_full( 41 | (match_quality_matrix.size(1),), -1, dtype=torch.int8 42 | ) 43 | 44 | # For each gt, find the prediction with which it has highest quality 45 | highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) 46 | # Find the highest quality match available, even if it is low, including ties. 47 | # Note that the matches qualities must be positive due to the use of 48 | # `torch.nonzero`. 49 | _, pred_inds_with_highest_quality = nonzero_tuple( 50 | match_quality_matrix == highest_quality_foreach_gt[:, None] 51 | ) 52 | anchor_labels[pred_inds_with_highest_quality] = 1 53 | 54 | return matched_idxs, anchor_labels 55 | -------------------------------------------------------------------------------- /docs/tllib/alignment/domain_adversarial.rst: -------------------------------------------------------------------------------- 1 | ========================================== 2 | Domain Adversarial Training 3 | ========================================== 4 | 5 | 6 | .. _DANN: 7 | 8 | DANN: Domain Adversarial Neural Network 9 | ---------------------------------------- 10 | 11 | .. autoclass:: tllib.alignment.dann.DomainAdversarialLoss 12 | 13 | .. _CDAN: 14 | 15 | CDAN: Conditional Domain Adversarial Network 16 | ----------------------------------------------- 17 | 18 | .. autoclass:: tllib.alignment.cdan.ConditionalDomainAdversarialLoss 19 | 20 | 21 | .. autoclass:: tllib.alignment.cdan.RandomizedMultiLinearMap 22 | 23 | 24 | .. autoclass:: tllib.alignment.cdan.MultiLinearMap 25 | 26 | 27 | .. _ADDA: 28 | 29 | ADDA: Adversarial Discriminative Domain Adaptation 30 | ----------------------------------------------------- 31 | 32 | .. autoclass:: tllib.alignment.adda.DomainAdversarialLoss 33 | 34 | .. note:: 35 | ADDAgrl is also implemented and benchmarked. You can find code 36 | `here `_. 37 | 38 | 39 | .. _BSP: 40 | 41 | BSP: Batch Spectral Penalization 42 | ----------------------------------- 43 | 44 | .. autoclass:: tllib.alignment.bsp.BatchSpectralPenalizationLoss 45 | 46 | 47 | .. _OSBP: 48 | 49 | OSBP: Open Set Domain Adaptation by Backpropagation 50 | ---------------------------------------------------- 51 | 52 | .. autoclass:: tllib.alignment.osbp.UnknownClassBinaryCrossEntropy 53 | 54 | 55 | .. _ADVENT: 56 | 57 | ADVENT: Adversarial Entropy Minimization for Semantic Segmentation 58 | ------------------------------------------------------------------ 59 | 60 | .. autoclass:: tllib.alignment.advent.Discriminator 61 | 62 | .. autoclass:: tllib.alignment.advent.DomainAdversarialEntropyLoss 63 | :members: 64 | 65 | 66 | .. _DADAPT: 67 | 68 | D-adapt: Decoupled Adaptation for Cross-Domain Object Detection 69 | ---------------------------------------------------------------- 70 | `Origin Paper `_. 71 | 72 | .. autoclass:: tllib.alignment.d_adapt.proposal.Proposal 73 | 74 | .. autoclass:: tllib.alignment.d_adapt.proposal.PersistentProposalList 75 | 76 | .. autoclass:: tllib.alignment.d_adapt.proposal.ProposalDataset 77 | 78 | .. autoclass:: tllib.alignment.d_adapt.modeling.meta_arch.DecoupledGeneralizedRCNN 79 | 80 | .. autoclass:: tllib.alignment.d_adapt.modeling.meta_arch.DecoupledRetinaNet 81 | 82 | -------------------------------------------------------------------------------- /tllib/vision/models/digits.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch.nn as nn 6 | 7 | 8 | class LeNet(nn.Sequential): 9 | def __init__(self, num_classes=10): 10 | super(LeNet, self).__init__( 11 | nn.Conv2d(1, 20, kernel_size=5), 12 | nn.MaxPool2d(2), 13 | nn.ReLU(), 14 | nn.Conv2d(20, 50, kernel_size=5), 15 | nn.Dropout2d(p=0.5), 16 | nn.MaxPool2d(2), 17 | nn.ReLU(), 18 | nn.Flatten(start_dim=1), 19 | nn.Linear(50 * 4 * 4, 500), 20 | nn.ReLU(), 21 | nn.Dropout(p=0.5), 22 | ) 23 | self.num_classes = num_classes 24 | self.out_features = 500 25 | 26 | def copy_head(self): 27 | return nn.Linear(500, self.num_classes) 28 | 29 | 30 | class DTN(nn.Sequential): 31 | def __init__(self, num_classes=10): 32 | super(DTN, self).__init__( 33 | nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), 34 | nn.BatchNorm2d(64), 35 | nn.Dropout2d(0.1), 36 | nn.ReLU(), 37 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), 38 | nn.BatchNorm2d(128), 39 | nn.Dropout2d(0.3), 40 | nn.ReLU(), 41 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), 42 | nn.BatchNorm2d(256), 43 | nn.Dropout2d(0.5), 44 | nn.ReLU(), 45 | nn.Flatten(start_dim=1), 46 | nn.Linear(256 * 4 * 4, 512), 47 | nn.BatchNorm1d(512), 48 | nn.ReLU(), 49 | nn.Dropout(), 50 | ) 51 | self.num_classes = num_classes 52 | self.out_features = 512 53 | 54 | def copy_head(self): 55 | return nn.Linear(512, self.num_classes) 56 | 57 | 58 | 59 | def lenet(pretrained=False, **kwargs): 60 | """LeNet model from 61 | `"Gradient-based learning applied to document recognition" `_ 62 | 63 | Args: 64 | num_classes (int): number of classes. Default: 10 65 | 66 | .. note:: 67 | The input image size must be 28 x 28. 68 | 69 | """ 70 | return LeNet(**kwargs) 71 | 72 | 73 | def dtn(pretrained=False, **kwargs): 74 | """ DTN model 75 | 76 | Args: 77 | num_classes (int): number of classes. Default: 10 78 | 79 | .. note:: 80 | The input image size must be 32 x 32. 81 | 82 | """ 83 | return DTN(**kwargs) -------------------------------------------------------------------------------- /examples/domain_adaptation/object_detection/oracle.sh: -------------------------------------------------------------------------------- 1 | # Faster RCNN: WaterColor 2 | CUDA_VISIBLE_DEVICES=0 python source_only.py \ 3 | --config-file config/faster_rcnn_R_101_C4_voc.yaml \ 4 | -s WaterColor datasets/watercolor -t WaterColor datasets/watercolor \ 5 | --test WaterColorTest datasets/watercolor --finetune \ 6 | OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/watercolor MODEL.ROI_HEADS.NUM_CLASSES 6 7 | 8 | # Faster RCNN: Comic 9 | CUDA_VISIBLE_DEVICES=0 python source_only.py \ 10 | --config-file config/faster_rcnn_R_101_C4_voc.yaml \ 11 | -s Comic datasets/comic -t Comic datasets/comic \ 12 | --test ComicTest datasets/comic --finetune \ 13 | OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/comic MODEL.ROI_HEADS.NUM_CLASSES 6 14 | 15 | # ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes 16 | CUDA_VISIBLE_DEVICES=0 python source_only.py \ 17 | --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ 18 | -s FoggyCityscapes datasets/foggy_cityscapes_in_voc -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ 19 | --test FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \ 20 | OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/cityscapes2foggy 21 | 22 | # VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes 23 | CUDA_VISIBLE_DEVICES=0 python source_only.py \ 24 | --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ 25 | -s FoggyCityscapes datasets/foggy_cityscapes_in_voc -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ 26 | --test FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \ 27 | OUTPUT_DIR logs/oracle/faster_rcnn_vgg_16/cityscapes2foggy 28 | 29 | # ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car 30 | CUDA_VISIBLE_DEVICES=0 python source_only.py \ 31 | --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ 32 | -s CityscapesCar datasets/cityscapes_in_voc/ -t CityscapesCar datasets/cityscapes_in_voc/ \ 33 | --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \ 34 | OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1 35 | 36 | # VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car 37 | CUDA_VISIBLE_DEVICES=0 python source_only.py \ 38 | --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ 39 | -s CityscapesCar datasets/cityscapes_in_voc/ -t CityscapesCar datasets/cityscapes_in_voc/ \ 40 | --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \ 41 | OUTPUT_DIR logs/oracle/faster_rcnn_vgg_16/cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1 42 | 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import re 3 | from os import path 4 | 5 | here = path.abspath(path.dirname(__file__)) 6 | 7 | # Get the version string 8 | with open(path.join(here, 'tllib', '__init__.py')) as f: 9 | version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) 10 | 11 | # Get all runtime requirements 12 | REQUIRES = [] 13 | with open('requirements.txt') as f: 14 | for line in f: 15 | line, _, _ = line.partition('#') 16 | line = line.strip() 17 | REQUIRES.append(line) 18 | 19 | if __name__ == '__main__': 20 | setup( 21 | name="tllib", # Replace with your own username 22 | version=version, 23 | author="THUML", 24 | author_email="JiangJunguang1123@outlook.com", 25 | keywords="domain adaptation, task adaptation, domain generalization, " 26 | "transfer learning, deep learning, pytorch", 27 | description="A Transfer Learning Library for Domain Adaptation, Task Adaptation, and Domain Generalization", 28 | long_description=open('README.md', encoding='utf8').read(), 29 | long_description_content_type="text/markdown", 30 | url="https://github.com/thuml/Transfer-Learning-Library", 31 | packages=find_packages(exclude=['docs', 'examples']), 32 | classifiers=[ 33 | # How mature is this project? Common values are 34 | # 3 - Alpha 35 | # 4 - Beta 36 | # 5 - Production/Stable 37 | 'Development Status :: 3 - Alpha', 38 | # Indicate who your project is intended for 39 | 'Intended Audience :: Science/Research', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'Topic :: Software Development :: Libraries :: Python Modules', 42 | # Pick your license as you wish (should match "license" above) 43 | 'License :: OSI Approved :: MIT License', 44 | # Specify the Python versions you support here. In particular, ensure 45 | # that you indicate whether you support Python 2, Python 3 or both. 46 | 'Programming Language :: Python :: 3.6', 47 | 'Programming Language :: Python :: 3.7', 48 | 'Programming Language :: Python :: 3.8', 49 | ], 50 | python_requires='>=3.6', 51 | install_requires=REQUIRES, 52 | extras_require={ 53 | 'dev': [ 54 | 'Sphinx', 55 | 'sphinx_rtd_theme', 56 | ] 57 | }, 58 | ) 59 | -------------------------------------------------------------------------------- /tllib/vision/datasets/segmentation/synthia.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from .segmentation_list import SegmentationList 7 | from .cityscapes import Cityscapes 8 | from .._util import download as download_data 9 | 10 | 11 | class Synthia(SegmentationList): 12 | """`SYNTHIA `_ 13 | 14 | Args: 15 | root (str): Root directory of dataset 16 | split (str, optional): The dataset split, supports ``train``. 17 | data_folder (str, optional): Sub-directory of the image. Default: 'RGB'. 18 | label_folder (str, optional): Sub-directory of the label. Default: 'synthia_mapped_to_cityscapes'. 19 | mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None. 20 | transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \ 21 | and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`. 22 | 23 | .. note:: You need to download GTA5 manually. 24 | Ensure that there exist following directories in the `root` directory before you using this class. 25 | :: 26 | RGB/ 27 | synthia_mapped_to_cityscapes/ 28 | """ 29 | ID_TO_TRAIN_ID = { 30 | 3: 0, 4: 1, 2: 2, 21: 3, 5: 4, 7: 5, 31 | 15: 6, 9: 7, 6: 8, 16: 9, 1: 10, 10: 11, 17: 12, 32 | 8: 13, 18: 14, 19: 15, 20: 16, 12: 17, 11: 18 33 | } 34 | download_list = [ 35 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/1c652d518e0347e2800d/?dl=1"), 36 | ] 37 | 38 | def __init__(self, root, split='train', data_folder='RGB', label_folder='synthia_mapped_to_cityscapes', **kwargs): 39 | assert split in ['train'] 40 | # download meta information from Internet 41 | list(map(lambda args: download_data(root, *args), self.download_list)) 42 | data_list_file = os.path.join(root, "image_list", "{}.txt".format(split)) 43 | super(Synthia, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, data_folder, 44 | label_folder, id_to_train_id=Synthia.ID_TO_TRAIN_ID, 45 | train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs) 46 | 47 | @property 48 | def evaluate_classes(self): 49 | return [ 50 | 'road', 'sidewalk', 'building', 'traffic light', 'traffic sign', 51 | 'vegetation', 'sky', 'person', 'rider', 'car', 'bus', 'motorcycle', 'bicycle' 52 | ] 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | docs/build/* 73 | docs/pytorch_sphinx_theme/* 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | .idea/ 129 | 130 | exp/* 131 | trash/* 132 | examples/domain_adaptation/digits/logs/* 133 | examples/domain_adaptation/digits/data/* 134 | +.DS_Store 135 | */.DS_Store 136 | -------------------------------------------------------------------------------- /docs/tllib/vision/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | =========================== 3 | 4 | ------------------------------ 5 | Image Classification 6 | ------------------------------ 7 | 8 | ResNets 9 | --------------------------------- 10 | 11 | .. automodule:: tllib.vision.models.resnet 12 | :members: 13 | 14 | LeNet 15 | -------------------------- 16 | 17 | .. automodule:: tllib.vision.models.digits.lenet 18 | :members: 19 | 20 | DTN 21 | -------------------------- 22 | 23 | .. automodule:: tllib.vision.models.digits.dtn 24 | :members: 25 | 26 | ---------------------------------- 27 | Object Detection 28 | ---------------------------------- 29 | 30 | .. autoclass:: tllib.vision.models.object_detection.meta_arch.TLGeneralizedRCNN 31 | :members: 32 | 33 | .. autoclass:: tllib.vision.models.object_detection.meta_arch.TLRetinaNet 34 | :members: 35 | 36 | .. autoclass:: tllib.vision.models.object_detection.proposal_generator.rpn.TLRPN 37 | 38 | .. autoclass:: tllib.vision.models.object_detection.roi_heads.TLRes5ROIHeads 39 | :members: 40 | 41 | .. autoclass:: tllib.vision.models.object_detection.roi_heads.TLStandardROIHeads 42 | :members: 43 | 44 | ---------------------------------- 45 | Semantic Segmentation 46 | ---------------------------------- 47 | 48 | .. autofunction:: tllib.vision.models.segmentation.deeplabv2.deeplabv2_resnet101 49 | 50 | 51 | ---------------------------------- 52 | Keypoint Detection 53 | ---------------------------------- 54 | 55 | PoseResNet 56 | -------------------------- 57 | 58 | .. autofunction:: tllib.vision.models.keypoint_detection.pose_resnet.pose_resnet101 59 | 60 | .. autoclass:: tllib.vision.models.keypoint_detection.pose_resnet.PoseResNet 61 | 62 | .. autoclass:: tllib.vision.models.keypoint_detection.pose_resnet.Upsampling 63 | 64 | 65 | Joint Loss 66 | ---------------------------------- 67 | 68 | .. autoclass:: tllib.vision.models.keypoint_detection.loss.JointsMSELoss 69 | 70 | .. autoclass:: tllib.vision.models.keypoint_detection.loss.JointsKLLoss 71 | 72 | 73 | ----------------------------------- 74 | Re-Identification 75 | ----------------------------------- 76 | 77 | Models 78 | --------------- 79 | .. autoclass:: tllib.vision.models.reid.resnet.ReidResNet 80 | 81 | .. automodule:: tllib.vision.models.reid.resnet 82 | :members: 83 | 84 | .. autoclass:: tllib.vision.models.reid.identifier.ReIdentifier 85 | :members: 86 | 87 | Loss 88 | ----------------------------------- 89 | .. autoclass:: tllib.vision.models.reid.loss.TripletLoss 90 | 91 | Sampler 92 | ----------------------------------- 93 | .. autoclass:: tllib.utils.data.RandomMultipleGallerySampler 94 | -------------------------------------------------------------------------------- /tllib/modules/kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from typing import Optional 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | __all__ = ['GaussianKernel'] 11 | 12 | 13 | class GaussianKernel(nn.Module): 14 | r"""Gaussian Kernel Matrix 15 | 16 | Gaussian Kernel k is defined by 17 | 18 | .. math:: 19 | k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right) 20 | 21 | where :math:`x_1, x_2 \in R^d` are 1-d tensors. 22 | 23 | Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),` 24 | 25 | .. math:: 26 | K(X)_{i,j} = k(x_i, x_j) 27 | 28 | Also by default, during training this layer keeps running estimates of the 29 | mean of L2 distances, which are then used to set hyperparameter :math:`\sigma`. 30 | Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`. 31 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 32 | keep running estimates, and use a fixed :math:`\sigma` instead. 33 | 34 | Args: 35 | sigma (float, optional): bandwidth :math:`\sigma`. Default: None 36 | track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`. 37 | Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True`` 38 | alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True`` 39 | 40 | Inputs: 41 | - X (tensor): input group :math:`X` 42 | 43 | Shape: 44 | - Inputs: :math:`(minibatch, F)` where F means the dimension of input features. 45 | - Outputs: :math:`(minibatch, minibatch)` 46 | """ 47 | 48 | def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True, 49 | alpha: Optional[float] = 1.): 50 | super(GaussianKernel, self).__init__() 51 | assert track_running_stats or sigma is not None 52 | self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None 53 | self.track_running_stats = track_running_stats 54 | self.alpha = alpha 55 | 56 | def forward(self, X: torch.Tensor) -> torch.Tensor: 57 | l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2) 58 | 59 | if self.track_running_stats: 60 | self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach()) 61 | 62 | return torch.exp(-l2_distance_square / (2 * self.sigma_square)) -------------------------------------------------------------------------------- /tllib/translation/cycada.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | 9 | class SemanticConsistency(nn.Module): 10 | """ 11 | Semantic consistency loss is introduced by 12 | `CyCADA: Cycle-Consistent Adversarial Domain Adaptation (ICML 2018) `_ 13 | 14 | This helps to prevent label flipping during image translation. 15 | 16 | Args: 17 | ignore_index (tuple, optional): Specifies target values that are ignored 18 | and do not contribute to the input gradient. When :attr:`size_average` is 19 | ``True``, the loss is averaged over non-ignored targets. Default: (). 20 | reduction (string, optional): Specifies the reduction to apply to the output: 21 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will 22 | be applied, ``'mean'``: the weighted mean of the output is taken, 23 | ``'sum'``: the output will be summed. Note: :attr:`size_average` 24 | and :attr:`reduce` are in the process of being deprecated, and in 25 | the meantime, specifying either of those two args will override 26 | :attr:`reduction`. Default: ``'mean'`` 27 | 28 | Shape: 29 | - Input: :math:`(N, C)` where `C = number of classes`, or 30 | :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` 31 | in the case of `K`-dimensional loss. 32 | - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or 33 | :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of 34 | K-dimensional loss. 35 | - Output: scalar. 36 | If :attr:`reduction` is ``'none'``, then the same size as the target: 37 | :math:`(N)`, or 38 | :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case 39 | of K-dimensional loss. 40 | 41 | Examples:: 42 | 43 | >>> loss = SemanticConsistency() 44 | >>> input = torch.randn(3, 5, requires_grad=True) 45 | >>> target = torch.empty(3, dtype=torch.long).random_(5) 46 | >>> output = loss(input, target) 47 | >>> output.backward() 48 | """ 49 | def __init__(self, ignore_index=(), reduction='mean'): 50 | super(SemanticConsistency, self).__init__() 51 | self.ignore_index = ignore_index 52 | self.loss = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction) 53 | 54 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 55 | for class_idx in self.ignore_index: 56 | target[target == class_idx] = -1 57 | return self.loss(input, target) 58 | -------------------------------------------------------------------------------- /examples/domain_adaptation/re_identification/spgan.sh: -------------------------------------------------------------------------------- 1 | # Market1501 -> Duke 2 | # step1: train SPGAN 3 | CUDA_VISIBLE_DEVICES=0 python spgan.py data -s Market1501 -t DukeMTMC \ 4 | --log logs/spgan/Market2Duke --translated-root data/spganM2D --seed 0 5 | # step2: train baseline on translated source dataset 6 | CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganM2D data -s Market1501 -t DukeMTMC -a reid_resnet50 \ 7 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Market2Duke 8 | 9 | # Duke -> Market1501 10 | # step1: train SPGAN 11 | CUDA_VISIBLE_DEVICES=0 python spgan.py data -s DukeMTMC -t Market1501 \ 12 | --log logs/spgan/Duke2Market --translated-root data/spganD2M --seed 0 13 | # step2: train baseline on translated source dataset 14 | CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganD2M data -s DukeMTMC -t Market1501 -a reid_resnet50 \ 15 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Duke2Market 16 | 17 | # Market1501 -> MSMT17 18 | # step1: train SPGAN 19 | CUDA_VISIBLE_DEVICES=0 python spgan.py data -s Market1501 -t MSMT17 \ 20 | --log logs/spgan/Market2MSMT --translated-root data/spganM2S --seed 0 21 | # step2: train baseline on translated source dataset 22 | CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganM2S data -s Market1501 -t MSMT17 -a reid_resnet50 \ 23 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Market2MSMT 24 | 25 | # MSMT -> Market1501 26 | # step1: train SPGAN 27 | CUDA_VISIBLE_DEVICES=0 python spgan.py data -s MSMT17 -t Market1501 \ 28 | --log logs/spgan/MSMT2Market --translated-root data/spganS2M --seed 0 29 | # step2: train baseline on translated source dataset 30 | CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganS2M data -s MSMT17 -t Market1501 -a reid_resnet50 \ 31 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/MSMT2Market 32 | 33 | # Duke -> MSMT 34 | # step1: train SPGAN 35 | CUDA_VISIBLE_DEVICES=0 python spgan.py data -s DukeMTMC -t MSMT17 \ 36 | --log logs/spgan/Duke2MSMT --translated-root data/spganD2S --seed 0 37 | # step2: train baseline on translated source dataset 38 | CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganD2S data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ 39 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Duke2MSMT 40 | 41 | # MSMT -> Duke 42 | # step1: train SPGAN 43 | CUDA_VISIBLE_DEVICES=0 python spgan.py data -s MSMT17 -t DukeMTMC \ 44 | --log logs/spgan/MSMT2Duke --translated-root data/spganS2D --seed 0 45 | # step2: train baseline on translated source dataset 46 | CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganS2D data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ 47 | --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/MSMT2Duke 48 | -------------------------------------------------------------------------------- /tllib/vision/datasets/visda2017.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from typing import Optional 7 | from .imagelist import ImageList 8 | from ._util import download as download_data, check_exits 9 | 10 | 11 | class VisDA2017(ImageList): 12 | """`VisDA-2017 `_ Dataset 13 | 14 | Args: 15 | root (str): Root directory of dataset 16 | task (str): The task (domain) to create dataset. Choices include ``'Synthetic'``: synthetic images and \ 17 | ``'Real'``: real-world images. 18 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 19 | in root directory. If dataset is already downloaded, it is not downloaded again. 20 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 21 | transformed version. E.g, ``transforms.RandomCrop``. 22 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 23 | 24 | .. note:: In `root`, there will exist following files after downloading. 25 | :: 26 | train/ 27 | aeroplance/ 28 | *.png 29 | ... 30 | validation/ 31 | image_list/ 32 | train.txt 33 | validation.txt 34 | """ 35 | download_list = [ 36 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/c107de37b8094c5398dc/?dl=1"), 37 | ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/c5f3ce59139144ec8221/?dl=1"), 38 | ("validation", "validation.tgz", "https://cloud.tsinghua.edu.cn/f/da70e4b1cf514ecea562/?dl=1") 39 | ] 40 | image_list = { 41 | "Synthetic": "image_list/train.txt", 42 | "Real": "image_list/validation.txt" 43 | } 44 | CLASSES = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife', 45 | 'motorcycle', 'person', 'plant', 'skateboard', 'train', 'truck'] 46 | 47 | def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs): 48 | assert task in self.image_list 49 | data_list_file = os.path.join(root, self.image_list[task]) 50 | 51 | if download: 52 | list(map(lambda args: download_data(root, *args), self.download_list)) 53 | else: 54 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) 55 | 56 | super(VisDA2017, self).__init__(root, VisDA2017.CLASSES, data_list_file=data_list_file, **kwargs) 57 | 58 | @classmethod 59 | def domains(cls): 60 | return list(cls.image_list.keys()) -------------------------------------------------------------------------------- /tllib/normalization/mixstyle/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/KaiyangZhou/mixstyle-release 3 | @author: Baixu Chen 4 | @contact: cbx_99_hasta@outlook.com 5 | """ 6 | import random 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class MixStyle(nn.Module): 12 | r"""MixStyle module from `DOMAIN GENERALIZATION WITH MIXSTYLE (ICLR 2021) `_. 13 | Given input :math:`x`, we first compute mean :math:`\mu(x)` and standard deviation :math:`\sigma(x)` across spatial 14 | dimension. Then we permute :math:`x` and get :math:`\tilde{x}`, corresponding mean :math:`\mu(\tilde{x})` and 15 | standard deviation :math:`\sigma(\tilde{x})`. `MixUp` is performed using mean and standard deviation 16 | 17 | .. math:: 18 | \gamma_{mix} = \lambda\sigma(x) + (1-\lambda)\sigma(\tilde{x}) 19 | 20 | .. math:: 21 | \beta_{mix} = \lambda\mu(x) + (1-\lambda)\mu(\tilde{x}) 22 | 23 | where :math:`\lambda` is instance-wise weight sampled from `Beta distribution`. MixStyle is then 24 | 25 | .. math:: 26 | MixStyle(x) = \gamma_{mix}\frac{x-\mu(x)}{\sigma(x)} + \beta_{mix} 27 | 28 | Args: 29 | p (float): probability of using MixStyle. 30 | alpha (float): parameter of the `Beta distribution`. 31 | eps (float): scaling parameter to avoid numerical issues. 32 | """ 33 | 34 | def __init__(self, p=0.5, alpha=0.1, eps=1e-6): 35 | super().__init__() 36 | self.p = p 37 | self.beta = torch.distributions.Beta(alpha, alpha) 38 | self.eps = eps 39 | self.alpha = alpha 40 | 41 | def forward(self, x): 42 | if not self.training: 43 | return x 44 | 45 | if random.random() > self.p: 46 | return x 47 | 48 | batch_size = x.size(0) 49 | 50 | mu = x.mean(dim=[2, 3], keepdim=True) 51 | var = x.var(dim=[2, 3], keepdim=True) 52 | sigma = (var + self.eps).sqrt() 53 | mu, sigma = mu.detach(), sigma.detach() 54 | x_normed = (x - mu) / sigma 55 | 56 | interpolation = self.beta.sample((batch_size, 1, 1, 1)) 57 | interpolation = interpolation.to(x.device) 58 | 59 | # split into two halves and swap the order 60 | perm = torch.arange(batch_size - 1, -1, -1) # inverse index 61 | perm_b, perm_a = perm.chunk(2) 62 | perm_b = perm_b[torch.randperm(batch_size // 2)] 63 | perm_a = perm_a[torch.randperm(batch_size // 2)] 64 | perm = torch.cat([perm_b, perm_a], 0) 65 | 66 | mu_perm, sigma_perm = mu[perm], sigma[perm] 67 | mu_mix = mu * interpolation + mu_perm * (1 - interpolation) 68 | sigma_mix = sigma * interpolation + sigma_perm * (1 - interpolation) 69 | 70 | return x_normed * sigma_mix + mu_mix 71 | -------------------------------------------------------------------------------- /tllib/modules/gl.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from typing import Optional, Any, Tuple 6 | import numpy as np 7 | import torch.nn as nn 8 | from torch.autograd import Function 9 | import torch 10 | 11 | 12 | class GradientFunction(Function): 13 | 14 | @staticmethod 15 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: 16 | ctx.coeff = coeff 17 | output = input * 1.0 18 | return output 19 | 20 | @staticmethod 21 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 22 | return grad_output * ctx.coeff, None 23 | 24 | 25 | class WarmStartGradientLayer(nn.Module): 26 | """Warm Start Gradient Layer :math:`\mathcal{R}(x)` with warm start 27 | 28 | The forward and backward behaviours are: 29 | 30 | .. math:: 31 | \mathcal{R}(x) = x, 32 | 33 | \dfrac{ d\mathcal{R}} {dx} = \lambda I. 34 | 35 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule: 36 | 37 | .. math:: 38 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo 39 | 40 | where :math:`i` is the iteration step. 41 | 42 | Parameters: 43 | - **alpha** (float, optional): :math:`α`. Default: 1.0 44 | - **lo** (float, optional): Initial value of :math:`\lambda`. Default: 0.0 45 | - **hi** (float, optional): Final value of :math:`\lambda`. Default: 1.0 46 | - **max_iters** (int, optional): :math:`N`. Default: 1000 47 | - **auto_step** (bool, optional): If True, increase :math:`i` each time `forward` is called. 48 | Otherwise use function `step` to increase :math:`i`. Default: False 49 | """ 50 | 51 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1., 52 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False): 53 | super(WarmStartGradientLayer, self).__init__() 54 | self.alpha = alpha 55 | self.lo = lo 56 | self.hi = hi 57 | self.iter_num = 0 58 | self.max_iters = max_iters 59 | self.auto_step = auto_step 60 | 61 | def forward(self, input: torch.Tensor) -> torch.Tensor: 62 | """""" 63 | coeff = np.float( 64 | 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) 65 | - (self.hi - self.lo) + self.lo 66 | ) 67 | if self.auto_step: 68 | self.step() 69 | return GradientFunction.apply(input, coeff) 70 | 71 | def step(self): 72 | """Increase iteration number :math:`i` by 1""" 73 | self.iter_num += 1 74 | -------------------------------------------------------------------------------- /tllib/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/yxgeee/MMT 3 | @author: Baixu Chen 4 | @contact: cbx_99_hasta@outlook.com 5 | """ 6 | import torch 7 | from bisect import bisect_right 8 | 9 | 10 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 11 | r"""Starts with a warm-up phase, then decays the learning rate of each parameter group by gamma once the 12 | number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr. 13 | 14 | Args: 15 | optimizer (Optimizer): Wrapped optimizer. 16 | milestones (list): List of epoch indices. Must be increasing. 17 | gamma (float): Multiplicative factor of learning rate decay. 18 | Default: 0.1. 19 | warmup_factor (float): a float number :math:`k` between 0 and 1, the start learning rate of warmup phase 20 | will be set to :math:`k*initial\_lr` 21 | warmup_steps (int): number of warm-up steps. 22 | warmup_method (str): "constant" denotes a constant learning rate during warm-up phase and "linear" denotes a 23 | linear-increasing learning rate during warm-up phase. 24 | last_epoch (int): The index of last epoch. Default: -1. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | optimizer, 30 | milestones, 31 | gamma=0.1, 32 | warmup_factor=1.0 / 3, 33 | warmup_steps=500, 34 | warmup_method="linear", 35 | last_epoch=-1, 36 | ): 37 | if not list(milestones) == sorted(milestones): 38 | raise ValueError( 39 | "Milestones should be a list of" " increasing integers. Got {}", 40 | milestones, 41 | ) 42 | 43 | if warmup_method not in ("constant", "linear"): 44 | raise ValueError( 45 | "Only 'constant' or 'linear' warmup_method accepted" 46 | "got {}".format(warmup_method) 47 | ) 48 | self.milestones = milestones 49 | self.gamma = gamma 50 | self.warmup_factor = warmup_factor 51 | self.warmup_steps = warmup_steps 52 | self.warmup_method = warmup_method 53 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 54 | 55 | def get_lr(self): 56 | warmup_factor = 1 57 | if self.last_epoch < self.warmup_steps: 58 | if self.warmup_method == "constant": 59 | warmup_factor = self.warmup_factor 60 | elif self.warmup_method == "linear": 61 | alpha = float(self.last_epoch) / float(self.warmup_steps) 62 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 63 | return [ 64 | base_lr 65 | * warmup_factor 66 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 67 | for base_lr in self.base_lrs 68 | ] 69 | -------------------------------------------------------------------------------- /examples/domain_adaptation/openset_domain_adaptation/erm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Office31 3 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_A2W 4 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_D2W 5 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_W2D 6 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_A2D 7 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_D2A 8 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_W2A 9 | 10 | # Office-Home 11 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Cl 12 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Pr 13 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Rw 14 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Ar 15 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Pr 16 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Rw 17 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Ar 18 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Cl 19 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Rw 20 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Ar 21 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Cl 22 | CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Pr 23 | 24 | # VisDA-2017 25 | CUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ 26 | --epochs 30 -i 500 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/erm/VisDA2017_S2R 27 | -------------------------------------------------------------------------------- /tllib/reweight/groupdro.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/facebookresearch/DomainBed 3 | @author: Baixu Chen 4 | @contact: cbx_99_hasta@outlook.com 5 | """ 6 | import torch 7 | 8 | 9 | class AutomaticUpdateDomainWeightModule(object): 10 | r""" 11 | Maintaining group weight based on loss history of all domains according 12 | to `Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case 13 | Generalization (ICLR 2020) `_. 14 | 15 | Suppose we have :math:`N` domains. During each iteration, we first calculate unweighted loss among all 16 | domains, resulting in :math:`loss\in R^N`. Then we update domain weight by 17 | 18 | .. math:: 19 | w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N] 20 | 21 | where :math:`\eta` is the hyper parameter which ensures smoother change of weight. 22 | As :math:`w \in R^N` denotes a distribution, we `normalize` 23 | :math:`w` by its sum. At last, weighted loss is calculated as our objective 24 | 25 | .. math:: 26 | objective = \sum_{k=1}^N w_k * loss_k 27 | 28 | Args: 29 | num_domains (int): The number of source domains. 30 | eta (float): Hyper parameter eta. 31 | device (torch.device): The device to run on. 32 | """ 33 | 34 | def __init__(self, num_domains: int, eta: float, device): 35 | self.domain_weight = torch.ones(num_domains).to(device) / num_domains 36 | self.eta = eta 37 | 38 | def get_domain_weight(self, sampled_domain_idxes): 39 | """Get domain weight to calculate final objective. 40 | 41 | Inputs: 42 | - sampled_domain_idxes (list): sampled domain indexes in current mini-batch 43 | 44 | Shape: 45 | - sampled_domain_idxes: :math:`(D, )` where D means the number of sampled domains in current mini-batch 46 | - Outputs: :math:`(D, )` 47 | """ 48 | domain_weight = self.domain_weight[sampled_domain_idxes] 49 | domain_weight = domain_weight / domain_weight.sum() 50 | return domain_weight 51 | 52 | def update(self, sampled_domain_losses: torch.Tensor, sampled_domain_idxes): 53 | """Update domain weight using loss of current mini-batch. 54 | 55 | Inputs: 56 | - sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch 57 | - sampled_domain_idxes (list): sampled domain indexes in current mini-batch 58 | 59 | Shape: 60 | - sampled_domain_losses: :math:`(D, )` where D means the number of sampled domains in current mini-batch 61 | - sampled_domain_idxes: :math:`(D, )` 62 | """ 63 | sampled_domain_losses = sampled_domain_losses.detach() 64 | 65 | for loss, idx in zip(sampled_domain_losses, sampled_domain_idxes): 66 | self.domain_weight[idx] *= (self.eta * loss).exp() 67 | -------------------------------------------------------------------------------- /docs/tllib/translation.rst: -------------------------------------------------------------------------------- 1 | ======================================= 2 | Domain Translation 3 | ======================================= 4 | 5 | 6 | .. _CycleGAN: 7 | 8 | ------------------------------------------------ 9 | CycleGAN: Cycle-Consistent Adversarial Networks 10 | ------------------------------------------------ 11 | 12 | Discriminator 13 | -------------- 14 | 15 | .. autofunction:: tllib.translation.cyclegan.pixel 16 | 17 | .. autofunction:: tllib.translation.cyclegan.patch 18 | 19 | Generator 20 | -------------- 21 | 22 | .. autofunction:: tllib.translation.cyclegan.resnet_9 23 | 24 | .. autofunction:: tllib.translation.cyclegan.resnet_6 25 | 26 | .. autofunction:: tllib.translation.cyclegan.unet_256 27 | 28 | .. autofunction:: tllib.translation.cyclegan.unet_128 29 | 30 | 31 | GAN Loss 32 | -------------- 33 | 34 | .. autoclass:: tllib.translation.cyclegan.LeastSquaresGenerativeAdversarialLoss 35 | 36 | .. autoclass:: tllib.translation.cyclegan.VanillaGenerativeAdversarialLoss 37 | 38 | .. autoclass:: tllib.translation.cyclegan.WassersteinGenerativeAdversarialLoss 39 | 40 | Translation 41 | -------------- 42 | 43 | .. autoclass:: tllib.translation.cyclegan.Translation 44 | 45 | 46 | Util 47 | ---------------- 48 | 49 | .. autoclass:: tllib.translation.cyclegan.util.ImagePool 50 | :members: 51 | 52 | .. autofunction:: tllib.translation.cyclegan.util.set_requires_grad 53 | 54 | 55 | 56 | 57 | .. _Cycada: 58 | 59 | -------------------------------------------------------------- 60 | CyCADA: Cycle-Consistent Adversarial Domain Adaptation 61 | -------------------------------------------------------------- 62 | 63 | .. autoclass:: tllib.translation.cycada.SemanticConsistency 64 | 65 | 66 | 67 | .. _SPGAN: 68 | 69 | ----------------------------------------------------------- 70 | SPGAN: Similarity Preserving Generative Adversarial Network 71 | ----------------------------------------------------------- 72 | `Image-Image Domain Adaptation with Preserved Self-Similarity and Domain-Dissimilarity for Person Re-identification 73 | `_. SPGAN is based on CycleGAN. An additional Siamese network is adopted to force 74 | the generator to produce images different from identities in target dataset. 75 | 76 | Siamese Network 77 | ------------------- 78 | 79 | .. autoclass:: tllib.translation.spgan.siamese.SiameseNetwork 80 | 81 | Contrastive Loss 82 | ------------------- 83 | 84 | .. autoclass:: tllib.translation.spgan.loss.ContrastiveLoss 85 | 86 | 87 | .. _FDA: 88 | 89 | ------------------------------------------------ 90 | FDA: Fourier Domain Adaptation 91 | ------------------------------------------------ 92 | 93 | .. autoclass:: tllib.translation.fourier_transform.FourierTransform 94 | 95 | .. autofunction:: tllib.translation.fourier_transform.low_freq_mutate 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /examples/domain_adaptation/openset_domain_adaptation/osbp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Office31 3 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_A2W 4 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_D2W 5 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_W2D 6 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_A2D 7 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_D2A 8 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_W2A 9 | 10 | # Office-Home 11 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Cl 12 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Pr 13 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Rw 14 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Ar 15 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Pr 16 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Rw 17 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Ar 18 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Cl 19 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Rw 20 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Ar 21 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Cl 22 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Pr 23 | 24 | # VisDA-2017 25 | CUDA_VISIBLE_DEVICES=0 python osbp.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ 26 | --epochs 30 -i 1000 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/osbp/VisDA2017_S2R -------------------------------------------------------------------------------- /tllib/alignment/adda.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | from typing import Optional, List, Dict 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from tllib.modules.classifier import Classifier as ClassifierBase 10 | 11 | 12 | class DomainAdversarialLoss(nn.Module): 13 | r"""Domain adversarial loss from `Adversarial Discriminative Domain Adaptation (CVPR 2017) 14 | `_. 15 | Similar to the original `GAN `_ paper, ADDA argues that replacing 16 | :math:`\text{log}(1-p)` with :math:`-\text{log}(p)` in the adversarial loss provides better gradient qualities. Detailed 17 | optimization process can be found `here 18 | `_. 19 | 20 | Inputs: 21 | - domain_pred (tensor): predictions of domain discriminator 22 | - domain_label (str, optional): whether the data comes from source or target. 23 | Must be 'source' or 'target'. Default: 'source' 24 | 25 | Shape: 26 | - domain_pred: :math:`(minibatch,)`. 27 | - Outputs: scalar. 28 | 29 | """ 30 | 31 | def __init__(self): 32 | super(DomainAdversarialLoss, self).__init__() 33 | 34 | def forward(self, domain_pred, domain_label='source'): 35 | assert domain_label in ['source', 'target'] 36 | if domain_label == 'source': 37 | return F.binary_cross_entropy(domain_pred, torch.ones_like(domain_pred).to(domain_pred.device)) 38 | else: 39 | return F.binary_cross_entropy(domain_pred, torch.zeros_like(domain_pred).to(domain_pred.device)) 40 | 41 | 42 | class ImageClassifier(ClassifierBase): 43 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): 44 | bottleneck = nn.Sequential( 45 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)), 46 | # nn.Flatten(), 47 | nn.Linear(backbone.out_features, bottleneck_dim), 48 | nn.BatchNorm1d(bottleneck_dim), 49 | nn.ReLU() 50 | ) 51 | super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) 52 | 53 | def freeze_bn(self): 54 | for m in self.modules(): 55 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 56 | m.eval() 57 | 58 | def get_parameters(self, base_lr=1.0, optimize_head=True) -> List[Dict]: 59 | params = [ 60 | {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, 61 | {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr} 62 | ] 63 | if optimize_head: 64 | params.append({"params": self.head.parameters(), "lr": 1.0 * base_lr}) 65 | 66 | return params 67 | --------------------------------------------------------------------------------