├── tests ├── __init__.py ├── utils │ ├── __init__.py │ ├── benchmarking │ │ ├── __init__.py │ │ ├── test_topk.py │ │ └── test_metric_callback.py │ ├── test_version_compare.py │ └── test_dist.py ├── models │ ├── __init__.py │ └── modules │ │ ├── __init__.py │ │ └── test_center.py ├── transforms │ ├── __init__.py │ ├── test_multi_crop_transform.py │ ├── test_rfft2d_transform.py │ ├── test_irfft2d_transform.py │ ├── test_gaussian_mixture_masks.py │ ├── test_jigsaw.py │ ├── test_mae_transform.py │ ├── test_Solarize.py │ ├── test_random_frequency_mask_transform.py │ ├── test_densecl_transform.py │ ├── test_pirl_transform.py │ ├── test_torchvision_v2compatibility.py │ ├── test_fastsiam_transform.py │ ├── test_multi_view_transform.py │ ├── test_mmcr_transform.py │ ├── test_wmse_transform.py │ ├── test_amplitude_rescale_transform.py │ ├── test_gaussian_blur.py │ ├── test_simclr_transform.py │ ├── test_vicreg_transform.py │ ├── test_simsiam_transform.py │ ├── test_swav_transform.py │ ├── test_smog_transform.py │ ├── test_msn_transform.py │ ├── test_dino_transform.py │ ├── test_moco_transform.py │ ├── test_fda_transform.py │ ├── test_byol_transform.py │ ├── test_multi_view_transform_v2.py │ ├── test_phase_shift_transform.py │ ├── test_rotation.py │ ├── test_location_to_NxN_grid.py │ └── test_vicregl_transform.py ├── api_workflow │ ├── __init__.py │ ├── utils.py │ └── test_api_workflow_collaboration.py ├── cli │ ├── test_cli_get_lighty_config.py │ └── test_cli_version.py ├── REAMDE.md ├── helpers.py ├── loss │ ├── test_hypersphere_loss.py │ ├── test_emp_ssl_loss.py │ ├── test_negative_cosine_similarity.py │ ├── test_koleo_loss.py │ ├── test_MMCR_loss.py │ ├── test_wmse_loss.py │ ├── test_co2.py │ └── test_ibot_loss.py ├── UNMOCKED_end2end_tests │ ├── delete_datasets_test_unmocked_cli.py │ ├── create_custom_metadata_from_input_dir.py │ ├── scripts_for_reproducing_problems │ │ └── test_api_latency.py │ ├── run_all_unmocked_tests.sh │ └── README.md ├── test_requirements.py ├── data │ └── test_multi_view_collate.py ├── api │ ├── test_rest_parser.py │ └── test_swagger_api_client.py └── embedding │ └── test_callbacks.py ├── lightly ├── transforms │ ├── py.typed │ ├── utils.py │ ├── multi_view_transform.py │ ├── rfft2d_transform.py │ ├── multi_view_transform_v2.py │ ├── amplitude_rescale_transform.py │ ├── irfft2d_transform.py │ ├── solarize.py │ ├── random_frequency_mask_transform.py │ ├── aim_transform.py │ └── tico_transform.py ├── utils │ ├── __init__.py │ ├── cropping │ │ ├── __init__.py │ │ └── read_yolo_label_file.py │ ├── benchmarking │ │ ├── __init__.py │ │ ├── topk.py │ │ └── online_linear_classifier.py │ ├── version_compare.py │ ├── reordering.py │ └── hipify.py ├── openapi_generated │ ├── __init__.py │ └── swagger_client │ │ ├── api_response.py │ │ ├── models │ │ ├── docker_worker_type.py │ │ ├── sample_sort_by.py │ │ ├── shared_access_type.py │ │ ├── image_type.py │ │ ├── dataset_type.py │ │ ├── file_output_format.py │ │ ├── team_role.py │ │ ├── user_type.py │ │ ├── expiry_handling_strategy_v3.py │ │ ├── job_state.py │ │ ├── lightly_trainer_precision_v2.py │ │ ├── lightly_trainer_precision_v3.py │ │ ├── docker_worker_state.py │ │ ├── sample_type.py │ │ ├── sampling_method.py │ │ ├── job_status_upload_method.py │ │ ├── lightly_docker_selection_method.py │ │ ├── creator.py │ │ ├── docker_run_scheduled_state.py │ │ ├── sample_partial_mode.py │ │ ├── docker_run_scheduled_priority.py │ │ ├── tag_change_data_operation_method.py │ │ ├── dimensionality_reduction_method.py │ │ ├── docker_run_artifact_storage_location.py │ │ ├── job_result_type.py │ │ ├── selection_input_type.py │ │ ├── selection_input_predictions_name.py │ │ ├── dataset_creator.py │ │ ├── lightly_model_v2.py │ │ ├── lightly_model_v3.py │ │ ├── selection_strategy_type.py │ │ ├── tag_arithmetics_operation.py │ │ ├── docker_run_log_level.py │ │ ├── usage.py │ │ ├── selection_strategy_type_v3.py │ │ ├── task_type.py │ │ ├── selection_strategy_threshold_operation.py │ │ ├── tag_creator.py │ │ ├── docker_run_artifact_type.py │ │ ├── datasource_purpose.py │ │ ├── file_name_format.py │ │ └── sector.py │ │ └── api │ │ └── __init__.py ├── loss │ ├── memory_bank.py │ ├── regularizer │ │ └── __init__.py │ ├── __init__.py │ └── negative_cosine_similarity.py ├── cli │ ├── config │ │ ├── __init__.py │ │ ├── get_config.py │ │ └── lightly-serve.yaml │ ├── __init__.py │ ├── version_cli.py │ └── _cli_simclr.py ├── active_learning │ ├── config │ │ └── __init__.py │ └── __init__.py ├── embedding │ └── __init__.py ├── api │ └── __init__.py ├── data │ └── __init__.py └── models │ ├── __init__.py │ ├── zoo.py │ └── modules │ └── __init__.py ├── .gitattributes ├── docs ├── .gcloudignore ├── source │ ├── favicon.png │ ├── lightly.core.rst │ ├── tutorials_source │ │ ├── package │ │ │ ├── README.txt │ │ │ └── Base-RCNN-FPN.yaml │ │ └── platform │ │ │ ├── README.txt │ │ │ └── images │ │ │ ├── final_tag_tree.png │ │ │ ├── selection_create_request.png │ │ │ ├── custom_metadata_scatter_plot.png │ │ │ ├── clothing-dataset-small-structure.png │ │ │ ├── custom_metadata_configuration_aquarium.png │ │ │ ├── custom_metadata_scatter_plot_with_tag.png │ │ │ ├── tutorial_cropped_objects_metadata │ │ │ ├── sku110k_select_nn.jpg │ │ │ ├── sku110k_lightly_magic.jpg │ │ │ ├── sku110k_config_metadata.jpg │ │ │ ├── sku110k_selected_images.jpg │ │ │ └── sku110k_explore_sort_objectness.jpg │ │ │ └── tutorial_export_labelstudio │ │ │ ├── labelstudio_import_dialog.jpg │ │ │ ├── labelstudio_imported_file.jpg │ │ │ └── labelstudio_import_finished.jpg │ ├── lightly.rst │ ├── lightly.models.utils.rst │ ├── getting_started │ │ ├── images │ │ │ ├── cli_example.png │ │ │ ├── webapp_tags.gif │ │ │ ├── dataset_home.jpg │ │ │ ├── webapp_token.gif │ │ │ ├── al_accuracy_plot.png │ │ │ ├── dataset_overview.jpg │ │ │ ├── lightly_overview.png │ │ │ ├── webapp_embedding.jpg │ │ │ ├── dino_augmentations.jpg │ │ │ ├── webapp_dataset_id.jpg │ │ │ ├── webapp_image_upload.jpg │ │ │ ├── simclr_augmentations.jpg │ │ │ ├── webapp_create_selection.gif │ │ │ ├── webapp_selection_new_tag.gif │ │ │ ├── simclr_augmentations_no_blur.jpg │ │ │ ├── custom_metadata_weather_temperature.png │ │ │ └── cifar10_benchmark_knn_accuracy_800_epochs.png │ │ ├── code_examples │ │ │ └── plot_image_augmentations.py │ │ └── install.rst │ ├── _templates │ │ └── autosummary │ │ │ ├── base.rst │ │ │ ├── module.rst │ │ │ └── class.rst │ ├── _static │ │ └── my-styles.css │ ├── docker_archive │ │ └── overview.rst │ ├── lightly.api.rst │ ├── lightly.data.rst │ ├── lightly.models.rst │ ├── examples │ │ ├── models.rst │ │ └── smog.rst │ ├── lightly.utils.rst │ └── lightly.loss.rst ├── logos │ ├── lightly_logo.png │ ├── lightly_logo_crop.png │ ├── lightly_logo_square.png │ ├── lightly_SSL_logo_crop.png │ ├── lightly_logo_crop_white_text.png │ └── lightly_SSL_logo_crop_white_text.png ├── app.yaml └── make.bat ├── examples ├── pytorch_lightning │ └── ijepa.py ├── pytorch_lightning_distributed │ └── ijepa.py ├── notebooks │ ├── pytorch_lightning │ │ └── ijepa.ipynb │ └── pytorch_lightning_distributed │ │ └── ijepa.ipynb └── README.md ├── MANIFEST.in ├── requirements ├── openapi.txt ├── base.txt └── README.md ├── setup.py ├── codecov.yml ├── .github ├── PULL_REQUEST_TEMPLATE │ ├── PR_template_minimal.md │ ├── PR_template_checklist.md │ └── PR_template_checklist_full.md └── workflows │ ├── test_code_format.yml │ ├── tests_unmocked.yml │ ├── check_example_nbs.yml │ ├── weekly_dependency_test.yml │ ├── discord_release_notification.yml │ ├── test_api_deps_only.yml │ └── test_minimal_deps.yml ├── CITATION.cff ├── .readthedocs.yml ├── MAINTAINING.md ├── .gitignore ├── LICENSE.txt └── .pre-commit-config.yaml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lightly/transforms/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lightly/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/api_workflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lightly/openapi_generated/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lightly/utils/cropping/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/benchmarking/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | lightly/openapi_generated linguist-generated=true 2 | -------------------------------------------------------------------------------- /docs/.gcloudignore: -------------------------------------------------------------------------------- 1 | source 2 | _data 3 | logos 4 | **lightning_logs 5 | .* -------------------------------------------------------------------------------- /examples/pytorch_lightning/ijepa.py: -------------------------------------------------------------------------------- 1 | # TODO, for now please refer to our pure pytorch example 2 | -------------------------------------------------------------------------------- /docs/source/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/favicon.png -------------------------------------------------------------------------------- /examples/pytorch_lightning_distributed/ijepa.py: -------------------------------------------------------------------------------- 1 | # TODO, for now please refer to our pure pytorch example 2 | -------------------------------------------------------------------------------- /docs/source/lightly.core.rst: -------------------------------------------------------------------------------- 1 | .core 2 | ================== 3 | 4 | .. automodule:: lightly.core 5 | :members: -------------------------------------------------------------------------------- /docs/source/tutorials_source/package/README.txt: -------------------------------------------------------------------------------- 1 | This is my gallery 2 | ================== 3 | 4 | Coming soon! -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/README.txt: -------------------------------------------------------------------------------- 1 | This is my gallery 2 | ================== 3 | 4 | Coming soon! -------------------------------------------------------------------------------- /docs/logos/lightly_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/logos/lightly_logo.png -------------------------------------------------------------------------------- /lightly/transforms/utils.py: -------------------------------------------------------------------------------- 1 | IMAGENET_NORMALIZE = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} 2 | -------------------------------------------------------------------------------- /tests/transforms/test_multi_crop_transform.py: -------------------------------------------------------------------------------- 1 | from lightly.transforms.multi_crop_transform import MultiCropTranform 2 | -------------------------------------------------------------------------------- /docs/logos/lightly_logo_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/logos/lightly_logo_crop.png -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # exclude benchmarks, docs, examples, and tests 2 | prune docs 3 | prune benchmarks 4 | prune examples 5 | prune tests -------------------------------------------------------------------------------- /docs/logos/lightly_logo_square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/logos/lightly_logo_square.png -------------------------------------------------------------------------------- /docs/logos/lightly_SSL_logo_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/logos/lightly_SSL_logo_crop.png -------------------------------------------------------------------------------- /docs/source/lightly.rst: -------------------------------------------------------------------------------- 1 | lightly 2 | ====================== 3 | 4 | .. automodule:: lightly 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /requirements/openapi.txt: -------------------------------------------------------------------------------- 1 | python_dateutil >= 2.5.3 2 | setuptools >= 21.0.0 3 | urllib3 >= 1.25.3 4 | pydantic >= 1.10.5 5 | aenum >= 3.1.11 6 | -------------------------------------------------------------------------------- /docs/logos/lightly_logo_crop_white_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/logos/lightly_logo_crop_white_text.png -------------------------------------------------------------------------------- /docs/source/lightly.models.utils.rst: -------------------------------------------------------------------------------- 1 | lightly.models.utils 2 | ===================== 3 | 4 | .. automodule:: lightly.models.utils 5 | :members: -------------------------------------------------------------------------------- /docs/logos/lightly_SSL_logo_crop_white_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/logos/lightly_SSL_logo_crop_white_text.png -------------------------------------------------------------------------------- /docs/source/getting_started/images/cli_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/cli_example.png -------------------------------------------------------------------------------- /docs/source/getting_started/images/webapp_tags.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/webapp_tags.gif -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/base.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. auto{{ objtype }}:: {{ objname }} -------------------------------------------------------------------------------- /docs/source/getting_started/images/dataset_home.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/dataset_home.jpg -------------------------------------------------------------------------------- /docs/source/getting_started/images/webapp_token.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/webapp_token.gif -------------------------------------------------------------------------------- /docs/source/_static/my-styles.css: -------------------------------------------------------------------------------- 1 | body{ 2 | font-family:Arial, Helvetica, sans-serif; 3 | } 4 | h1{ 5 | font-family:Arial, Helvetica, sans-serif !important; 6 | } -------------------------------------------------------------------------------- /docs/source/getting_started/images/al_accuracy_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/al_accuracy_plot.png -------------------------------------------------------------------------------- /docs/source/getting_started/images/dataset_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/dataset_overview.jpg -------------------------------------------------------------------------------- /docs/source/getting_started/images/lightly_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/lightly_overview.png -------------------------------------------------------------------------------- /docs/source/getting_started/images/webapp_embedding.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/webapp_embedding.jpg -------------------------------------------------------------------------------- /docs/source/getting_started/images/dino_augmentations.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/dino_augmentations.jpg -------------------------------------------------------------------------------- /docs/source/getting_started/images/webapp_dataset_id.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/webapp_dataset_id.jpg -------------------------------------------------------------------------------- /docs/source/getting_started/images/webapp_image_upload.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/webapp_image_upload.jpg -------------------------------------------------------------------------------- /docs/source/getting_started/images/simclr_augmentations.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/simclr_augmentations.jpg -------------------------------------------------------------------------------- /lightly/loss/memory_bank.py: -------------------------------------------------------------------------------- 1 | # For backwards compatibility as memory_bank module was previously in loss module. 2 | from lightly.models.modules.memory_bank import MemoryBankModule 3 | -------------------------------------------------------------------------------- /docs/source/getting_started/images/webapp_create_selection.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/webapp_create_selection.gif -------------------------------------------------------------------------------- /docs/source/getting_started/images/webapp_selection_new_tag.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/webapp_selection_new_tag.gif -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/final_tag_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/final_tag_tree.png -------------------------------------------------------------------------------- /docs/source/getting_started/images/simclr_augmentations_no_blur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/simclr_augmentations_no_blur.jpg -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/selection_create_request.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/selection_create_request.png -------------------------------------------------------------------------------- /docs/source/getting_started/images/custom_metadata_weather_temperature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/custom_metadata_weather_temperature.png -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/custom_metadata_scatter_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/custom_metadata_scatter_plot.png -------------------------------------------------------------------------------- /docs/source/getting_started/images/cifar10_benchmark_knn_accuracy_800_epochs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/getting_started/images/cifar10_benchmark_knn_accuracy_800_epochs.png -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/clothing-dataset-small-structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/clothing-dataset-small-structure.png -------------------------------------------------------------------------------- /tests/api_workflow/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | _CHARACTER_SET = "abcdef0123456789" 4 | 5 | 6 | def generate_id(length: int = 24) -> str: 7 | return "".join([random.choice(_CHARACTER_SET) for i in range(length)]) 8 | -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/custom_metadata_configuration_aquarium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/custom_metadata_configuration_aquarium.png -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/custom_metadata_scatter_plot_with_tag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/custom_metadata_scatter_plot_with_tag.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | # For compatibility with legacy builds or versions of tools 4 | # that don’t support certain packaging standards (e.g. PEP 517 5 | # or PEP 660). The configuration is kept in pyproject.toml 6 | setup() 7 | -------------------------------------------------------------------------------- /lightly/cli/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ lightly.cli.config 2 | 3 | The lightly.cli.config module holds default configs for the command-line 4 | interface. 5 | """ 6 | 7 | # Copyright (c) 2020. Lightly AG and its affiliates. 8 | # All Rights Reserved 9 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: auto # This sets the target to your last upload (useful for incremental changes) 6 | threshold: 0.5% # This allows for a decrease of up to 0.5% in coverage 7 | -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_select_nn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_select_nn.jpg -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/tutorial_export_labelstudio/labelstudio_import_dialog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/tutorial_export_labelstudio/labelstudio_import_dialog.jpg -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/tutorial_export_labelstudio/labelstudio_imported_file.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/tutorial_export_labelstudio/labelstudio_imported_file.jpg -------------------------------------------------------------------------------- /lightly/active_learning/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ Collection of Selection Configurations """ 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | 6 | from lightly.active_learning.config.selection_config import SelectionConfig 7 | -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_lightly_magic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_lightly_magic.jpg -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/tutorial_export_labelstudio/labelstudio_import_finished.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/tutorial_export_labelstudio/labelstudio_import_finished.jpg -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_config_metadata.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_config_metadata.jpg -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_selected_images.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_selected_images.jpg -------------------------------------------------------------------------------- /docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_explore_sort_objectness.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/lightly/master/docs/source/tutorials_source/platform/images/tutorial_cropped_objects_metadata/sku110k_explore_sort_objectness.jpg -------------------------------------------------------------------------------- /lightly/loss/regularizer/__init__.py: -------------------------------------------------------------------------------- 1 | """The lightly.loss.regularizer package provides regularizers for self-supervised learning. """ 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | 6 | from lightly.loss.regularizer.co2 import CO2Regularizer 7 | -------------------------------------------------------------------------------- /tests/transforms/test_rfft2d_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lightly.transforms import RFFT2DTransform 4 | 5 | 6 | def test() -> None: 7 | transform = RFFT2DTransform() 8 | image = torch.rand(3, 32, 32) 9 | output = transform(image) 10 | assert output.shape == (3, 32, 17) 11 | -------------------------------------------------------------------------------- /tests/transforms/test_irfft2d_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lightly.transforms import IRFFT2DTransform 4 | 5 | 6 | def test() -> None: 7 | transform = IRFFT2DTransform((32, 32)) 8 | image = torch.rand(3, 32, 17) 9 | output = transform(image) 10 | assert output.shape == (3, 32, 32) 11 | -------------------------------------------------------------------------------- /tests/transforms/test_gaussian_mixture_masks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lightly.transforms import GaussianMixtureMask 4 | 5 | 6 | def test() -> None: 7 | transform = GaussianMixtureMask(20, (10, 15)) 8 | image = torch.rand(3, 32, 17) 9 | output = transform(image) 10 | assert output.shape == image.shape 11 | -------------------------------------------------------------------------------- /tests/transforms/test_jigsaw.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from PIL import Image 4 | 5 | from lightly.transforms.jigsaw import Jigsaw 6 | 7 | 8 | class TestJigsaw(unittest.TestCase): 9 | def test_on_pil_image(self) -> None: 10 | crop = Jigsaw() 11 | sample = Image.new("RGB", (255, 255)) 12 | crop(sample) 13 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | # Minimal dependencies to use the package with the API client. 2 | # See requirements/README.md for more information. 3 | certifi>=14.05.14 4 | hydra-core>=1.0.0 5 | lightly_utils~=0.0.0 6 | numpy>=1.18.1 7 | python_dateutil>=2.5.3 8 | requests>=2.23.0 9 | six>=1.10 10 | tqdm>=4.44 11 | urllib3 >= 1.25.3 12 | pydantic >= 1.10.5 13 | aenum >= 3.1.11 -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/PR_template_minimal.md: -------------------------------------------------------------------------------- 1 | closes #issue_number 2 | 3 | ## Description 4 | Please_describe_what_you_changed_and_why___You_do_not_need_to_repeat_stuff_from_the_issue 5 | 6 | ## Documentation 7 | - [ ] I have updated the documentation. 8 | - [ ] I need help on it. 9 | 10 | ## Tests 11 | - [ ] I have updated the tests. 12 | - [ ] I need help on it. -------------------------------------------------------------------------------- /docs/source/docker_archive/overview.rst: -------------------------------------------------------------------------------- 1 | Lightly\ **One** Worker 2 | =================================== 3 | 4 | .. warning:: 5 | **The Docker Archive documentation is deprecated** 6 | 7 | The old workflow described in these docs will not be supported with new Lightly\ **One** Worker versions above 2.6. 8 | Please switch to our `new documentation page `_ instead. -------------------------------------------------------------------------------- /tests/cli/test_cli_get_lighty_config.py: -------------------------------------------------------------------------------- 1 | from lightly.cli.config.get_config import get_lightly_config 2 | 3 | 4 | def test_get_lightly_config() -> None: 5 | conf = get_lightly_config() 6 | # Assert some default values 7 | assert conf.checkpoint == "" 8 | assert conf.loader.batch_size == 16 9 | assert conf.trainer.weights_summary is None 10 | assert conf.summary_callback.max_depth == 1 11 | -------------------------------------------------------------------------------- /docs/source/lightly.api.rst: -------------------------------------------------------------------------------- 1 | lightly.api 2 | =================== 3 | 4 | .. automodule:: lightly.api 5 | 6 | .. automodule:: lightly.api.api_workflow_client 7 | :members: 8 | :inherited-members: 9 | 10 | .. autoclass:: lightly.api.api_workflow_compute_worker.ComputeWorkerRunInfo 11 | :members: 12 | 13 | .. autoclass:: lightly.api.api_workflow_compute_worker.InvalidConfigurationError 14 | :members: 15 | -------------------------------------------------------------------------------- /tests/transforms/test_mae_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.mae_transform import MAETransform 4 | 5 | 6 | def test_multi_view_on_pil_image() -> None: 7 | multi_view_transform = MAETransform(input_size=32) 8 | sample = Image.new("RGB", (100, 100)) 9 | output = multi_view_transform(sample) 10 | assert len(output) == 1 11 | assert output[0].shape == (3, 32, 32) 12 | -------------------------------------------------------------------------------- /lightly/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | """The lightly.embedding module provides trainable embedding strategies. 2 | 3 | The embedding models use a pre-trained ResNet but should be finetuned on each 4 | dataset instance. 5 | 6 | """ 7 | 8 | # Copyright (c) 2020. Lightly AG and its affiliates. 9 | # All Rights Reserved 10 | 11 | 12 | from lightly.embedding._base import BaseEmbedding 13 | from lightly.embedding.embedding import SelfSupervisedEmbedding 14 | -------------------------------------------------------------------------------- /lightly/active_learning/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | def raise_active_learning_deprecation_warning(): 5 | warnings.warn( 6 | "Using active learning via the lightly package is deprecated and will be removed soon. " 7 | "Please use the Lightly Solution instead. " 8 | "See https://docs.lightly.ai for more information and tutorials on doing active learning.", 9 | DeprecationWarning, 10 | stacklevel=2, 11 | ) 12 | -------------------------------------------------------------------------------- /docs/source/lightly.data.rst: -------------------------------------------------------------------------------- 1 | lightly.data 2 | =================== 3 | 4 | .. automodule:: lightly.data 5 | 6 | .dataset 7 | --------------- 8 | .. automodule:: lightly.data.dataset 9 | :members: 10 | 11 | .multi_view_collate 12 | ------------------- 13 | .. autoclass:: lightly.data.multi_view_collate.MultiViewCollate 14 | :members: 15 | :special-members: __call__ 16 | 17 | .collate: 18 | --------- 19 | .. automodule:: lightly.data.collate 20 | :members: 21 | -------------------------------------------------------------------------------- /examples/notebooks/pytorch_lightning/ijepa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0", 6 | "metadata": {}, 7 | "source": [ 8 | "TODO, for now please refer to our pure pytorch example" 9 | ] 10 | } 11 | ], 12 | "metadata": { 13 | "jupytext": { 14 | "cell_metadata_filter": "-all", 15 | "main_language": "python", 16 | "notebook_metadata_filter": "-all" 17 | } 18 | }, 19 | "nbformat": 4, 20 | "nbformat_minor": 5 21 | } 22 | -------------------------------------------------------------------------------- /tests/transforms/test_Solarize.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from PIL import Image 4 | 5 | from lightly.transforms.solarize import RandomSolarization 6 | 7 | 8 | class TestRandomSolarization(unittest.TestCase): 9 | def test_on_pil_image(self) -> None: 10 | for w in [32, 64, 128]: 11 | for h in [32, 64, 128]: 12 | solarization = RandomSolarization(0.5) 13 | sample = Image.new("RGB", (w, h)) 14 | solarization(sample) 15 | -------------------------------------------------------------------------------- /tests/transforms/test_random_frequency_mask_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lightly.transforms import RandomFrequencyMaskTransform, RFFT2DTransform 4 | 5 | 6 | def test() -> None: 7 | rfm_transform = RandomFrequencyMaskTransform() 8 | rfft2d_transform = RFFT2DTransform() 9 | image = torch.randn(3, 64, 64) 10 | fft_image = rfft2d_transform(image) 11 | transformed_image = rfm_transform(fft_image) 12 | 13 | assert transformed_image.shape == fft_image.shape 14 | -------------------------------------------------------------------------------- /examples/notebooks/pytorch_lightning_distributed/ijepa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0", 6 | "metadata": {}, 7 | "source": [ 8 | "TODO, for now please refer to our pure pytorch example" 9 | ] 10 | } 11 | ], 12 | "metadata": { 13 | "jupytext": { 14 | "cell_metadata_filter": "-all", 15 | "main_language": "python", 16 | "notebook_metadata_filter": "-all" 17 | } 18 | }, 19 | "nbformat": 4, 20 | "nbformat_minor": 5 21 | } 22 | -------------------------------------------------------------------------------- /lightly/cli/config/get_config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | 6 | def get_lightly_config() -> DictConfig: 7 | config_path = Path(__file__).with_name("config.yaml") 8 | conf = OmegaConf.load(config_path) 9 | # TODO(Huan, 05.04.2023): remove this when hydra is completely dropped 10 | if conf.get("hydra"): 11 | # This config entry is only for hydra; not referenced in any logic 12 | del conf["hydra"] 13 | return conf 14 | -------------------------------------------------------------------------------- /lightly/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """ The lightly.cli module provides a console interface 2 | for training self-supervised models, embedding, 3 | and filtering datasets 4 | """ 5 | 6 | # Copyright (c) 2020. Lightly AG and its affiliates. 7 | # All Rights Reserved 8 | 9 | from lightly.cli.crop_cli import crop_cli 10 | from lightly.cli.download_cli import download_cli 11 | from lightly.cli.embed_cli import embed_cli 12 | from lightly.cli.lightly_cli import lightly_cli 13 | from lightly.cli.train_cli import train_cli 14 | -------------------------------------------------------------------------------- /tests/transforms/test_densecl_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms import DenseCLTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_multi_view_on_pil_image() -> None: 9 | multi_view_transform = DenseCLTransform(input_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 12 | assert len(output) == 2 13 | assert output[0].shape == (3, 32, 32) 14 | assert output[1].shape == (3, 32, 32) 15 | -------------------------------------------------------------------------------- /lightly/utils/benchmarking/__init__.py: -------------------------------------------------------------------------------- 1 | from lightly.utils.benchmarking.benchmark_module import BenchmarkModule 2 | from lightly.utils.benchmarking.knn import knn_predict 3 | from lightly.utils.benchmarking.knn_classifier import KNNClassifier 4 | from lightly.utils.benchmarking.linear_classifier import ( 5 | FinetuneClassifier, 6 | LinearClassifier, 7 | ) 8 | from lightly.utils.benchmarking.metric_callback import MetricCallback 9 | from lightly.utils.benchmarking.online_linear_classifier import OnlineLinearClassifier 10 | -------------------------------------------------------------------------------- /tests/transforms/test_pirl_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.pirl_transform import PIRLTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_multi_view_on_pil_image() -> None: 9 | multi_view_transform = PIRLTransform(input_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 12 | assert len(output) == 2 13 | assert output[0].shape == (3, 32, 32) 14 | assert output[1].shape == (9, 3, 10, 10) 15 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Susmelj 5 | given-names: Igor 6 | - family-names: Heller 7 | given-names: Matthias 8 | - family-names: Wirth 9 | given-names: Philipp 10 | - family-names: Prescott 11 | given-names: Jeremy 12 | - family-names: Ebner 13 | given-names: Malte 14 | - name: "et al." 15 | title: "Lightly" 16 | type: software 17 | url: "https://github.com/lightly-ai/lightly" 18 | date-released: 2020 19 | -------------------------------------------------------------------------------- /docs/source/lightly.models.rst: -------------------------------------------------------------------------------- 1 | lightly.models 2 | =================== 3 | 4 | .. automodule:: lightly.models 5 | 6 | .resnet 7 | --------------- 8 | .. automodule:: lightly.models.resnet 9 | :members: 10 | 11 | .zoo 12 | --------------- 13 | .. automodule:: lightly.models.zoo 14 | :members: 15 | 16 | .. automodule:: lightly.models.modules 17 | 18 | .nn_memory_bank 19 | --------------- 20 | .. automodule:: lightly.models.modules.nn_memory_bank 21 | :members: 22 | 23 | .heads 24 | -------- 25 | .. automodule:: lightly.models.modules.heads 26 | :members: 27 | -------------------------------------------------------------------------------- /tests/utils/benchmarking/test_topk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lightly.utils.benchmarking import topk 4 | 5 | 6 | def test_mean_topk_accuracy() -> None: 7 | predicted_classes = torch.tensor( 8 | [ 9 | [1, 2, 3, 4], 10 | [4, 1, 10, 0], 11 | [3, 1, 5, 8], 12 | ] 13 | ) 14 | targets = torch.tensor([1, 10, 8]) 15 | assert topk.mean_topk_accuracy(predicted_classes, targets, k=(1, 2, 3, 4, 5)) == { 16 | 1: 1 / 3, 17 | 2: 1 / 3, 18 | 3: 2 / 3, 19 | 4: 1.0, 20 | 5: 1.0, 21 | } 22 | -------------------------------------------------------------------------------- /tests/transforms/test_torchvision_v2compatibility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from torch import Tensor 5 | 6 | from lightly.transforms import ToTensor 7 | 8 | 9 | def test_ToTensor() -> None: 10 | img_np = np.random.randint(0, 255, (20, 30, 3), dtype=np.uint8) 11 | img_pil = Image.fromarray(img_np) 12 | img_tens = ToTensor()(img_pil) 13 | assert isinstance(img_tens, Tensor) 14 | assert img_tens.shape == (3, 20, 30) 15 | assert img_tens.dtype == torch.float32 16 | assert img_tens.max() <= 1.0 and img_tens.min() >= 0.0 17 | -------------------------------------------------------------------------------- /lightly/cli/config/lightly-serve.yaml: -------------------------------------------------------------------------------- 1 | input_mount: '' # Path to the input directory. 2 | lightly_mount: '' # Path to the lightly directory. 3 | host: 'localhost' # Hostname for serving the data. 4 | port: 3456 # Port for serving the data. 5 | ssl_cert: '' # Path to the SSL certificate. 6 | ssl_key: '' # Path to the SSL key. 7 | 8 | # Disable Hydra log directories. 9 | defaults: 10 | - _self_ 11 | - override hydra/hydra_logging: disabled 12 | - override hydra/job_logging: disabled 13 | 14 | hydra: 15 | output_subdir: null 16 | run: 17 | dir: . -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. rubric:: Description 4 | 5 | .. automodule:: {{ fullname }} 6 | 7 | .. currentmodule:: {{ fullname }} 8 | 9 | {% if classes %} 10 | .. rubric:: Classes 11 | 12 | .. autosummary:: 13 | :toctree: . 14 | {% for class in classes %} 15 | {{ class }} 16 | {% endfor %} 17 | 18 | {% endif %} 19 | 20 | {% if functions %} 21 | .. rubric:: Functions 22 | 23 | .. autosummary:: 24 | :toctree: . 25 | {% for function in functions %} 26 | {{ function }} 27 | {% endfor %} 28 | 29 | {% endif %} -------------------------------------------------------------------------------- /tests/transforms/test_fastsiam_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.fast_siam_transform import FastSiamTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_multi_view_on_pil_image() -> None: 9 | multi_view_transform = FastSiamTransform(num_views=3, input_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 12 | assert len(output) == 3 13 | assert output[0].shape == (3, 32, 32) 14 | assert output[1].shape == (3, 32, 32) 15 | assert output[2].shape == (3, 32, 32) 16 | -------------------------------------------------------------------------------- /tests/transforms/test_multi_view_transform.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torchvision.transforms as T 4 | from PIL import Image 5 | 6 | from lightly.transforms.multi_view_transform import MultiViewTransform 7 | 8 | 9 | def test_multi_view_on_pil_image() -> None: 10 | multi_view_transform = MultiViewTransform( 11 | [ 12 | T.RandomHorizontalFlip(p=0.1), 13 | T.RandomVerticalFlip(p=0.5), 14 | T.RandomGrayscale(p=0.3), 15 | ] 16 | ) 17 | sample = Image.new("RGB", (10, 10)) 18 | output = multi_view_transform(sample) 19 | assert len(output) == 3 20 | -------------------------------------------------------------------------------- /lightly/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ The lightly.api module provides access to the Lightly API.""" 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | from lightly.api import patch as _patch 6 | from lightly.api.api_workflow_artifacts import ArtifactNotExist 7 | from lightly.api.api_workflow_client import ApiWorkflowClient 8 | from lightly.openapi_generated.swagger_client.api_client import ( 9 | Configuration as _Configuration, 10 | ) 11 | 12 | # Make ApiWorkflowClient and swagger classes picklable. 13 | _patch.make_swagger_configuration_picklable( 14 | configuration_cls=_Configuration, 15 | ) 16 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | 12 | # Optionally build your docs in additional formats such as PDF 13 | formats: 14 | - pdf 15 | 16 | # Optionally set the version of Python and requirements required to build your docs 17 | python: 18 | version: 3.7 19 | install: 20 | - method: pip 21 | path: . 22 | extra_requirements: 23 | - dev 24 | -------------------------------------------------------------------------------- /docs/app.yaml: -------------------------------------------------------------------------------- 1 | # When using the first time you need to authenticate your machine 2 | # running `gcloud auth login` in your terminal. 3 | # 4 | # To deploy the app simply run `gcloud app deploy` from a terminal 5 | # within the docs folder. 6 | 7 | runtime: python27 8 | api_version: 1 9 | threadsafe: true 10 | service: docs 11 | 12 | handlers: 13 | - url: / 14 | secure: always 15 | redirect_http_response_code: 301 16 | static_files: build/html/index.html 17 | upload: build/html/index.html 18 | 19 | - url: /(.*) 20 | secure: always 21 | redirect_http_response_code: 301 22 | static_files: build/html/\1 23 | upload: build/html/(.*) 24 | -------------------------------------------------------------------------------- /tests/transforms/test_mmcr_transform.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | 4 | from lightly.transforms.mmcr_transform import MMCRTransform 5 | 6 | 7 | def test_raise_value_error() -> None: 8 | with pytest.raises(ValueError): 9 | MMCRTransform(k=0) 10 | 11 | 12 | def test_num_views() -> None: 13 | multi_view_transform = MMCRTransform(k=3) 14 | assert len(multi_view_transform.transforms) == 3 15 | 16 | 17 | def test_multi_view_on_pil_image() -> None: 18 | multi_view_transform = MMCRTransform(k=3) 19 | sample = Image.new("RGB", (100, 100)) 20 | output = multi_view_transform(sample) 21 | assert len(output) == 3 22 | -------------------------------------------------------------------------------- /tests/REAMDE.md: -------------------------------------------------------------------------------- 1 | # How to write tests for the API 2 | 3 | There are three different locations to write tests for the API, each with its 4 | own advantages and disadvantages: 5 | 6 | 1. In tests/api_workflow_client: 7 | This is for testing the api_workflow_client directly with the ability to configure its mocked version fully. 8 | Furthermore, you have a (partly) stateful api_workflow_client. 9 | 2. In tests/cli: 10 | This is for testing the cli commands. However, it will use a new api_workflow_client 11 | for every new cli command. It does not allow configuring the mocked api_workflow_client. 12 | 3. In tests/UNMOCKED_end2end_tests: 13 | This runs tests against the live API. -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | from torch import Tensor 4 | 5 | 6 | def assert_list_tensor(items: Any) -> List[Tensor]: 7 | """Makes sure that the input is a list of tensors. 8 | 9 | Should be used in tests where functions return Union[List[Tensor], List[Image]] and 10 | we want to make sure that the output is a list of tensors. 11 | 12 | Example: 13 | >>> output: Union[List[Tensor], List[Image]] = transform(images) 14 | >>> tensors: List[Tensor] = assert_list_tensor(output) 15 | 16 | """ 17 | assert isinstance(items, list) 18 | assert all(isinstance(item, Tensor) for item in items) 19 | return items 20 | -------------------------------------------------------------------------------- /tests/loss/test_hypersphere_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lightly.loss.hypersphere_loss import HypersphereLoss 5 | 6 | 7 | class TestHyperSphereLoss: 8 | # NOTE: skipping bsz==1 case as its not relevant to this loss, and will produce nan-values 9 | @pytest.mark.parametrize("bsz", range(2, 20)) 10 | def test_forward_pass(self, bsz: int) -> None: 11 | loss = HypersphereLoss() 12 | 13 | batch_1 = torch.randn((bsz, 32)) 14 | batch_2 = torch.randn((bsz, 32)) 15 | 16 | # symmetry 17 | l1 = loss(batch_1, batch_2) 18 | l2 = loss(batch_2, batch_1) 19 | 20 | assert l1 == pytest.approx(l2) 21 | -------------------------------------------------------------------------------- /tests/transforms/test_wmse_transform.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | 4 | from lightly.transforms.wmse_transform import WMSETransform 5 | 6 | 7 | def test_raise_value_error() -> None: 8 | with pytest.raises(ValueError): 9 | WMSETransform(num_samples=0) 10 | 11 | 12 | def test_num_views() -> None: 13 | multi_view_transform = WMSETransform(num_samples=3) 14 | assert len(multi_view_transform.transforms) == 3 15 | 16 | 17 | def test_multi_view_on_pil_image() -> None: 18 | multi_view_transform = WMSETransform(num_samples=3) 19 | sample = Image.new("RGB", (100, 100)) 20 | output = multi_view_transform(sample) 21 | assert len(output) == 3 22 | -------------------------------------------------------------------------------- /docs/source/examples/models.rst: -------------------------------------------------------------------------------- 1 | .. _models: 2 | 3 | Models 4 | ====== 5 | 6 | We provide example implementations for self-supervised learning models 7 | for PyTorch and PyTorch Lightning to give you a headstart when implementing your own model! 8 | 9 | 10 | .. toctree:: 11 | :maxdepth: 1 12 | 13 | aim.rst 14 | barlowtwins.rst 15 | byol.rst 16 | dcl.rst 17 | densecl.rst 18 | dino.rst 19 | fastsiam.rst 20 | mae.rst 21 | mmcr.rst 22 | msn.rst 23 | moco.rst 24 | nnclr.rst 25 | pmsn.rst 26 | simclr.rst 27 | simmim.rst 28 | simsiam.rst 29 | smog.rst 30 | swav.rst 31 | tico.rst 32 | vicreg.rst 33 | vicregl.rst 34 | -------------------------------------------------------------------------------- /tests/loss/test_emp_ssl_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lightly.loss.emp_ssl_loss import EMPSSLLoss 5 | 6 | 7 | class testEMPSSSLLoss: 8 | def test_forward(self) -> None: 9 | bs = 512 10 | dim = 128 11 | num_views = 100 12 | 13 | loss_fn = EMPSSLLoss() 14 | x = [torch.randn(bs, dim)] * num_views 15 | 16 | loss_fn(x) 17 | 18 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") 19 | def test_forward_cuda(self) -> None: 20 | bs = 512 21 | dim = 128 22 | num_views = 100 23 | 24 | loss_fn = EMPSSLLoss().cuda() 25 | x = x = [torch.randn(bs, dim).cuda()] * num_views 26 | 27 | loss_fn(x) 28 | -------------------------------------------------------------------------------- /tests/transforms/test_amplitude_rescale_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lightly.transforms import AmplitudeRescaleTransform, RFFT2DTransform 4 | 5 | 6 | # Testing function image -> FFT -> AmplitudeRescale. 7 | # Compare shapes of source and result. 8 | def test() -> None: 9 | image = torch.randn(3, 64, 64) 10 | 11 | rfftTransform = RFFT2DTransform() 12 | rfft = rfftTransform(image) 13 | 14 | ampRescaleTf_1 = AmplitudeRescaleTransform() 15 | rescaled_rfft_1 = ampRescaleTf_1(rfft) 16 | 17 | ampRescaleTf_2 = AmplitudeRescaleTransform(range=(1.0, 2.0)) 18 | rescaled_rfft_2 = ampRescaleTf_2(rfft) 19 | 20 | assert rescaled_rfft_1.shape == rfft.shape 21 | assert rescaled_rfft_2.shape == rfft.shape 22 | -------------------------------------------------------------------------------- /tests/transforms/test_gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from PIL import Image 4 | 5 | from lightly.transforms.gaussian_blur import GaussianBlur 6 | 7 | 8 | class TestGaussianBlur(unittest.TestCase): 9 | def test_on_pil_image(self) -> None: 10 | for w in range(1, 100): 11 | for h in range(1, 100): 12 | gaussian_blur = GaussianBlur() 13 | sample = Image.new("RGB", (w, h)) 14 | gaussian_blur(sample) 15 | 16 | def test_raise_kernel_size_deprecation(self) -> None: 17 | gaussian_blur = GaussianBlur(kernel_size=2) 18 | self.assertWarns(DeprecationWarning) 19 | 20 | def test_raise_scale_deprecation(self) -> None: 21 | gaussian_blur = GaussianBlur(scale=0.1) 22 | self.assertWarns(DeprecationWarning) 23 | -------------------------------------------------------------------------------- /tests/transforms/test_simclr_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_view_on_pil_image() -> None: 9 | single_view_transform = SimCLRViewTransform(input_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = single_view_transform(sample) 12 | assert output.shape == (3, 32, 32) 13 | 14 | 15 | def test_multi_view_on_pil_image() -> None: 16 | multi_view_transform = SimCLRTransform(input_size=32) 17 | sample = Image.new("RGB", (100, 100)) 18 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 19 | assert len(output) == 2 20 | assert output[0].shape == (3, 32, 32) 21 | assert output[1].shape == (3, 32, 32) 22 | -------------------------------------------------------------------------------- /tests/transforms/test_vicreg_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_view_on_pil_image() -> None: 9 | single_view_transform = VICRegViewTransform(input_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = single_view_transform(sample) 12 | assert output.shape == (3, 32, 32) 13 | 14 | 15 | def test_multi_view_on_pil_image() -> None: 16 | multi_view_transform = VICRegTransform(input_size=32) 17 | sample = Image.new("RGB", (100, 100)) 18 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 19 | assert len(output) == 2 20 | assert output[0].shape == (3, 32, 32) 21 | assert output[1].shape == (3, 32, 32) 22 | -------------------------------------------------------------------------------- /lightly/data/__init__.py: -------------------------------------------------------------------------------- 1 | """The lightly.data module provides a dataset wrapper and collate functions. """ 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | 6 | from lightly.data._video import ( 7 | EmptyVideoError, 8 | NonIncreasingTimestampError, 9 | UnseekableTimestampError, 10 | VideoError, 11 | ) 12 | from lightly.data.collate import ( 13 | BaseCollateFunction, 14 | DINOCollateFunction, 15 | ImageCollateFunction, 16 | MAECollateFunction, 17 | MoCoCollateFunction, 18 | MSNCollateFunction, 19 | MultiCropCollateFunction, 20 | PIRLCollateFunction, 21 | SimCLRCollateFunction, 22 | SwaVCollateFunction, 23 | VICRegLCollateFunction, 24 | imagenet_normalize, 25 | ) 26 | from lightly.data.dataset import LightlyDataset 27 | -------------------------------------------------------------------------------- /tests/transforms/test_simsiam_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.simsiam_transform import SimSiamTransform, SimSiamViewTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_view_on_pil_image() -> None: 9 | single_view_transform = SimSiamViewTransform(input_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = single_view_transform(sample) 12 | assert output.shape == (3, 32, 32) 13 | 14 | 15 | def test_multi_view_on_pil_image() -> None: 16 | multi_view_transform = SimSiamTransform(input_size=32) 17 | sample = Image.new("RGB", (100, 100)) 18 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 19 | assert len(output) == 2 20 | assert output[0].shape == (3, 32, 32) 21 | assert output[1].shape == (3, 32, 32) 22 | -------------------------------------------------------------------------------- /lightly/cli/version_cli.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """**Lightly Version:** Show the version of the installed package. 3 | 4 | Example: 5 | >>> # show the version of the installed package 6 | >>> lightly-version 7 | """ 8 | 9 | # Copyright (c) 2021. Lightly AG and its affiliates. 10 | # All Rights Reserved 11 | 12 | import hydra 13 | 14 | import lightly 15 | from lightly.cli._helpers import fix_hydra_arguments 16 | 17 | 18 | def _version_cli(): 19 | version = lightly.__version__ 20 | print(f"lightly version {version}", flush=True) 21 | 22 | 23 | @hydra.main(**fix_hydra_arguments(config_path="config", config_name="config")) 24 | def version_cli(cfg): 25 | """Prints the version of the used lightly package to the terminal.""" 26 | _version_cli() 27 | 28 | 29 | def entry(): 30 | version_cli() 31 | -------------------------------------------------------------------------------- /tests/transforms/test_swav_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.swav_transform import SwaVTransform, SwaVViewTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_view_on_pil_image() -> None: 9 | single_view_transform = SwaVViewTransform() 10 | sample = Image.new("RGB", (100, 100)) 11 | output = single_view_transform(sample) 12 | assert output.shape == (3, 100, 100) 13 | 14 | 15 | def test_multi_view_on_pil_image() -> None: 16 | multi_view_transform = SwaVTransform(crop_sizes=(32, 8)) 17 | sample = Image.new("RGB", (100, 100)) 18 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 19 | assert len(output) == 8 20 | assert all(out.shape == (3, 32, 32) for out in output[:2]) 21 | assert all(out.shape == (3, 8, 8) for out in output[2:]) 22 | -------------------------------------------------------------------------------- /docs/source/lightly.utils.rst: -------------------------------------------------------------------------------- 1 | lightly.utils 2 | =================== 3 | 4 | .. automodule:: lightly.utils 5 | 6 | .io 7 | --------------- 8 | .. automodule:: lightly.utils.io 9 | :members: 10 | 11 | .embeddings_2d 12 | --------------- 13 | .. automodule:: lightly.utils.embeddings_2d 14 | :members: 15 | 16 | .benchmarking 17 | --------------- 18 | .. automodule:: lightly.utils.benchmarking 19 | :members: 20 | 21 | .debug 22 | --------------- 23 | .. automodule:: lightly.utils.debug 24 | :members: 25 | 26 | .dist 27 | --------------- 28 | .. automodule:: lightly.utils.dist 29 | :members: 30 | 31 | .reordering 32 | --------------- 33 | .. automodule:: lightly.utils.reordering 34 | :members: 35 | 36 | .version_compare 37 | ----------------- 38 | .. automodule:: lightly.utils.version_compare 39 | :members: 40 | 41 | -------------------------------------------------------------------------------- /tests/transforms/test_smog_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_view_on_pil_image() -> None: 9 | single_view_transform = SmoGViewTransform(crop_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = single_view_transform(sample) 12 | assert output.shape == (3, 32, 32) 13 | 14 | 15 | def test_multi_view_on_pil_image() -> None: 16 | multi_view_transform = SMoGTransform(crop_sizes=(32, 8)) 17 | sample = Image.new("RGB", (100, 100)) 18 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 19 | assert len(output) == 8 20 | assert all(out.shape == (3, 32, 32) for out in output[:4]) 21 | assert all(out.shape == (3, 8, 8) for out in output[4:]) 22 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/PR_template_checklist.md: -------------------------------------------------------------------------------- 1 | closes #issue_number 2 | 3 | ## Description 4 | - [ ] My change is breaking 5 | Please_describe_what_you_changed_and_why___You_do_not_need_to_repeat_stuff_from_the_issue 6 | 7 | ## Tests 8 | - [ ] My change is covered by existing tests. 9 | - [ ] My change needs new tests. 10 | - [ ] I have added/adapted the tests accordingly. 11 | - [ ] I have manually tested the change. if_yes_describe_how 12 | 13 | ## Documentation 14 | - [ ] I have added docstrings to all public functions/methods. 15 | - [ ] My change requires a change to the documentation ( `.rst` files). 16 | - [ ] I have updated the documentation accordingly. 17 | - [ ] The autodocs update the documentation accordingly. 18 | 19 | ## Implications / comments / further issues 20 | - #e_g_link_to_issue_to_cover_breaking_changes 21 | 22 | -------------------------------------------------------------------------------- /tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from lightly.api import ApiWorkflowClient 4 | 5 | if __name__ == "__main__": 6 | if len(sys.argv) == 1 + 3: 7 | num_datasets, token, date_time = (sys.argv[1 + i] for i in range(3)) 8 | else: 9 | raise ValueError( 10 | "ERROR in number of command line arguments, must be 3." 11 | "Example: python delete_datasets_test_unmocked_cli.py 6 LIGHTLY_TOKEN 2022-09-29-13-41-24" 12 | ) 13 | 14 | api_workflow_client = ApiWorkflowClient(token=token) 15 | 16 | num_datasets = int(num_datasets) 17 | for i in range(1, num_datasets + 1): 18 | dataset_name = f"test_unmocked_cli_{i}_{date_time}" 19 | api_workflow_client.set_dataset_id_by_name(dataset_name) 20 | api_workflow_client.delete_dataset_by_id(api_workflow_client.dataset_id) 21 | -------------------------------------------------------------------------------- /tests/transforms/test_msn_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_view_on_pil_image() -> None: 9 | single_view_transform = MSNViewTransform(crop_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = single_view_transform(sample) 12 | assert output.shape == (3, 32, 32) 13 | 14 | 15 | def test_multi_view_on_pil_image() -> None: 16 | multi_view_transform = MSNTransform(random_size=32, focal_size=8) 17 | sample = Image.new("RGB", (100, 100)) 18 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 19 | assert len(output) == 12 20 | # global views 21 | assert all(out.shape == (3, 32, 32) for out in output[:2]) 22 | # local views 23 | assert all(out.shape == (3, 8, 8) for out in output[2:]) 24 | -------------------------------------------------------------------------------- /tests/transforms/test_dino_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_view_on_pil_image() -> None: 9 | single_view_transform = DINOViewTransform(crop_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = single_view_transform(sample) 12 | assert output.shape == (3, 32, 32) 13 | 14 | 15 | def test_multi_view_on_pil_image() -> None: 16 | multi_view_transform = DINOTransform(global_crop_size=32, local_crop_size=8) 17 | sample = Image.new("RGB", (100, 100)) 18 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 19 | assert len(output) == 8 20 | # global views 21 | assert all(out.shape == (3, 32, 32) for out in output[:2]) 22 | # local views 23 | assert all(out.shape == (3, 8, 8) for out in output[2:]) 24 | -------------------------------------------------------------------------------- /tests/transforms/test_moco_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform 4 | 5 | from .. import helpers 6 | 7 | 8 | def test_moco_v1_multi_view_on_pil_image() -> None: 9 | multi_view_transform = MoCoV1Transform(input_size=32) 10 | sample = Image.new("RGB", (100, 100)) 11 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 12 | assert len(output) == 2 13 | assert output[0].shape == (3, 32, 32) 14 | assert output[1].shape == (3, 32, 32) 15 | 16 | 17 | def test_moco_v2_multi_view_on_pil_image() -> None: 18 | multi_view_transform = MoCoV2Transform(input_size=32) 19 | sample = Image.new("RGB", (100, 100)) 20 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 21 | assert len(output) == 2 22 | assert output[0].shape == (3, 32, 32) 23 | assert output[1].shape == (3, 32, 32) 24 | -------------------------------------------------------------------------------- /tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from lightly.data import LightlyDataset 4 | from lightly.utils.io import save_custom_metadata 5 | 6 | if __name__ == "__main__": 7 | if len(sys.argv) == 1 + 2: 8 | input_dir, metadata_filename = (sys.argv[1 + i] for i in range(2)) 9 | else: 10 | raise ValueError( 11 | "ERROR in number of command line arguments, must be 2." 12 | "Example: python create_custom_metadata_from_input_dir.py input_dir metadata_filename" 13 | ) 14 | 15 | dataset = LightlyDataset(input_dir) 16 | 17 | # create a list of pairs of (filename, metadata) 18 | custom_metadata = [] 19 | for index, filename in enumerate(dataset.get_filenames()): 20 | metadata = {"index": index} 21 | custom_metadata.append((filename, metadata)) 22 | 23 | save_custom_metadata(metadata_filename, custom_metadata) 24 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /tests/transforms/test_fda_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.fda_transform import ( 4 | FDATransform, 5 | FDAView1Transform, 6 | FDAView2Transform, 7 | ) 8 | 9 | from .. import helpers 10 | 11 | 12 | def test_view_on_pil_image() -> None: 13 | single_view_transform = FDAView1Transform(input_size=32) 14 | sample = Image.new("RGB", (100, 100)) 15 | output = single_view_transform(sample) 16 | assert output.shape == (3, 32, 32) 17 | 18 | 19 | def test_multi_view_on_pil_image() -> None: 20 | multi_view_transform = FDATransform( 21 | view_1_transform=FDAView1Transform(input_size=32), 22 | view_2_transform=FDAView2Transform(input_size=32), 23 | ) 24 | sample = Image.new("RGB", (100, 100)) 25 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 26 | assert len(output) == 2 27 | assert output[0].shape == (3, 32, 32) 28 | assert output[1].shape == (3, 32, 32) 29 | -------------------------------------------------------------------------------- /tests/transforms/test_byol_transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from lightly.transforms.byol_transform import ( 4 | BYOLTransform, 5 | BYOLView1Transform, 6 | BYOLView2Transform, 7 | ) 8 | 9 | from .. import helpers 10 | 11 | 12 | def test_view_on_pil_image() -> None: 13 | single_view_transform = BYOLView1Transform(input_size=32) 14 | sample = Image.new("RGB", (100, 100)) 15 | output = single_view_transform(sample) 16 | assert output.shape == (3, 32, 32) 17 | 18 | 19 | def test_multi_view_on_pil_image() -> None: 20 | multi_view_transform = BYOLTransform( 21 | view_1_transform=BYOLView1Transform(input_size=32), 22 | view_2_transform=BYOLView2Transform(input_size=32), 23 | ) 24 | sample = Image.new("RGB", (100, 100)) 25 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 26 | assert len(output) == 2 27 | assert output[0].shape == (3, 32, 32) 28 | assert output[1].shape == (3, 32, 32) 29 | -------------------------------------------------------------------------------- /tests/transforms/test_multi_view_transform_v2.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pytest 4 | from PIL import Image 5 | 6 | from lightly.utils.dependency import torchvision_transforms_v2_available 7 | 8 | if not torchvision_transforms_v2_available(): 9 | pytest.skip("torchvision.transforms.v2 not available", allow_module_level=True) 10 | import torch 11 | import torchvision.transforms.v2 as T 12 | from torchvision.tv_tensors import Mask 13 | 14 | from lightly.transforms.multi_view_transform_v2 import MultiViewTransformV2 15 | 16 | 17 | def test_multi_view_on_pil_image() -> None: 18 | multi_view_transform = MultiViewTransformV2( 19 | [ 20 | T.RandomHorizontalFlip(p=0.1), 21 | T.RandomVerticalFlip(p=0.5), 22 | T.RandomGrayscale(p=0.3), 23 | ] 24 | ) 25 | sample = {"img": Image.new("RGB", (10, 10)), "mask": Mask(torch.zeros(10, 10))} 26 | output = multi_view_transform(sample) 27 | assert len(output) == 3 28 | -------------------------------------------------------------------------------- /tests/transforms/test_phase_shift_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lightly.transforms import IRFFT2DTransform, PhaseShiftTransform, RFFT2DTransform 4 | 5 | 6 | # Testing function image -> RFFT -> PhaseShift. 7 | # Compare shapes of source and result. 8 | def test() -> None: 9 | image = torch.randn(3, 64, 64) 10 | 11 | rfftTransform = RFFT2DTransform() 12 | rfft = rfftTransform(image) 13 | 14 | phaseShiftTf_1 = PhaseShiftTransform() 15 | rescaled_rfft_1 = phaseShiftTf_1(rfft) 16 | 17 | phaseShiftTf_2 = PhaseShiftTransform(range=(1.0, 2.0)) 18 | rescaled_rfft_2 = phaseShiftTf_2(rfft) 19 | 20 | phaseShiftTf_3 = PhaseShiftTransform( 21 | range=(1.0, 2.0), include_negatives=True, sign_probability=0.8 22 | ) 23 | rescaled_rfft_3 = phaseShiftTf_3(rfft) 24 | 25 | assert rescaled_rfft_1.shape == rfft.shape 26 | assert rescaled_rfft_2.shape == rfft.shape 27 | assert rescaled_rfft_3.shape == rfft.shape 28 | -------------------------------------------------------------------------------- /docs/source/getting_started/code_examples/plot_image_augmentations.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | from PIL import Image 4 | 5 | import lightly 6 | 7 | # let's get all jpg filenames from a folder 8 | glob_to_data = "/datasets/clothing-dataset/images/*.jpg" 9 | fnames = glob.glob(glob_to_data) 10 | 11 | # load the first two images using pillow 12 | input_images = [Image.open(fname) for fname in fnames[:2]] 13 | 14 | # create our colalte function 15 | collate_fn_simclr = lightly.data.SimCLRCollateFunction() 16 | 17 | # plot the images 18 | fig = lightly.utils.debug.plot_augmented_images(input_images, collate_fn_simclr) 19 | 20 | # let's disable blur 21 | collate_fn_simclr_no_blur = lightly.data.SimCLRCollateFunction() 22 | fig = lightly.utils.debug.plot_augmented_images(input_images, collate_fn_simclr_no_blur) 23 | 24 | # we can also use the DINO collate function instead 25 | collate_fn_dino = lightly.data.DINOCollateFunction() 26 | fig = lightly.utils.debug.plot_augmented_images(input_images, collate_fn_dino) 27 | -------------------------------------------------------------------------------- /MAINTAINING.md: -------------------------------------------------------------------------------- 1 | # How to Maintain LightlySSL 2 | 3 | This document is intended for maintainers of Lightly**SSL**. If you would like to 4 | contribute, please refer to the [CONTRIBUTING.md](./CONTRIBUTING.md) document. 5 | 6 | ## Update PR from a Fork 7 | 8 | Sometimes it is necessary to update a PR from a forked Lightly**SSL** repository, 9 | for example to add some finishing touches (format, cleanup, docs), to fix some difficult 10 | issue, or to finish an incomplete feature. 11 | 12 | Let's assume there is a PR from `:`. To update the PR, follow 13 | these steps: 14 | 15 | ``` 16 | git remote add https://github.com//lightly.git 17 | git fetch 18 | git checkout -b / 19 | ``` 20 | 21 | Now you can make changes and push them to the PR branch with `git push`. 22 | 23 | If there's an update from LightlySSL's master branch, you can merge it into the PR branch 24 | 25 | ``` 26 | git fetch upstream 27 | git merge upstream/master 28 | ``` 29 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | We provide example implementations for self-supervised learning models for PyTorch and PyTorch Lightning to give you a headstart when implementing your own model! 4 | 5 | 6 | All examples can be run from the terminal with: 7 | 8 | ``` 9 | python 10 | ``` 11 | 12 | The examples should also run on [Google Colab](https://colab.research.google.com/). Remember to activate the GPU otherwise training will be very slow! You can simply copy paste the code and add the following line at the beginning of the notebook to install lightly: 13 | 14 | ``` 15 | !pip install lightly 16 | 17 | # add code from example below 18 | ``` 19 | 20 | 21 | You can find additional information for each model in our [Documentation](https://docs.lightly.ai//examples/models.html#) 22 | 23 | > [!IMPORTANT] 24 | > The examples notebooks are generated using the [`create_example_nbs.py`](./create_example_nbs.py) script and should not be modified manually. All changes must be made to the respecitve example. 25 | 26 | -------------------------------------------------------------------------------- /tests/loss/test_negative_cosine_similarity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lightly.loss import NegativeCosineSimilarity 5 | 6 | 7 | class TestNegativeCosineSimilarity: 8 | @pytest.mark.parametrize("bsz", range(1, 20)) 9 | def test_forward_pass(self, bsz: int) -> None: 10 | loss = NegativeCosineSimilarity() 11 | x0 = torch.randn((bsz, 32)) 12 | x1 = torch.randn((bsz, 32)) 13 | 14 | # symmetry 15 | l1 = loss(x0, x1) 16 | l2 = loss(x1, x0) 17 | assert l1 == pytest.approx(l2, abs=1e-5) 18 | 19 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") 20 | @pytest.mark.parametrize("bsz", range(1, 20)) 21 | def test_forward_pass_cuda(self, bsz: int) -> None: 22 | loss = NegativeCosineSimilarity() 23 | x0 = torch.randn((bsz, 32)).cuda() 24 | x1 = torch.randn((bsz, 32)).cuda() 25 | 26 | # symmetry 27 | l1 = loss(x0, x1) 28 | l2 = loss(x1, x0) 29 | assert l1 == pytest.approx(l2, abs=1e-5) 30 | -------------------------------------------------------------------------------- /tests/api_workflow/test_api_workflow_collaboration.py: -------------------------------------------------------------------------------- 1 | from tests.api_workflow import utils 2 | from tests.api_workflow.mocked_api_workflow_client import ( 3 | MockedApiWorkflowClient, 4 | MockedApiWorkflowSetup, 5 | ) 6 | 7 | 8 | class TestApiWorkflowDatasets(MockedApiWorkflowSetup): 9 | def setUp(self) -> None: 10 | self.api_workflow_client = MockedApiWorkflowClient(token="token_xyz") 11 | 12 | def test_share_empty_dataset(self): 13 | self.api_workflow_client.share_dataset_only_with( 14 | dataset_id=utils.generate_id(), user_emails=[] 15 | ) 16 | 17 | def test_share_dataset(self): 18 | self.api_workflow_client.share_dataset_only_with( 19 | dataset_id=utils.generate_id(), user_emails=["someone@something.com"] 20 | ) 21 | 22 | def test_get_shared_users(self): 23 | user_emails = self.api_workflow_client.get_shared_users( 24 | dataset_id=utils.generate_id() 25 | ) 26 | assert user_emails == ["user1@gmail.com", "user2@something.com"] 27 | -------------------------------------------------------------------------------- /tests/cli/test_cli_version.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import tempfile 5 | 6 | import hydra 7 | import pytest 8 | from hydra.experimental import compose 9 | 10 | try: 11 | from hydra import initialize 12 | except ImportError: 13 | from hydra.experimental import initialize 14 | 15 | from lightly.cli.version_cli import version_cli 16 | from tests.api_workflow.mocked_api_workflow_client import ( 17 | MockedApiWorkflowClient, 18 | MockedApiWorkflowSetup, 19 | ) 20 | 21 | 22 | class TestCLIVersion(MockedApiWorkflowSetup): 23 | def setUp(self): 24 | MockedApiWorkflowSetup.setUp(self) 25 | with initialize(config_path="../../lightly/cli/config", job_name="test_app"): 26 | self.cfg = compose(config_name="config") 27 | 28 | @pytest.fixture(autouse=True) 29 | def capsys(self, capsys): 30 | self.capsys = capsys 31 | 32 | def test_checkpoint_created(self): 33 | version_cli(self.cfg) 34 | out, err = self.capsys.readouterr() 35 | assert out.startswith("lightly version") 36 | -------------------------------------------------------------------------------- /lightly/transforms/multi_view_transform.py: -------------------------------------------------------------------------------- 1 | from typing import List, Sequence, Union 2 | 3 | from PIL.Image import Image 4 | from torch import Tensor 5 | 6 | from lightly.transforms.torchvision_v2_compatibility import torchvision_transforms as T 7 | 8 | 9 | class MultiViewTransform: 10 | """Transforms an image into multiple views. 11 | 12 | Args: 13 | transforms: 14 | A sequence of transforms. Every transform creates a new view. 15 | 16 | """ 17 | 18 | def __init__(self, transforms: Sequence[T.Compose]): 19 | self.transforms = transforms 20 | 21 | def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]: 22 | """Transforms an image into multiple views. 23 | 24 | Every transform in self.transforms creates a new view. 25 | 26 | Args: 27 | image: 28 | Image to be transformed into multiple views. 29 | 30 | Returns: 31 | List of views. 32 | 33 | """ 34 | return [transform(image) for transform in self.transforms] 35 | -------------------------------------------------------------------------------- /lightly/transforms/rfft2d_transform.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class RFFT2DTransform: 8 | """2D Fast Fourier Transform (RFFT2D) Transformation. 9 | 10 | This transformation applies the 2D Fast Fourier Transform (RFFT2D) 11 | to an image, converting it from the spatial domain to the frequency domain. 12 | 13 | Input: 14 | - Tensor of shape (C, H, W), where C is the number of channels. 15 | 16 | Output: 17 | - Tensor of shape (C, H, W) in the frequency domain, where C is the number of channels. 18 | """ 19 | 20 | def __call__(self, image: Tensor) -> Tensor: 21 | """Applies the 2D Fast Fourier Transform (RFFT2D) to the input image. 22 | 23 | Args: 24 | image: Input image as a Tensor of shape (C, H, W). 25 | 26 | Returns: 27 | Tensor: The image in the frequency domain after applying RFFT2D, of shape (C, H, W). 28 | """ 29 | 30 | rfft_image: Tensor = torch.fft.rfft2(image) 31 | return rfft_image 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **__pycache__** 2 | **build 3 | **settings.json 4 | lightning_logs/ 5 | **lightning_logs/ 6 | **/__MACOSX 7 | datasets/ 8 | dist/ 9 | docs/source/sg_execution_times.rst 10 | docs/source/tutorials/package/* 11 | docs/source/tutorials/platform/* 12 | docs/source/tutorials_source/platform/data 13 | docs/source/tutorials_source/platform/pizzas 14 | docs/source/docker/resources 15 | docs/source/getting_started/resources 16 | !docs/source/tutorials/package.rst 17 | !docs/source/tutorials/platform.rst 18 | !docs/source/tutorials/package/structure_your_input.rst 19 | **.zip 20 | *.egg-info 21 | 22 | 23 | 24 | #ignore venv 25 | venv 26 | .venv 27 | #ignore pycharm IDE 28 | .idea 29 | 30 | #ignore lightly outputs 31 | lightly_outputs 32 | 33 | #ignore .DS_Store 34 | **.DS_Store 35 | 36 | #ignore eggs 37 | .eggs 38 | tests/UNMOCKED_end2end_tests/call_test_api.py 39 | tests/UNMOCKED_end2end_tests/get_versions_all_apis.py 40 | tests/UNMOCKED_end2end_tests/call_benchmark_upload_private.py 41 | tests/UNMOCKED_end2end_tests/test_magic_on_branch.sh 42 | 43 | # ignore VSCode 44 | .vscode 45 | *.code-workspace 46 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | {% block attributes %} 9 | {% if attributes %} 10 | .. HACK -- the point here is that we don't want this to appear in the output, but the autosummary should still generate the pages. 11 | .. autosummary:: 12 | :toctree: 13 | {% for item in all_attributes %} 14 | {%- if not item.startswith('_') %} 15 | {{ name }}.{{ item }} 16 | {%- endif -%} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% if methods %} 22 | .. HACK -- the point here is that we don't want this to appear in the output, but the autosummary should still generate the pages. 23 | .. autosummary:: 24 | :toctree: 25 | {% for item in all_methods %} 26 | {%- if not item.startswith('_') or item in ['__call__'] %} 27 | {{ name }}.{{ item }} 28 | {%- endif -%} 29 | {%- endfor %} 30 | {% endif %} 31 | {% endblock %} -------------------------------------------------------------------------------- /lightly/utils/version_compare.py: -------------------------------------------------------------------------------- 1 | """ Utility method for comparing versions of libraries """ 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | 6 | 7 | def version_compare(v0: str, v1: str) -> int: 8 | """Returns 1 if version of v0 is larger than v1 and -1 otherwise 9 | 10 | Use this method to compare Python package versions and see which one is 11 | newer. 12 | 13 | Examples: 14 | 15 | >>> # compare two versions 16 | >>> version_compare('1.2.0', '1.1.2') 17 | >>> 1 18 | """ 19 | v0_parsed = [int(n) for n in v0.split(".")][::-1] 20 | v1_parsed = [int(n) for n in v1.split(".")][::-1] 21 | if len(v0_parsed) != 3 or len(v1_parsed) != 3: 22 | raise ValueError( 23 | f"Length of version strings is not 3 (expected pattern `x.y.z`) but is " 24 | f"{v0_parsed} and {v1_parsed}." 25 | ) 26 | pairs = list(zip(v0_parsed, v1_parsed))[::-1] 27 | for x, y in pairs: 28 | if x < y: 29 | return -1 30 | if x > y: 31 | return 1 32 | return 0 33 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/api_response.py: -------------------------------------------------------------------------------- 1 | """API response object.""" 2 | 3 | from __future__ import annotations 4 | from typing import Any, Dict, Optional 5 | try: 6 | # Pydantic >=v1.10.17 7 | from pydantic.v1 import Field, StrictInt, StrictStr 8 | except ImportError: 9 | # Pydantic v1 10 | from pydantic import Field, StrictInt, StrictStr 11 | 12 | class ApiResponse: 13 | """ 14 | API response object 15 | """ 16 | 17 | status_code: Optional[StrictInt] = Field(None, description="HTTP status code") 18 | headers: Optional[Dict[StrictStr, StrictStr]] = Field(None, description="HTTP headers") 19 | data: Optional[Any] = Field(None, description="Deserialized data given the data type") 20 | raw_data: Optional[Any] = Field(None, description="Raw data (HTTP response body)") 21 | 22 | def __init__(self, 23 | status_code=None, 24 | headers=None, 25 | data=None, 26 | raw_data=None): 27 | self.status_code = status_code 28 | self.headers = headers 29 | self.data = data 30 | self.raw_data = raw_data 31 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/docker_worker_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DockerWorkerType(str, Enum): 27 | """ 28 | DockerWorkerType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | FULL = 'FULL' 35 | 36 | @classmethod 37 | def from_json(cls, json_str: str) -> 'DockerWorkerType': 38 | """Create an instance of DockerWorkerType from a JSON string""" 39 | return DockerWorkerType(json.loads(json_str)) 40 | 41 | 42 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/sample_sort_by.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SampleSortBy(str, Enum): 27 | """ 28 | SampleSortBy 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | ID = '_id' 35 | INDEX = 'index' 36 | 37 | @classmethod 38 | def from_json(cls, json_str: str) -> 'SampleSortBy': 39 | """Create an instance of SampleSortBy from a JSON string""" 40 | return SampleSortBy(json.loads(json_str)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/shared_access_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SharedAccessType(str, Enum): 27 | """ 28 | SharedAccessType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | WRITE = 'WRITE' 35 | 36 | @classmethod 37 | def from_json(cls, json_str: str) -> 'SharedAccessType': 38 | """Create an instance of SharedAccessType from a JSON string""" 39 | return SharedAccessType(json.loads(json_str)) 40 | 41 | 42 | -------------------------------------------------------------------------------- /.github/workflows/test_code_format.yml: -------------------------------------------------------------------------------- 1 | name: Code Format Check 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | jobs: 11 | test: 12 | name: Check 13 | runs-on: ubuntu-22.04 14 | strategy: 15 | matrix: 16 | python: ["3.7", "3.12"] 17 | steps: 18 | - name: Checkout Code 19 | uses: actions/checkout@v4 20 | - name: Hack to get setup-python to work on nektos/act 21 | run: | 22 | if [ ! -f "/etc/lsb-release" ] ; then 23 | echo "DISTRIB_RELEASE=18.04" > /etc/lsb-release 24 | fi 25 | - name: Set Up Python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python }} 29 | - name: Set Up Environment 30 | run: | 31 | make install-uv reset-venv 32 | source .venv/bin/activate 33 | make install-pinned-extras 34 | - name: Run Format Check 35 | run: | 36 | source .venv/bin/activate 37 | make format-check 38 | - name: Run Type Check 39 | run: | 40 | source .venv/bin/activate 41 | make type-check 42 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/image_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class ImageType(str, Enum): 27 | """ 28 | ImageType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | FULL = 'full' 35 | THUMBNAIL = 'thumbnail' 36 | META = 'meta' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'ImageType': 40 | """Create an instance of ImageType from a JSON string""" 41 | return ImageType(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from lightly.api import ApiWorkflowClient 8 | from lightly.openapi_generated.swagger_client import ApiClient, Configuration, QuotaApi 9 | 10 | if __name__ == "__main__": 11 | token = os.getenv("LIGHTLY_TOKEN") 12 | api_client = ApiWorkflowClient(token=token).api_client 13 | quota_api = QuotaApi(api_client) 14 | 15 | n_iters = 200 16 | 17 | latencies = np.zeros(n_iters) 18 | for i in tqdm(range(n_iters)): 19 | start = time.time() 20 | quota_api.get_quota_maximum_dataset_size() 21 | duration = time.time() - start 22 | latencies[i] = duration 23 | 24 | def format_latency(latency: float): 25 | return f"{latency*1000:.1f}ms" 26 | 27 | values = [("min")] 28 | print( 29 | f"Latencies: min: {format_latency(np.min(latencies))}, mean: {format_latency(np.mean(latencies))}, max: {format_latency(np.max(latencies))}" 30 | ) 31 | print(f"\nPINGING TO GOOGLE") 32 | response = os.system("ping -c 1 " + "google.com") 33 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/dataset_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DatasetType(str, Enum): 27 | """ 28 | DatasetType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | CROPS = 'Crops' 35 | IMAGES = 'Images' 36 | VIDEOS = 'Videos' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'DatasetType': 40 | """Create an instance of DatasetType from a JSON string""" 41 | return DatasetType(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/file_output_format.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class FileOutputFormat(str, Enum): 27 | """ 28 | FileOutputFormat 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | JSON = 'JSON' 35 | PLAIN = 'PLAIN' 36 | 37 | @classmethod 38 | def from_json(cls, json_str: str) -> 'FileOutputFormat': 39 | """Create an instance of FileOutputFormat from a JSON string""" 40 | return FileOutputFormat(json.loads(json_str)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/PR_template_checklist_full.md: -------------------------------------------------------------------------------- 1 | closes #issue_number 2 | 3 | ## Description 4 | - [ ] My change is breaking 5 | Please_describe_what_you_changed_and_why___You_do_not_need_to_repeat_stuff_from_the_issue 6 | 7 | ## Tests 8 | - [ ] My change is covered by existing tests 9 | - [ ] My change needs new tests 10 | - [ ] I have added/adapted tests accordingly. 11 | - [ ] I have manually tested the change. 12 | 13 | If applicable, describe the manual test procedure, e.g: 14 | ```bash 15 | pip uninstall lightly 16 | export BRANCH_NAME="branch_name" 17 | pip install "git+https://github.com/lightly-ai/lightly.git@$BRANCH_NAME" 18 | lightly-cli_do_something_command 19 | ``` 20 | 21 | ## Documentation 22 | - [ ] I have added docstrings to all changed/added public functions/methods. 23 | - [ ] My change requires a change to the documentation ( `.rst` files). 24 | - [ ] I have updated the documentation accordingly. 25 | - [ ] The autodocs update the documentation accordingly.` 26 | 27 | ## Improvements put into another issue: 28 | - #issue_number 29 | 30 | ## Issues covering the breaking change: 31 | - #link_to_issue_in_other_repo to adapt the other side of the breaking change -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/team_role.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class TeamRole(str, Enum): 27 | """ 28 | TeamRole 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | OWNER = 'OWNER' 35 | ADMIN = 'ADMIN' 36 | MEMBER = 'MEMBER' 37 | SERVICEACCOUNT = 'SERVICEACCOUNT' 38 | 39 | @classmethod 40 | def from_json(cls, json_str: str) -> 'TeamRole': 41 | """Create an instance of TeamRole from a JSON string""" 42 | return TeamRole(json.loads(json_str)) 43 | 44 | 45 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/user_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class UserType(str, Enum): 27 | """ 28 | UserType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | FREE = 'FREE' 35 | PROFESSIONAL = 'PROFESSIONAL' 36 | ENTERPRISE = 'ENTERPRISE' 37 | ADMIN = 'ADMIN' 38 | 39 | @classmethod 40 | def from_json(cls, json_str: str) -> 'UserType': 41 | """Create an instance of UserType from a JSON string""" 42 | return UserType(json.loads(json_str)) 43 | 44 | 45 | -------------------------------------------------------------------------------- /lightly/models/__init__.py: -------------------------------------------------------------------------------- 1 | """The lightly.models package provides model implementations. 2 | 3 | Note that the high-level building blocks will be deprecated with 4 | lightly version 1.3.0. Instead, use low-level building blocks to build the 5 | models yourself. 6 | 7 | Example implementations for all models can be found here: 8 | `Model Examples `_ 9 | 10 | The package contains an implementation of the commonly used ResNet and 11 | adaptations of the architecture which make self-supervised learning simpler. 12 | 13 | The package also hosts the Lightly model zoo - a list of downloadable ResNet 14 | checkpoints. 15 | 16 | """ 17 | 18 | # Copyright (c) 2020. Lightly AG and its affiliates. 19 | # All Rights Reserved 20 | 21 | from lightly.models import utils 22 | from lightly.models.barlowtwins import BarlowTwins 23 | from lightly.models.byol import BYOL 24 | from lightly.models.moco import MoCo 25 | from lightly.models.nnclr import NNCLR 26 | from lightly.models.resnet import ResNetGenerator 27 | from lightly.models.simclr import SimCLR 28 | from lightly.models.simsiam import SimSiam 29 | from lightly.models.zoo import ZOO, checkpoints 30 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/expiry_handling_strategy_v3.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class ExpiryHandlingStrategyV3(str, Enum): 27 | """ 28 | ExpiryHandlingStrategyV3 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | SKIP = 'SKIP' 35 | ABORT = 'ABORT' 36 | 37 | @classmethod 38 | def from_json(cls, json_str: str) -> 'ExpiryHandlingStrategyV3': 39 | """Create an instance of ExpiryHandlingStrategyV3 from a JSON string""" 40 | return ExpiryHandlingStrategyV3(json.loads(json_str)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/job_state.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class JobState(str, Enum): 27 | """ 28 | JobState 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | UNKNOWN = 'UNKNOWN' 35 | WAITING = 'WAITING' 36 | RUNNING = 'RUNNING' 37 | FAILED = 'FAILED' 38 | FINISHED = 'FINISHED' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'JobState': 42 | """Create an instance of JobState from a JSON string""" 43 | return JobState(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /tests/UNMOCKED_end2end_tests/run_all_unmocked_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Get the parameters 5 | export LIGHTLY_TOKEN=$1 6 | [[ -z "$LIGHTLY_TOKEN" ]] && { echo "Error: token is empty" ; exit 1; } 7 | echo "############################### token: ${LIGHTLY_TOKEN}" 8 | 9 | DATE_TIME=$(date +%Y-%m-%d-%H-%M-%S) 10 | 11 | echo "############################### Download the clothing_dataset_small" 12 | 13 | DIR_DATASET=clothing_dataset_small 14 | if [ -d $DIR_DATASET ]; then 15 | echo "Skipping download of dataset, it already exists." 16 | else 17 | git clone https://github.com/alexeygrigorev/clothing-dataset-small $DIR_DATASET 18 | fi 19 | INPUT_DIR="${DIR_DATASET}/test/dress" 20 | CUSTOM_METADATA_FILENAME="${DIR_DATASET}/custom_metadata.json" 21 | python tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py $INPUT_DIR $CUSTOM_METADATA_FILENAME 22 | 23 | # Run the tests 24 | echo "############################### Test 1" 25 | lightly-magic input_dir=$INPUT_DIR trainer.max_epochs=0 26 | 27 | echo "############################### Test 2" 28 | lightly-magic input_dir=$INPUT_DIR trainer.max_epochs=1 29 | 30 | echo "############################### Delete dataset again" 31 | rm -rf $DIR_DATASET 32 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/lightly_trainer_precision_v2.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class LightlyTrainerPrecisionV2(int, Enum): 27 | """ 28 | LightlyTrainerPrecisionV2 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | NUMBER_16 = 16 35 | NUMBER_32 = 32 36 | 37 | @classmethod 38 | def from_json(cls, json_str: str) -> 'LightlyTrainerPrecisionV2': 39 | """Create an instance of LightlyTrainerPrecisionV2 from a JSON string""" 40 | return LightlyTrainerPrecisionV2(json.loads(json_str)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/lightly_trainer_precision_v3.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class LightlyTrainerPrecisionV3(int, Enum): 27 | """ 28 | LightlyTrainerPrecisionV3 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | NUMBER_16 = 16 35 | NUMBER_32 = 32 36 | 37 | @classmethod 38 | def from_json(cls, json_str: str) -> 'LightlyTrainerPrecisionV3': 39 | """Create an instance of LightlyTrainerPrecisionV3 from a JSON string""" 40 | return LightlyTrainerPrecisionV3(json.loads(json_str)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /lightly/transforms/multi_view_transform_v2.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Sequence 2 | 3 | from lightly.transforms.torchvision_v2_compatibility import torchvision_transforms as T 4 | 5 | 6 | class MultiViewTransformV2: 7 | """Transforms an image into multiple views and is compatible with transforms v2. 8 | 9 | Args: 10 | transforms: 11 | A sequence of v2 transforms. Every transform creates a new view. 12 | """ 13 | 14 | def __init__(self, transforms: Sequence[T.Compose]): 15 | self.transforms = transforms 16 | 17 | def __call__(self, *args: Any) -> List[Any]: 18 | """Transforms a data structure containing images, bounding boxes and masks 19 | into a sequence of multiple views. 20 | 21 | Every transform in self.transforms creates a new view. 22 | 23 | Args: 24 | *args: 25 | Arbitary positional arguments consisting of arbitrary data structures 26 | containing images, bounding boxes and masks. 27 | 28 | Returns: 29 | A list of views, where each view is a transformed version of `*args`. 30 | 31 | """ 32 | return [transform(*args) for transform in self.transforms] 33 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/docker_worker_state.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DockerWorkerState(str, Enum): 27 | """ 28 | DockerWorkerState 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | OFFLINE = 'OFFLINE' 35 | CRASHED = 'CRASHED' 36 | IDLE = 'IDLE' 37 | BUSY = 'BUSY' 38 | 39 | @classmethod 40 | def from_json(cls, json_str: str) -> 'DockerWorkerState': 41 | """Create an instance of DockerWorkerState from a JSON string""" 42 | return DockerWorkerState(json.loads(json_str)) 43 | 44 | 45 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/sample_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SampleType(str, Enum): 27 | """ 28 | Type of the sample (VideoFrame vs IMAGE vs CROP). Determined by the API! 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | CROP = 'CROP' 35 | IMAGE = 'IMAGE' 36 | VIDEO_FRAME = 'VIDEO_FRAME' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'SampleType': 40 | """Create an instance of SampleType from a JSON string""" 41 | return SampleType(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/sampling_method.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SamplingMethod(str, Enum): 27 | """ 28 | SamplingMethod 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | ACTIVE_LEARNING = 'ACTIVE_LEARNING' 35 | CORAL = 'CORAL' 36 | CORESET = 'CORESET' 37 | RANDOM = 'RANDOM' 38 | 39 | @classmethod 40 | def from_json(cls, json_str: str) -> 'SamplingMethod': 41 | """Create an instance of SamplingMethod from a JSON string""" 42 | return SamplingMethod(json.loads(json_str)) 43 | 44 | 45 | -------------------------------------------------------------------------------- /tests/transforms/test_rotation.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from PIL import Image 4 | 5 | from lightly.transforms.rotation import ( 6 | RandomRotate, 7 | RandomRotateDegrees, 8 | random_rotation_transform, 9 | ) 10 | 11 | 12 | def test_RandomRotate_on_pil_image() -> None: 13 | random_rotate = RandomRotate() 14 | sample = Image.new("RGB", (100, 100)) 15 | random_rotate(sample) 16 | 17 | 18 | def test_RandomRotateDegrees_on_pil_image() -> None: 19 | all_degrees: List[Union[float, Tuple[float, float]]] = [0, 1, 45, (0, 0), (-15, 30)] 20 | for degrees in all_degrees: 21 | random_rotate = RandomRotateDegrees(prob=0.5, degrees=degrees) 22 | sample = Image.new("RGB", (100, 100)) 23 | random_rotate(sample) 24 | 25 | 26 | def test_random_rotation_transform() -> None: 27 | transform = random_rotation_transform(rr_prob=1.0, rr_degrees=None) 28 | assert isinstance(transform, RandomRotate) 29 | transform = random_rotation_transform(rr_prob=1.0, rr_degrees=45) 30 | assert isinstance(transform, RandomRotateDegrees) 31 | transform = random_rotation_transform(rr_prob=1.0, rr_degrees=(30, 45)) 32 | assert isinstance(transform, RandomRotateDegrees) 33 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/job_status_upload_method.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class JobStatusUploadMethod(str, Enum): 27 | """ 28 | JobStatusUploadMethod 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | USER_WEBAPP = 'USER_WEBAPP' 35 | USER_PIP = 'USER_PIP' 36 | INTERNAL = 'INTERNAL' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'JobStatusUploadMethod': 40 | """Create an instance of JobStatusUploadMethod from a JSON string""" 41 | return JobStatusUploadMethod(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/lightly_docker_selection_method.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class LightlyDockerSelectionMethod(str, Enum): 27 | """ 28 | LightlyDockerSelectionMethod 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | CORESET = 'coreset' 35 | RANDOM = 'random' 36 | 37 | @classmethod 38 | def from_json(cls, json_str: str) -> 'LightlyDockerSelectionMethod': 39 | """Create an instance of LightlyDockerSelectionMethod from a JSON string""" 40 | return LightlyDockerSelectionMethod(json.loads(json_str)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /tests/UNMOCKED_end2end_tests/README.md: -------------------------------------------------------------------------------- 1 | This repository contains scripts to test the python package with a server. 2 | 3 | ## Testing the Server API with CLI commands and active learning 4 | You only need an account on the server. 5 | Once you have a token from our production server `https://app.lightly.ai`, you can run: 6 | ```bash 7 | cd ../../../lightly # ensure you are in the top directory 8 | pip uninstall lightly -y 9 | 10 | pip install . 11 | bash tests/UNMOCKED_end2end_tests/run_all_unmocked_tests.sh LIGHTLY_TOKEN 12 | ``` 13 | 14 | ## Testing the Server API with CLI commands 15 | You only need an account on the server and a dataset. 16 | Once you have a token from our production server `https://app.lightly.ai`, you can run: 17 | ```bash 18 | bash test_api_on_branch.sh path/to/dataset LIGHTLY_TOKEN 19 | ``` 20 | 21 | ## Testing the API latency 22 | This needs a token, but no dataset 23 | ```bash 24 | LIGHTLY_TOKEN="MY_TOKEN" && python tests/UNMOCKED_end2end_tests/test_api_latency.py 25 | ``` 26 | 27 | ## Testing the upload speed 28 | Use the pycharm profile with yappi to run the function [benchmark_upload.py:benchmark_upload()](benchmark_upload.py). 29 | You can use the following script for example: [call_benchmark_upload](call_benchmark_upload.py) -------------------------------------------------------------------------------- /lightly/loss/__init__.py: -------------------------------------------------------------------------------- 1 | """The lightly.loss package provides loss functions for self-supervised learning. """ 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | from lightly.loss.barlow_twins_loss import BarlowTwinsLoss 6 | from lightly.loss.dcl_loss import DCLLoss, DCLWLoss 7 | from lightly.loss.detcon_loss import DetConBLoss, DetConSLoss 8 | from lightly.loss.dino_loss import DINOLoss 9 | from lightly.loss.emp_ssl_loss import EMPSSLLoss 10 | from lightly.loss.ibot_loss import IBOTPatchLoss 11 | from lightly.loss.koleo_loss import KoLeoLoss 12 | from lightly.loss.macl_loss import MACLLoss 13 | from lightly.loss.mmcr_loss import MMCRLoss 14 | from lightly.loss.msn_loss import MSNLoss 15 | from lightly.loss.negative_cosine_similarity import NegativeCosineSimilarity 16 | from lightly.loss.ntx_ent_loss import NTXentLoss 17 | from lightly.loss.pmsn_loss import PMSNCustomLoss, PMSNLoss 18 | from lightly.loss.swav_loss import SwaVLoss 19 | from lightly.loss.sym_neg_cos_sim_loss import SymNegCosineSimilarityLoss 20 | from lightly.loss.tico_loss import TiCoLoss 21 | from lightly.loss.vicreg_loss import VICRegLoss 22 | from lightly.loss.vicregl_loss import VICRegLLoss 23 | from lightly.loss.wmse_loss import Whitening2d, WMSELoss 24 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/creator.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class Creator(str, Enum): 27 | """ 28 | Creator 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | UNKNOWN = 'UNKNOWN' 35 | USER_WEBAPP = 'USER_WEBAPP' 36 | USER_PIP = 'USER_PIP' 37 | USER_PIP_LIGHTLY_MAGIC = 'USER_PIP_LIGHTLY_MAGIC' 38 | USER_WORKER = 'USER_WORKER' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'Creator': 42 | """Create an instance of Creator from a JSON string""" 43 | return Creator(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /.github/workflows/tests_unmocked.yml: -------------------------------------------------------------------------------- 1 | name: run unmocked tests 2 | on: [workflow_dispatch] 3 | 4 | # Why is this only triggered manually with workflow_dispatch? 5 | # - We have many unittests anyway 6 | # - We ran into quota problems with too many tests (3000mins for the pip package) 7 | # - The tests are unreliable / not deterministic, as they depend on the API from staging 8 | # - The tests take 10mins, which can slow down development 9 | 10 | jobs: 11 | test: 12 | name: Run unmocked tests 13 | runs-on: ubuntu-22.04 14 | steps: 15 | - name: Checkout Code 16 | uses: actions/checkout@v4 17 | - name: Set Up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: 3.7 21 | cache: pip 22 | cache-dependency-path: | 23 | pyproject.toml 24 | - name: Set Up Environment 25 | run: | 26 | make install-uv reset-venv 27 | source .venv/bin/activate 28 | make install-pinned 29 | - name: Run Unmocked Tests 30 | run: | 31 | source .venv/bin/activate 32 | export LIGHTLY_SERVER_LOCATION=${{ secrets.LIGHTLY_SERVER_LOCATION }} 33 | bash tests/UNMOCKED_end2end_tests/run_all_unmocked_tests.sh ${{ secrets.DUMMY_USER_TOKEN_STAGING }} 34 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/docker_run_scheduled_state.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DockerRunScheduledState(str, Enum): 27 | """ 28 | DockerRunScheduledState 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | OPEN = 'OPEN' 35 | LOCKED = 'LOCKED' 36 | DONE = 'DONE' 37 | CANCELED = 'CANCELED' 38 | 39 | @classmethod 40 | def from_json(cls, json_str: str) -> 'DockerRunScheduledState': 41 | """Create an instance of DockerRunScheduledState from a JSON string""" 42 | return DockerRunScheduledState(json.loads(json_str)) 43 | 44 | 45 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/sample_partial_mode.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SamplePartialMode(str, Enum): 27 | """ 28 | ids: return only the id fileNames: return the id and fileName full: return all data 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | IDS = 'ids' 35 | FILENAMES = 'fileNames' 36 | FULL = 'full' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'SamplePartialMode': 40 | """Create an instance of SamplePartialMode from a JSON string""" 41 | return SamplePartialMode(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /lightly/transforms/amplitude_rescale_transform.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.distributions import Uniform 6 | 7 | 8 | class AmplitudeRescaleTransform: 9 | """Implementation of amplitude rescaling transformation. 10 | 11 | This transform will rescale the amplitude of the Fourier Spectrum (`freq_image`) of the image and return it. 12 | 13 | Attributes: 14 | dist: 15 | Uniform distribution in `[m, n)` from which the scaling value will be selected. 16 | """ 17 | 18 | def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None: 19 | self.dist = Uniform(range[0], range[1]) 20 | 21 | def __call__(self, freq_image: Tensor) -> Tensor: 22 | amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2) 23 | 24 | phase = torch.atan2(freq_image.imag, freq_image.real) 25 | # p with shape (H, W) 26 | p = self.dist.sample(freq_image.shape[1:]).to(freq_image.device) 27 | # Unsqueeze to add channel dimension. 28 | amplitude *= p.unsqueeze(0) 29 | real = amplitude * torch.cos(phase) 30 | imag = amplitude * torch.sin(phase) 31 | output = torch.complex(real, imag) 32 | 33 | return output 34 | -------------------------------------------------------------------------------- /.github/workflows/check_example_nbs.yml: -------------------------------------------------------------------------------- 1 | name: "Check Example Notebooks" 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | paths: 7 | - "examples/**/*.py" 8 | pull_request: 9 | branches: [master] 10 | paths: 11 | - "examples/**/*.py" 12 | workflow_dispatch: 13 | 14 | jobs: 15 | convert-to-nbs: 16 | name: "Check Example Notebooks" 17 | runs-on: ubuntu-22.04 18 | steps: 19 | - name: Checkout Code 20 | uses: actions/checkout@v4 21 | - name: Set Up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: 3.7 25 | - name: Set Up Environment 26 | run: | 27 | make install-uv reset-venv 28 | source .venv/bin/activate 29 | make install-pinned-extras 30 | - name: Convert using Script 31 | run: | 32 | source .venv/bin/activate 33 | make generate-example-notebooks 34 | - name: Check for diff 35 | run: | 36 | source .venv/bin/activate 37 | git add examples/notebooks/ 38 | if ! git diff --cached --exit-code; then 39 | echo "Notebooks have changed! Please run 'make generate-example-notebooks' and commit the changes." 40 | exit 1 41 | fi 42 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/docker_run_scheduled_priority.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DockerRunScheduledPriority(str, Enum): 27 | """ 28 | DockerRunScheduledPriority 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | LOW = 'LOW' 35 | MID = 'MID' 36 | HIGH = 'HIGH' 37 | CRITICAL = 'CRITICAL' 38 | 39 | @classmethod 40 | def from_json(cls, json_str: str) -> 'DockerRunScheduledPriority': 41 | """Create an instance of DockerRunScheduledPriority from a JSON string""" 42 | return DockerRunScheduledPriority(json.loads(json_str)) 43 | 44 | 45 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/tag_change_data_operation_method.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class TagChangeDataOperationMethod(str, Enum): 27 | """ 28 | TagChangeDataOperationMethod 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | SELECTED = 'selected' 35 | ADDED = 'added' 36 | REMOVED = 'removed' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'TagChangeDataOperationMethod': 40 | """Create an instance of TagChangeDataOperationMethod from a JSON string""" 41 | return TagChangeDataOperationMethod(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/dimensionality_reduction_method.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DimensionalityReductionMethod(str, Enum): 27 | """ 28 | Method which was used to create the 2d embeddings 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | PCA = 'PCA' 35 | TSNE = 'TSNE' 36 | UMAP = 'UMAP' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'DimensionalityReductionMethod': 40 | """Create an instance of DimensionalityReductionMethod from a JSON string""" 41 | return DimensionalityReductionMethod(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/docker_run_artifact_storage_location.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DockerRunArtifactStorageLocation(str, Enum): 27 | """ 28 | DockerRunArtifactStorageLocation 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | LIGHTLY = 'LIGHTLY' 35 | DATASOURCE = 'DATASOURCE' 36 | 37 | @classmethod 38 | def from_json(cls, json_str: str) -> 'DockerRunArtifactStorageLocation': 39 | """Create an instance of DockerRunArtifactStorageLocation from a JSON string""" 40 | return DockerRunArtifactStorageLocation(json.loads(json_str)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/job_result_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class JobResultType(str, Enum): 27 | """ 28 | JobResultType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | DATASET_PROCESSING = 'DATASET_PROCESSING' 35 | IMAGEMETA = 'IMAGEMETA' 36 | EMBEDDING = 'EMBEDDING' 37 | EMBEDDINGS2D = 'EMBEDDINGS2D' 38 | SAMPLING = 'SAMPLING' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'JobResultType': 42 | """Create an instance of JobResultType from a JSON string""" 43 | return JobResultType(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/selection_input_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SelectionInputType(str, Enum): 27 | """ 28 | SelectionInputType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | EMBEDDINGS = 'EMBEDDINGS' 35 | SCORES = 'SCORES' 36 | METADATA = 'METADATA' 37 | PREDICTIONS = 'PREDICTIONS' 38 | RANDOM = 'RANDOM' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'SelectionInputType': 42 | """Create an instance of SelectionInputType from a JSON string""" 43 | return SelectionInputType(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/selection_input_predictions_name.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SelectionInputPredictionsName(str, Enum): 27 | """ 28 | SelectionInputPredictionsName 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | CLASS_DISTRIBUTION = 'CLASS_DISTRIBUTION' 35 | CATEGORY_COUNT = 'CATEGORY_COUNT' 36 | 37 | @classmethod 38 | def from_json(cls, json_str: str) -> 'SelectionInputPredictionsName': 39 | """Create an instance of SelectionInputPredictionsName from a JSON string""" 40 | return SelectionInputPredictionsName(json.loads(json_str)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /tests/transforms/test_location_to_NxN_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import lightly.transforms.random_crop_and_flip_with_grid as test_module 4 | 5 | 6 | def test_location_to_NxN_grid() -> None: 7 | # create a test instance of the Location class 8 | test_location = test_module.Location( 9 | left=10, 10 | top=20, 11 | width=100, 12 | height=200, 13 | image_height=244, 14 | image_width=244, 15 | horizontal_flip=True, 16 | vertical_flip=False, 17 | ) 18 | 19 | # create an instance of the class containing the function 20 | test_class = test_module.RandomResizedCropAndFlip(grid_size=3) 21 | 22 | # call the function with the test location 23 | result = test_class.location_to_NxN_grid(test_location) 24 | 25 | # create a tensor representing the expected output 26 | expected_output = torch.tensor( 27 | [ 28 | [[126.6667, 53.3333], [76.6667, 53.3333], [26.6667, 53.3333]], 29 | [[126.6667, 153.3333], [76.6667, 153.3333], [26.6667, 153.3333]], 30 | [[126.6667, 253.3333], [76.6667, 253.3333], [26.6667, 253.3333]], 31 | ] 32 | ) 33 | # check that the function output matches the expected output 34 | assert torch.allclose(result, expected_output) 35 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/dataset_creator.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DatasetCreator(str, Enum): 27 | """ 28 | DatasetCreator 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | UNKNOWN = 'UNKNOWN' 35 | USER_WEBAPP = 'USER_WEBAPP' 36 | USER_PIP = 'USER_PIP' 37 | USER_PIP_LIGHTLY_MAGIC = 'USER_PIP_LIGHTLY_MAGIC' 38 | USER_WORKER = 'USER_WORKER' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'DatasetCreator': 42 | """Create an instance of DatasetCreator from a JSON string""" 43 | return DatasetCreator(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lightly/transforms/irfft2d_transform.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class IRFFT2DTransform: 8 | """Inverse 2D Fast Fourier Transform (IRFFT2D) Transformation. 9 | 10 | This transformation applies the inverse 2D Fast Fourier Transform (IRFFT2D) 11 | to an image in the frequency domain. 12 | 13 | Input: 14 | - Tensor of shape (C, H, W), where C is the number of channels. 15 | 16 | Output: 17 | - Tensor of shape (C, H, W), where C is the number of channels. 18 | """ 19 | 20 | def __init__(self, shape: Tuple[int, int]): 21 | """ 22 | Args: 23 | shape: The desired output shape (H, W) after applying the inverse FFT 24 | """ 25 | self.shape = shape 26 | 27 | def __call__(self, freq_image: Tensor) -> Tensor: 28 | """Applies the inverse 2D Fast Fourier Transform (IRFFT2D) to the input tensor. 29 | 30 | Args: 31 | freq_image: A tensor in the frequency domain of shape (C, H, W). 32 | 33 | Returns: 34 | Tensor: Reconstructed image after applying IRFFT2D, of shape (C, H, W). 35 | """ 36 | reconstructed_image: Tensor = torch.fft.irfft2(freq_image, s=self.shape) 37 | return reconstructed_image 38 | -------------------------------------------------------------------------------- /lightly/utils/benchmarking/topk.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def mean_topk_accuracy( 8 | predicted_classes: Tensor, targets: Tensor, k: Sequence[int] 9 | ) -> Dict[int, Tensor]: 10 | """Computes the mean accuracy for the specified values of k. 11 | 12 | The mean is calculated over the batch dimension. 13 | 14 | Args: 15 | predicted_classes: 16 | Tensor of shape (batch_size, num_classes) with the predicted classes sorted 17 | in descending order of confidence. 18 | targets: 19 | Tensor of shape (batch_size) containing the target classes. 20 | k: 21 | Sequence of integers specifying the values of k for which the accuracy 22 | should be computed. 23 | 24 | Returns: 25 | Dictionary containing the mean accuracy for each value of k. For example for 26 | k=(1, 5) the dictionary could look like this: {1: 0.4, 5: 0.6}. 27 | """ 28 | accuracy = {} 29 | targets = targets.unsqueeze(1) 30 | with torch.no_grad(): 31 | for num_k in k: 32 | correct = torch.eq(predicted_classes[:, :num_k], targets) 33 | accuracy[num_k] = correct.float().sum() / targets.shape[0] 34 | return accuracy 35 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/lightly_model_v2.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class LightlyModelV2(str, Enum): 27 | """ 28 | LightlyModelV2 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | RESNET_MINUS_18 = 'resnet-18' 35 | RESNET_MINUS_34 = 'resnet-34' 36 | RESNET_MINUS_50 = 'resnet-50' 37 | RESNET_MINUS_101 = 'resnet-101' 38 | RESNET_MINUS_152 = 'resnet-152' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'LightlyModelV2': 42 | """Create an instance of LightlyModelV2 from a JSON string""" 43 | return LightlyModelV2(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/lightly_model_v3.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class LightlyModelV3(str, Enum): 27 | """ 28 | LightlyModelV3 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | RESNET_MINUS_18 = 'resnet-18' 35 | RESNET_MINUS_34 = 'resnet-34' 36 | RESNET_MINUS_50 = 'resnet-50' 37 | RESNET_MINUS_101 = 'resnet-101' 38 | RESNET_MINUS_152 = 'resnet-152' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'LightlyModelV3': 42 | """Create an instance of LightlyModelV3 from a JSON string""" 43 | return LightlyModelV3(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/selection_strategy_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SelectionStrategyType(str, Enum): 27 | """ 28 | SelectionStrategyType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | DIVERSITY = 'DIVERSITY' 35 | WEIGHTS = 'WEIGHTS' 36 | THRESHOLD = 'THRESHOLD' 37 | BALANCE = 'BALANCE' 38 | SIMILARITY = 'SIMILARITY' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'SelectionStrategyType': 42 | """Create an instance of SelectionStrategyType from a JSON string""" 43 | return SelectionStrategyType(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/tag_arithmetics_operation.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class TagArithmeticsOperation(str, Enum): 27 | """ 28 | The possible arithmetic operations that can be done between multiple tags. 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | UNION = 'UNION' 35 | INTERSECTION = 'INTERSECTION' 36 | DIFFERENCE = 'DIFFERENCE' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'TagArithmeticsOperation': 40 | """Create an instance of TagArithmeticsOperation from a JSON string""" 41 | return TagArithmeticsOperation(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /requirements/README.md: -------------------------------------------------------------------------------- 1 | All dependencies are tracked in `pyproject.toml` and no new files should be added to 2 | the `requirements` directory. 3 | 4 | We maintain `base.txt` to allow installing the package only with the dependencies 5 | necessary to use the API part of the package. The package can be installed with API 6 | only dependencies by running: 7 | ``` 8 | pip install -r requirements/base.txt 9 | pip install lightly --no-deps 10 | ``` 11 | This is also documented in our Lightly Worker docs: 12 | https://docs.lightly.ai/docs/install-lightly#install-the-lightly-python-client 13 | 14 | It is currently not possible to move these dependencies to an optional dependency 15 | group in `pyproject.toml` because pip does not support installing only optional 16 | dependencies. See https://github.com/pypa/pip/issues/11440 17 | 18 | `openapi.txt` is automatically created by the API generator and should not be modified 19 | manually. 20 | 21 | There are tests in [`tests/test_requirements.py`](../tests/test_requirements.py) that 22 | check that the dependencies in `base.txt` are in sync with the dependencies in 23 | `pyproject.toml` and `openapi.txt`. 24 | 25 | There is also a [GitHub Action](../.github/workflows/test_api_deps_only.yml) that 26 | verifies that installing only the API part of the package works correctly. 27 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/docker_run_log_level.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DockerRunLogLevel(str, Enum): 27 | """ 28 | DockerRunLogLevel 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | VERBOSE = 'VERBOSE' 35 | NOTSET = 'NOTSET' 36 | DEBUG = 'DEBUG' 37 | INFO = 'INFO' 38 | WARNING = 'WARNING' 39 | ERROR = 'ERROR' 40 | CRITICAL = 'CRITICAL' 41 | 42 | @classmethod 43 | def from_json(cls, json_str: str) -> 'DockerRunLogLevel': 44 | """Create an instance of DockerRunLogLevel from a JSON string""" 45 | return DockerRunLogLevel(json.loads(json_str)) 46 | 47 | 48 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/usage.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class Usage(str, Enum): 27 | """ 28 | Usage 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | LABELING = 'LABELING' 35 | SELECTING = 'SELECTING' 36 | VISUALIZATION = 'VISUALIZATION' 37 | MODEL_TRAINING = 'MODEL_TRAINING' 38 | DATASET_MANAGING = 'DATASET_MANAGING' 39 | CHECKING_THE_TOOL = 'CHECKING_THE_TOOL' 40 | OTHER = 'OTHER' 41 | 42 | @classmethod 43 | def from_json(cls, json_str: str) -> 'Usage': 44 | """Create an instance of Usage from a JSON string""" 45 | return Usage(json.loads(json_str)) 46 | 47 | 48 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/selection_strategy_type_v3.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SelectionStrategyTypeV3(str, Enum): 27 | """ 28 | SelectionStrategyTypeV3 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | DIVERSITY = 'DIVERSITY' 35 | WEIGHTS = 'WEIGHTS' 36 | THRESHOLD = 'THRESHOLD' 37 | BALANCE = 'BALANCE' 38 | SIMILARITY = 'SIMILARITY' 39 | TYPICALITY = 'TYPICALITY' 40 | 41 | @classmethod 42 | def from_json(cls, json_str: str) -> 'SelectionStrategyTypeV3': 43 | """Create an instance of SelectionStrategyTypeV3 from a JSON string""" 44 | return SelectionStrategyTypeV3(json.loads(json_str)) 45 | 46 | 47 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/task_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class TaskType(str, Enum): 27 | """ 28 | The type of the prediction or label task 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | CLASSIFICATION = 'CLASSIFICATION' 35 | OBJECT_DETECTION = 'OBJECT_DETECTION' 36 | SEMANTIC_SEGMENTATION = 'SEMANTIC_SEGMENTATION' 37 | INSTANCE_SEGMENTATION = 'INSTANCE_SEGMENTATION' 38 | KEYPOINT_DETECTION = 'KEYPOINT_DETECTION' 39 | 40 | @classmethod 41 | def from_json(cls, json_str: str) -> 'TaskType': 42 | """Create an instance of TaskType from a JSON string""" 43 | return TaskType(json.loads(json_str)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/selection_strategy_threshold_operation.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class SelectionStrategyThresholdOperation(str, Enum): 27 | """ 28 | SelectionStrategyThresholdOperation 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | SMALLER = 'SMALLER' 35 | SMALLER_EQUAL = 'SMALLER_EQUAL' 36 | BIGGER = 'BIGGER' 37 | BIGGER_EQUAL = 'BIGGER_EQUAL' 38 | 39 | @classmethod 40 | def from_json(cls, json_str: str) -> 'SelectionStrategyThresholdOperation': 41 | """Create an instance of SelectionStrategyThresholdOperation from a JSON string""" 42 | return SelectionStrategyThresholdOperation(json.loads(json_str)) 43 | 44 | 45 | -------------------------------------------------------------------------------- /lightly/transforms/solarize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021. Lightly AG and its affiliates. 2 | # All Rights Reserved 3 | 4 | import numpy as np 5 | from PIL import ImageOps 6 | from PIL.Image import Image as PILImage 7 | 8 | 9 | class RandomSolarization(object): 10 | """Implementation of random image Solarization. 11 | 12 | Utilizes the integrated image operation `solarize` from Pillow. Solarization 13 | inverts all pixel values above a threshold (default: 128). 14 | 15 | Attributes: 16 | probability: 17 | Probability to apply the transformation 18 | threshold: 19 | Threshold for solarization. 20 | """ 21 | 22 | def __init__(self, prob: float = 0.5, threshold: int = 128): 23 | self.prob = prob 24 | self.threshold = threshold 25 | 26 | def __call__(self, sample: PILImage) -> PILImage: 27 | """Solarizes the given input image 28 | 29 | Args: 30 | sample: 31 | PIL image to which solarize will be applied. 32 | 33 | Returns: 34 | Solarized image or original image. 35 | 36 | """ 37 | prob = np.random.random_sample() 38 | if prob < self.prob: 39 | # return solarized image 40 | return ImageOps.solarize(sample, threshold=self.threshold) 41 | # return original image 42 | return sample 43 | -------------------------------------------------------------------------------- /lightly/cli/_cli_simclr.py: -------------------------------------------------------------------------------- 1 | """ SimCLR Model """ 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from lightly.models.modules import SimCLRProjectionHead 12 | 13 | 14 | class _SimCLR(nn.Module): 15 | """Implementation of SimCLR used by the command-line interface. 16 | 17 | Provides backwards compatability with old checkpoints. 18 | """ 19 | 20 | def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128): 21 | super(_SimCLR, self).__init__() 22 | 23 | self.backbone = backbone 24 | self.projection_head = SimCLRProjectionHead( 25 | num_ftrs, num_ftrs, out_dim, batch_norm=False 26 | ) 27 | 28 | def forward(self, x0: torch.Tensor, x1: torch.Tensor = None): 29 | """Embeds and projects the input images.""" 30 | 31 | # forward pass of first input x0 32 | f0 = self.backbone(x0).flatten(start_dim=1) 33 | out0 = self.projection_head(f0) 34 | 35 | # return out0 if x1 is None 36 | if x1 is None: 37 | return out0 38 | 39 | # forward pass of second input x1 40 | f1 = self.backbone(x1).flatten(start_dim=1) 41 | out1 = self.projection_head(f1) 42 | 43 | # return both outputs 44 | return out0, out1 45 | -------------------------------------------------------------------------------- /tests/transforms/test_vicregl_transform.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from PIL import Image 4 | from torch import Tensor 5 | 6 | from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform 7 | 8 | from .. import helpers 9 | 10 | 11 | def test_view_on_pil_image() -> None: 12 | single_view_transform = VICRegLViewTransform() 13 | sample = Image.new("RGB", (100, 100)) 14 | output = single_view_transform(sample) 15 | assert output.shape == (3, 100, 100) 16 | 17 | 18 | def test_multi_view_on_pil_image() -> None: 19 | multi_view_transform = VICRegLTransform( 20 | global_crop_size=32, 21 | local_crop_size=8, 22 | n_local_views=6, 23 | global_grid_size=4, 24 | local_grid_size=2, 25 | ) 26 | sample = Image.new("RGB", (100, 100)) 27 | output = helpers.assert_list_tensor(multi_view_transform(sample)) 28 | assert len(output) == 16 # (2 global crops * 2) + (6 local crops * 2) 29 | global_views = output[:2] 30 | local_views = output[2:8] 31 | global_grids = output[8:10] 32 | local_grids = output[10:] 33 | assert all(view.shape == (3, 32, 32) for view in global_views) 34 | assert all(view.shape == (3, 8, 8) for view in local_views) 35 | assert all(grid.shape == (4, 4, 2) for grid in global_grids) 36 | assert all(grid.shape == (2, 2, 2) for grid in local_grids) 37 | -------------------------------------------------------------------------------- /.github/workflows/weekly_dependency_test.yml: -------------------------------------------------------------------------------- 1 | # Install lightly from scratch and run tests. 2 | # Python environment for regular unit tests is cached. In this workflow, 3 | # we always pick up the latest dependencies and therefore check if a new 4 | # external release breaks lightly. 5 | name: Weekly Dependency Test 6 | 7 | on: 8 | workflow_dispatch: 9 | schedule: 10 | - cron: '0 22 * * THU' # each Thursday night, at 22:00 UTC 11 | 12 | jobs: 13 | test_fresh_install: 14 | name: Test fresh install 15 | runs-on: ubuntu-22.04 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | - name: Set Up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.12" 23 | - name: Set Up Environment 24 | run: | 25 | make install-uv reset-venv 26 | source .venv/bin/activate 27 | make install-latest 28 | - name: Run Tests 29 | id: run_tests 30 | run: | 31 | source .venv/bin/activate 32 | pytest -n auto 33 | - name: Slack Notification 34 | if: always() 35 | uses: rtCamp/action-slack-notify@v2 36 | env: 37 | SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_NIGHTLY }} 38 | SLACK_ICON_EMOJI: ":github:" 39 | SLACK_USERNAME: Lightly PIP Package ${{ github.workflow }} 40 | SLACK_COLOR: ${{ steps.run_tests.outcome }} 41 | SLACK_FOOTER: "" 42 | -------------------------------------------------------------------------------- /tests/test_requirements.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Set 3 | 4 | import toml 5 | 6 | 7 | def test_requirements_base__pyproject() -> None: 8 | """Check that all dependencies in requirements/base.txt are also in 9 | pyproject.toml.""" 10 | missing_deps = _requirements("base") - _pyproject("dependencies") 11 | assert not missing_deps 12 | 13 | 14 | def test_requirements_base__openapi() -> None: 15 | """Check that all dependencies in requirements/openapi.txt are also in 16 | requirements/base.txt.""" 17 | openapi = { 18 | dep for dep in _requirements("openapi") if not dep.startswith("setuptools") 19 | } 20 | missing_deps = openapi - _requirements("base") 21 | assert not missing_deps 22 | 23 | 24 | def _pyproject(name: str) -> Set[str]: 25 | """Returns dependencies from pyproject.toml.""" 26 | return _normalize_dependencies(toml.load("pyproject.toml")["project"][name]) 27 | 28 | 29 | def _requirements(name: str) -> Set[str]: 30 | """Returns dependencies from requirements files.""" 31 | path = Path(__file__).parent.parent / "requirements" / f"{name}.txt" 32 | return _normalize_dependencies(path.read_text().splitlines()) 33 | 34 | 35 | def _normalize_dependencies(deps: List[str]) -> Set[str]: 36 | """Remove spaces, empty lines, and comments.""" 37 | return {dep.replace(" ", "") for dep in deps if dep and not dep.startswith("#")} 38 | -------------------------------------------------------------------------------- /tests/utils/test_version_compare.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from lightly.utils import version_compare 4 | 5 | 6 | class TestVersionCompare(unittest.TestCase): 7 | def test_valid_versions(self) -> None: 8 | # general test of smaller than version numbers 9 | self.assertEqual(version_compare.version_compare("0.1.4", "1.2.0"), -1) 10 | self.assertEqual(version_compare.version_compare("1.1.0", "1.2.0"), -1) 11 | 12 | # test bigger than 13 | self.assertEqual(version_compare.version_compare("1.2.0", "1.1.0"), 1) 14 | self.assertEqual(version_compare.version_compare("1.2.0", "0.1.4"), 1) 15 | 16 | # test equal 17 | self.assertEqual(version_compare.version_compare("1.2.0", "1.2.0"), 0) 18 | 19 | def test_invalid_versions(self) -> None: 20 | with self.assertRaises(ValueError): 21 | version_compare.version_compare("1.2", "1.1.0") 22 | 23 | with self.assertRaises(ValueError): 24 | version_compare.version_compare("1.2.0.1", "1.1.0") 25 | 26 | # test within same minor version and with special cases 27 | with self.assertRaises(ValueError): 28 | self.assertEqual(version_compare.version_compare("1.0.7", "1.1.0.dev1"), -1) 29 | 30 | with self.assertRaises(ValueError): 31 | self.assertEqual( 32 | version_compare.version_compare("1.1.0.dev1", "1.1.0rc1"), -1 33 | ) 34 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/tag_creator.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class TagCreator(str, Enum): 27 | """ 28 | TagCreator 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | UNKNOWN = 'UNKNOWN' 35 | USER_WEBAPP = 'USER_WEBAPP' 36 | USER_PIP = 'USER_PIP' 37 | USER_PIP_LIGHTLY_MAGIC = 'USER_PIP_LIGHTLY_MAGIC' 38 | USER_WORKER = 'USER_WORKER' 39 | SAMPLER_ACTIVE_LEARNING = 'SAMPLER_ACTIVE_LEARNING' 40 | SAMPLER_CORAL = 'SAMPLER_CORAL' 41 | SAMPLER_CORESET = 'SAMPLER_CORESET' 42 | SAMPLER_RANDOM = 'SAMPLER_RANDOM' 43 | 44 | @classmethod 45 | def from_json(cls, json_str: str) -> 'TagCreator': 46 | """Create an instance of TagCreator from a JSON string""" 47 | return TagCreator(json.loads(json_str)) 48 | 49 | 50 | -------------------------------------------------------------------------------- /lightly/transforms/random_frequency_mask_transform.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | class RandomFrequencyMaskTransform: 9 | """2D Random Frequency Mask Transformation. 10 | 11 | This transformation applies a binary mask on the fourier transform, 12 | across all channels. A proportion of k frequencies are set to 0 with this. 13 | 14 | Input 15 | - Tensor: RFFT of a 2D Image (C, H, W) C-> No. of Channels 16 | Output 17 | - Tensor: The masked RFFT of the image 18 | 19 | """ 20 | 21 | def __init__(self, k: Tuple[float, float] = (0.01, 0.1)) -> None: 22 | self.k = k 23 | 24 | def __call__(self, fft_image: Tensor) -> Tensor: 25 | k = np.random.uniform(low=self.k[0], high=self.k[1]) 26 | 27 | # Every mask for every channel will have same frequencies being turned off i.e. being set to zero 28 | mask = ( 29 | torch.rand(fft_image.shape[1:], device=fft_image.device) > k 30 | ) # mask_type: (H, W) 31 | 32 | # Do not mask zero frequency mode to retain majority of the semantic information. 33 | # Please refer https://arxiv.org/abs/2312.02205 34 | mask[0, 0] = 1 35 | 36 | # Adding channel dimension 37 | mask = mask.unsqueeze(0) 38 | 39 | masked_frequency_spectrum_image = fft_image * mask 40 | 41 | return masked_frequency_spectrum_image 42 | -------------------------------------------------------------------------------- /lightly/utils/reordering.py: -------------------------------------------------------------------------------- 1 | from typing import List, Sequence, TypeVar 2 | 3 | _K = TypeVar("_K") 4 | _V = TypeVar("_V") 5 | 6 | 7 | def sort_items_by_keys( 8 | keys: Sequence[_K], items: Sequence[_V], sorted_keys: Sequence[_K] 9 | ) -> List[_V]: 10 | """Sorts the items in the same order as the sorted keys. 11 | 12 | Args: 13 | keys: 14 | Keys by which items can be identified. 15 | items: 16 | Items to sort. 17 | sorted_keys: 18 | Keys in sorted order. 19 | 20 | Returns: 21 | The list of sorted items. 22 | 23 | Examples: 24 | >>> keys = [3, 2, 1] 25 | >>> items = ['!', 'world', 'hello'] 26 | >>> sorted_keys = [1, 2, 3] 27 | >>> sorted_items = sort_items_by_keys( 28 | >>> keys, 29 | >>> items, 30 | >>> sorted_keys, 31 | >>> ) 32 | >>> print(sorted_items) 33 | >>> > ['hello', 'world', '!'] 34 | 35 | """ 36 | if len(keys) != len(items) or len(keys) != len(sorted_keys): 37 | raise ValueError( 38 | f"All inputs (keys, items and sorted_keys) " 39 | f"must have the same length, " 40 | f"but their lengths are: ({len(keys)}," 41 | f"{len(items)} and {len(sorted_keys)})." 42 | ) 43 | lookup = {key_: item_ for key_, item_ in zip(keys, items)} 44 | sorted_ = [lookup[key_] for key_ in sorted_keys] 45 | return sorted_ 46 | -------------------------------------------------------------------------------- /.github/workflows/discord_release_notification.yml: -------------------------------------------------------------------------------- 1 | name: Discord Release Notification 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | notify-discord: 9 | runs-on: ubuntu-22.04 10 | steps: 11 | - name: Send Notification to Discord 12 | env: 13 | DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK }} 14 | # We truncate the description at the models section (starting with ### Models) 15 | # to keep the message short. 16 | # We have also have to format the release description for it to be valid JSON. 17 | # This is done by piping the description to jq. 18 | run: | 19 | DESCRIPTION=$(echo '${{ github.event.release.body }}' | awk '/### Models/{exit}1' | jq -aRs .) 20 | curl -H "Content-Type: application/json" \ 21 | -X POST \ 22 | -d @- \ 23 | "${DISCORD_WEBHOOK}" << EOF 24 | { 25 | "username": "Lightly", 26 | "avatar_url": "https://avatars.githubusercontent.com/u/50146475", 27 | "content": "Lightly ${{ github.event.release.tag_name }} has been released!", 28 | "embeds": [ 29 | { 30 | "title": "${{ github.event.release.name }}", 31 | "url": "${{ github.event.release.html_url }}", 32 | "color": 5814783, 33 | "description": $DESCRIPTION 34 | } 35 | ] 36 | } 37 | EOF 38 | -------------------------------------------------------------------------------- /lightly/utils/cropping/read_yolo_label_file.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from lightly.utils.bounding_box import BoundingBox 4 | 5 | 6 | def read_yolo_label_file( 7 | filepath: str, padding: float, separator: str = " " 8 | ) -> Tuple[List[int], List[BoundingBox]]: 9 | """Reads a file in the yolo file format 10 | 11 | Args: 12 | filepath: 13 | The path to the yolo file, usually a .txt file. 14 | padding: 15 | The relative passepartout / padding to be added around the bounding box 16 | separator: 17 | The separator character between the 5 values (class id, x, y, w, h) per row in the label file. 18 | 19 | Returns: 20 | Two lists, each with one element per row in the label file: 21 | The class indices. 22 | The bounding boxes. 23 | 24 | """ 25 | with open(filepath, "r") as f: 26 | lines = f.readlines() 27 | 28 | class_indices = [] 29 | bounding_boxes = [] 30 | for line in lines: 31 | values = line.split(sep=separator) 32 | class_id, x_norm, y_norm, w_norm, h_norm = (float(val) for val in values) 33 | class_id = int(class_id) 34 | class_indices.append(class_id) 35 | 36 | w_norm *= 1 + padding 37 | h_norm *= 1 + padding 38 | bbox = BoundingBox.from_yolo_label(x_norm, y_norm, w_norm, h_norm) 39 | bounding_boxes.append(bbox) 40 | return class_indices, bounding_boxes 41 | -------------------------------------------------------------------------------- /lightly/transforms/aim_transform.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union 2 | 3 | from lightly.transforms.mae_transform import MAETransform 4 | from lightly.transforms.utils import IMAGENET_NORMALIZE 5 | 6 | 7 | class AIMTransform(MAETransform): 8 | """Implements the view augmentation for AIM [0]. 9 | 10 | Uses the same parameters as MAE [1] but with larger min_scale. 11 | 12 | Input to this transform: 13 | PIL Image or Tensor. 14 | 15 | Output of this transform: 16 | List of Tensor of length 1. 17 | 18 | Applies the following augmentations by default: 19 | - Random resized crop 20 | - Random horizontal flip 21 | 22 | - [0]: AIM, 2024, https://arxiv.org/abs/2401.08541 23 | - [1]: Masked Autoencoder, 2021, https://arxiv.org/abs/2111.06377 24 | 25 | Attributes: 26 | input_size: 27 | Size of the input image in pixels. 28 | min_scale: 29 | Minimum size of the randomized crop relative to the input_size. 30 | normalize: 31 | Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize. 32 | 33 | """ 34 | 35 | def __init__( 36 | self, 37 | input_size: Union[int, Tuple[int, int]] = 224, 38 | min_scale: float = 0.4, 39 | normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, 40 | ): 41 | super().__init__( 42 | input_size=input_size, 43 | min_scale=min_scale, 44 | normalize=normalize, 45 | ) 46 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/docker_run_artifact_type.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DockerRunArtifactType(str, Enum): 27 | """ 28 | DockerRunArtifactType 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | LOG = 'LOG' 35 | MEMLOG = 'MEMLOG' 36 | CHECKPOINT = 'CHECKPOINT' 37 | REPORT_PDF = 'REPORT_PDF' 38 | REPORT_JSON = 'REPORT_JSON' 39 | REPORT_V2_JSON = 'REPORT_V2_JSON' 40 | CORRUPTNESS_CHECK_INFORMATION = 'CORRUPTNESS_CHECK_INFORMATION' 41 | SEQUENCE_INFORMATION = 'SEQUENCE_INFORMATION' 42 | RELEVANT_FILENAMES = 'RELEVANT_FILENAMES' 43 | 44 | @classmethod 45 | def from_json(cls, json_str: str) -> 'DockerRunArtifactType': 46 | """Create an instance of DockerRunArtifactType from a JSON string""" 47 | return DockerRunArtifactType(json.loads(json_str)) 48 | 49 | 50 | -------------------------------------------------------------------------------- /docs/source/tutorials_source/package/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /tests/models/modules/test_center.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import Tensor 4 | 5 | from lightly.models.modules.center import Center 6 | 7 | 8 | class TestCenter: 9 | def test__init__invalid_mode(self) -> None: 10 | with pytest.raises(ValueError): 11 | Center(size=(1, 32), mode="invalid") 12 | 13 | def test_value(self) -> None: 14 | center = Center(size=(1, 32), mode="mean") 15 | assert torch.all(center.value == 0) 16 | 17 | @pytest.mark.parametrize( 18 | "x, expected", 19 | [ 20 | (torch.tensor([[0.0, 0.0], [0.0, 0.0]]), torch.tensor([0.0, 0.0])), 21 | (torch.tensor([[1.0, 2.0], [3.0, 4.0]]), torch.tensor([2.0, 3.0])), 22 | ], 23 | ) 24 | def test_update(self, x: Tensor, expected: Tensor) -> None: 25 | center = Center(size=(1, 2), mode="mean", momentum=0.0) 26 | center.update(x) 27 | assert torch.all(center.value == expected) 28 | 29 | @pytest.mark.parametrize( 30 | "momentum, expected", 31 | [ 32 | (0.0, torch.tensor([1.0, 2.0])), 33 | (0.1, torch.tensor([0.9, 1.8])), 34 | (0.5, torch.tensor([0.5, 1.0])), 35 | (1.0, torch.tensor([0.0, 0.0])), 36 | ], 37 | ) 38 | def test_update__momentum(self, momentum: float, expected: Tensor) -> None: 39 | center = Center(size=(1, 2), mode="mean", momentum=momentum) 40 | center.update(torch.tensor([[1.0, 2.0]])) 41 | assert torch.all(center.value == expected) 42 | -------------------------------------------------------------------------------- /docs/source/getting_started/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | =================================== 3 | 4 | Supported Python versions 5 | ------------------------- 6 | 7 | Lightly\ **SSL** requires Python 3.7+. We recommend installing Lightly\ **SSL** in a Linux or OSX environment. 8 | 9 | .. _rst-installing: 10 | 11 | Installing LightlySSL 12 | ---------------------- 13 | 14 | You can install Lightly\ **SSL** and its dependencies from PyPi with: 15 | 16 | .. code-block:: bash 17 | 18 | pip install lightly 19 | 20 | We strongly recommend that you install Lightly\ **SSL** in a dedicated virtualenv, to avoid conflicting with your system packages. 21 | 22 | Dependencies 23 | ------------ 24 | Lightly\ **SSL** currently uses `PyTorch `_ as the underlying deep learning framework. 25 | On top of PyTorch we use `Hydra `_ for managing configurations and 26 | `PyTorch Lightning `_ for training models. 27 | 28 | If you want to work with video files you need to additionally install 29 | `PyAV `_. 30 | 31 | .. code-block:: bash 32 | 33 | pip install "lightly[video]" 34 | 35 | If you want to work with the Masked Autoencoder you need to additionally install 36 | `TIMM `_. 37 | 38 | .. code-block:: bash 39 | 40 | pip install "lightly[timm]" 41 | 42 | Next Steps 43 | ------------ 44 | 45 | Start with one of our tutorials: :ref:`input-structure-label` 46 | -------------------------------------------------------------------------------- /.github/workflows/test_api_deps_only.yml: -------------------------------------------------------------------------------- 1 | 2 | name: Install API Dependencies Only 3 | 4 | on: 5 | push: 6 | branches: 7 | - master 8 | pull_request: 9 | workflow_dispatch: 10 | 11 | jobs: 12 | detect-code-changes: 13 | name: Detect Code Changes 14 | runs-on: ubuntu-22.04 15 | outputs: 16 | run-tests: ${{ steps.filter.outputs.run-tests }} 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: dorny/paths-filter@v3 20 | id: filter 21 | with: 22 | list-files: shell 23 | filters: | 24 | run-tests: 25 | - '!docs/**' 26 | - '!examples/**' 27 | - '!benchmarks/**' 28 | 29 | test: 30 | name: Test 31 | needs: detect-code-changes 32 | if: needs.detect-code-changes.outputs.run-tests == 'true' 33 | runs-on: ubuntu-22.04 34 | 35 | steps: 36 | - name: Checkout Code 37 | uses: actions/checkout@v4 38 | - name: Set Up Python 39 | uses: actions/setup-python@v5 40 | with: 41 | python-version: "3.12" 42 | - name: Set Up Environment 43 | run: | 44 | make install-uv reset-venv 45 | source .venv/bin/activate 46 | make install-api-only 47 | - name: Test API Client 48 | run: | 49 | source .venv/bin/activate 50 | # Test if ApiWorkflowClient can be created. 51 | export LIGHTLY_SERVER_LOCATION="localhost:-1" 52 | python -c "from lightly.api import ApiWorkflowClient; ApiWorkflowClient(token='abc')" 53 | -------------------------------------------------------------------------------- /tests/loss/test_koleo_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from lightly.loss.koleo_loss import KoLeoLoss 10 | 11 | 12 | class TestKoLeoLoss: 13 | # Test values generated using the original implementation from: 14 | # https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/koleo_loss.py 15 | @pytest.mark.parametrize( 16 | "x, expected_loss", 17 | [ 18 | (torch.tensor([[1.0]]), 17.7275), 19 | (torch.tensor([[0.0, 1.0], [0.0, -1.0]]), -math.log(2)), 20 | (torch.tensor([[0.0, 1.0], [1.0, 0.0]]), -math.log(2**0.5)), 21 | ( 22 | torch.tensor( 23 | [ 24 | [0.0, 1.0], 25 | [1.0, 0.0], 26 | [1.0, 1.0], 27 | [-1.0, 0.0], 28 | [0.0, -1.0], 29 | [-1.0, -1.0], 30 | ] 31 | ), 32 | 0.2674, 33 | ), 34 | ], 35 | ) 36 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 37 | def test_forward(self, x: Tensor, expected_loss: float, device: str) -> None: 38 | if device == "cuda" and not torch.cuda.is_available(): 39 | pytest.skip("CUDA not available, skipping test.") 40 | 41 | x = x.to(device) 42 | loss = KoLeoLoss().to(device) 43 | assert loss(x).item() == pytest.approx(expected_loss, rel=1e-4) 44 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/api/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # import apis into api package 4 | from lightly.openapi_generated.swagger_client.api.collaboration_api import CollaborationApi 5 | from lightly.openapi_generated.swagger_client.api.datasets_api import DatasetsApi 6 | from lightly.openapi_generated.swagger_client.api.datasources_api import DatasourcesApi 7 | from lightly.openapi_generated.swagger_client.api.docker_api import DockerApi 8 | from lightly.openapi_generated.swagger_client.api.embeddings_api import EmbeddingsApi 9 | from lightly.openapi_generated.swagger_client.api.embeddings2d_api import Embeddings2dApi 10 | from lightly.openapi_generated.swagger_client.api.jobs_api import JobsApi 11 | from lightly.openapi_generated.swagger_client.api.mappings_api import MappingsApi 12 | from lightly.openapi_generated.swagger_client.api.meta_data_configurations_api import MetaDataConfigurationsApi 13 | from lightly.openapi_generated.swagger_client.api.predictions_api import PredictionsApi 14 | from lightly.openapi_generated.swagger_client.api.quota_api import QuotaApi 15 | from lightly.openapi_generated.swagger_client.api.samples_api import SamplesApi 16 | from lightly.openapi_generated.swagger_client.api.samplings_api import SamplingsApi 17 | from lightly.openapi_generated.swagger_client.api.scores_api import ScoresApi 18 | from lightly.openapi_generated.swagger_client.api.tags_api import TagsApi 19 | from lightly.openapi_generated.swagger_client.api.teams_api import TeamsApi 20 | from lightly.openapi_generated.swagger_client.api.versioning_api import VersioningApi 21 | 22 | -------------------------------------------------------------------------------- /tests/data/test_multi_view_collate.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | from warnings import warn 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from lightly.data.multi_view_collate import MultiViewCollate 8 | 9 | 10 | def test_empty_batch(): 11 | multi_view_collate = MultiViewCollate() 12 | batch = [] 13 | views, labels, fnames = multi_view_collate(batch) 14 | assert len(views) == 0 15 | assert len(labels) == 0 16 | assert len(fnames) == 0 17 | 18 | 19 | def test_single_item_batch(): 20 | multi_view_collate = MultiViewCollate() 21 | img = [torch.randn((3, 224, 224)) for _ in range(5)] 22 | label = 1 23 | fname = "image1.jpg" 24 | batch = [(img, label, fname)] 25 | views, labels, fnames = multi_view_collate(batch) 26 | assert len(views) == 5 27 | assert views[0].shape == (1, 3, 224, 224) 28 | assert torch.equal(labels, torch.tensor([1])) 29 | assert fnames == ["image1.jpg"] 30 | 31 | 32 | def test_multiple_item_batch(): 33 | multi_view_collate = MultiViewCollate() 34 | img1 = [torch.randn((3, 224, 224)) for _ in range(5)] 35 | label1 = 1 36 | fname1 = "image1.jpg" 37 | img2 = [torch.randn((3, 224, 224)) for _ in range(5)] 38 | label2 = 2 39 | fname2 = "image2.jpg" 40 | batch = [(img1, label1, fname1), (img2, label2, fname2)] 41 | views, labels, fnames = multi_view_collate(batch) 42 | assert len(views) == 5 43 | assert views[0].shape == (2, 3, 224, 224) 44 | assert torch.equal(labels, torch.tensor([1, 2], dtype=torch.long)) 45 | assert fnames == ["image1.jpg", "image2.jpg"] 46 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: detect-private-key # check for private keys 6 | - id: check-added-large-files # prevent commit of files >500kB 7 | args: ['--maxkb=500'] 8 | - repo: https://github.com/psf/black 9 | rev: 24.8.0 # aligned with the version defined in pyproject.toml 10 | hooks: 11 | - id: black 12 | - repo: https://github.com/pycqa/isort 13 | rev: 5.13.2 # aligned with the version defined in pyproject.toml 14 | hooks: 15 | - id: isort 16 | - repo: https://github.com/pre-commit/mirrors-mypy 17 | rev: v1.11.1 # aligned with the version defined in pyproject.toml 18 | hooks: 19 | - id: mypy 20 | additional_dependencies: 21 | - 'numpy' 22 | - repo: https://github.com/nbQA-dev/nbQA 23 | rev: 1.8.7 24 | hooks: 25 | - id: nbqa-black 26 | - id: nbqa-isort 27 | args: ["--profile=black"] 28 | - repo: local 29 | hooks: 30 | - id: pytest-check # run all tests 31 | name: pytest-check 32 | entry: make test-fast 33 | language: system 34 | pass_filenames: false 35 | stages: [push] 36 | # Avoid running tests if non-tested files have changed. 37 | # The regex follows the pattern in the docs: https://pre-commit.com/#regular-expressions 38 | exclude: | 39 | (?x)^( 40 | benchmark_logs/.*| 41 | docs/.*| 42 | examples/.*| 43 | \.gitignore| 44 | CONTRIBUTING\.md| 45 | LICENSE\.txt| 46 | README\.md 47 | PRECOMMITHOOKS\.md| 48 | )$ 49 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/datasource_purpose.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class DatasourcePurpose(str, Enum): 27 | """ 28 | The datasource purpose and for which use-cases it is needed. - INPUT_OUTPUT: Is used as source of raw data and predictions/metadata within .lightly as well as destination for writing thumbnails, crops or video frame within .lightly - INPUT: Is only used as source of raw data - LIGHTLY: Is used as source of predictions/metadata within .lightly as well as destination for writing thumbnails, crops or video frames within .lightly 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | INPUT_OUTPUT = 'INPUT_OUTPUT' 35 | INPUT = 'INPUT' 36 | LIGHTLY = 'LIGHTLY' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'DatasourcePurpose': 40 | """Create an instance of DatasourcePurpose from a JSON string""" 41 | return DatasourcePurpose(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /docs/source/examples/smog.rst: -------------------------------------------------------------------------------- 1 | .. _smog: 2 | 3 | SMoG 4 | ==== 5 | 6 | Example implementation of the Synchronous Momentum Grouping (SMoG) paper. 7 | SMoG follows the framework of contrastive learning but replaces the contrastive 8 | unit from instance to group, mimicking clustering-based methods. To 9 | achieve this, they propose the momentum grouping scheme which synchronously 10 | conducts feature grouping with representation learning. 11 | 12 | Reference: 13 | `Unsupervised Visual Representation Learning by Synchronous Momentum Grouping, 2022 `_ 14 | 15 | 16 | .. tabs:: 17 | .. tab:: PyTorch 18 | 19 | .. image:: https://img.shields.io/badge/Open%20in%20Colab-blue?logo=googlecolab&label=%20&labelColor=5c5c5c 20 | :target: https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/smog.ipynb 21 | 22 | This example can be run from the command line with:: 23 | 24 | python lightly/examples/pytorch/smog.py 25 | 26 | .. literalinclude:: ../../../examples/pytorch/smog.py 27 | 28 | .. tab:: Lightning 29 | 30 | .. image:: https://img.shields.io/badge/Open%20in%20Colab-blue?logo=googlecolab&label=%20&labelColor=5c5c5c 31 | :target: https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/smog.ipynb 32 | 33 | This example can be run from the command line with:: 34 | 35 | python lightly/examples/pytorch_lightning/smog.py 36 | 37 | .. literalinclude:: ../../../examples/pytorch_lightning/smog.py 38 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/file_name_format.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class FileNameFormat(str, Enum): 27 | """ 28 | When the filename is output, which format shall be used. E.g for a sample called 'frame0.png' that was uploaded from a datasource 's3://my_bucket/datasets/for_lightly/' in the folder 'car/green/' - NAME: car/green/frame0.png - DATASOURCE_FULL: s3://my_bucket/datasets/for_lightly/car/green/frame0.png - REDIRECTED_READ_URL: https://api.lightly.ai/v1/datasets/{datasetId}/samples/{sampleId}/readurlRedirect?publicToken={jsonWebToken} 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | NAME = 'NAME' 35 | DATASOURCE_FULL = 'DATASOURCE_FULL' 36 | REDIRECTED_READ_URL = 'REDIRECTED_READ_URL' 37 | 38 | @classmethod 39 | def from_json(cls, json_str: str) -> 'FileNameFormat': 40 | """Create an instance of FileNameFormat from a JSON string""" 41 | return FileNameFormat(json.loads(json_str)) 42 | 43 | 44 | -------------------------------------------------------------------------------- /lightly/transforms/tico_transform.py: -------------------------------------------------------------------------------- 1 | from lightly.transforms.byol_transform import ( 2 | BYOLTransform, 3 | BYOLView1Transform, 4 | BYOLView2Transform, 5 | ) 6 | 7 | 8 | class TiCoTransform(BYOLTransform): 9 | """Implements the transformations for TiCo[0]. These are the same as BYOL[1]. 10 | 11 | Input to this transform: 12 | PIL Image or Tensor. 13 | 14 | Output of this transform: 15 | List of Tensor of length 2. 16 | 17 | Applies the following augmentations by default: 18 | - Random resized crop 19 | - Random horizontal flip 20 | - Color jitter 21 | - Random gray scale 22 | - Gaussian blur 23 | - Solarization 24 | - ImageNet normalization 25 | 26 | Note that SimCLR v1 and v2 use similar augmentations. In detail, TiCo (and BYOL) has 27 | asymmetric gaussian blur and solarization. Furthermore, TiCo has weaker 28 | color jitter compared to SimCLR. 29 | 30 | - [0]: Jiachen Zhu et. al, 2022, Tico, https://arxiv.org/abs/2206.10698.pdf 31 | - [1]: Bootstrap Your Own Latent, 2020, https://arxiv.org/pdf/2006.07733.pdf 32 | 33 | Input to this transform: 34 | PIL Image or Tensor. 35 | 36 | Output of this transform: 37 | List of [tensor, tensor]. 38 | 39 | Attributes: 40 | view_1_transform: The transform for the first view. 41 | view_2_transform: The transform for the second view. 42 | """ 43 | 44 | 45 | class TiCoView1Transform(BYOLView1Transform): 46 | """Alias for BYOLView1Transform.""" 47 | 48 | 49 | class TiCoView2Transform(BYOLView2Transform): 50 | """Alias for BYOLView2Transform.""" 51 | -------------------------------------------------------------------------------- /tests/loss/test_MMCR_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lightly.loss.mmcr_loss import MMCRLoss 5 | 6 | 7 | class TestMMCRLoss: 8 | def test_forward(self) -> None: 9 | bs = 3 10 | dim = 128 11 | k = 32 12 | 13 | loss_fn = MMCRLoss() 14 | online = torch.randn(bs, k, dim) 15 | momentum = torch.randn(bs, k, dim) 16 | 17 | loss_fn(online, momentum) 18 | 19 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") 20 | def test_forward_cuda(self) -> None: 21 | bs = 3 22 | dim = 128 23 | k = 32 24 | 25 | loss_fn = MMCRLoss() 26 | online = torch.randn(bs, k, dim).cuda() 27 | momentum = torch.randn(bs, k, dim).cuda() 28 | 29 | loss_fn(online, momentum) 30 | 31 | def test_loss_value(self) -> None: 32 | """If all values are zero, the loss should be zero.""" 33 | bs = 3 34 | dim = 128 35 | k = 32 36 | 37 | loss_fn = MMCRLoss() 38 | online = torch.zeros(bs, k, dim) 39 | momentum = torch.zeros(bs, k, dim) 40 | 41 | loss = loss_fn(online, momentum) 42 | assert loss == 0.0 43 | 44 | def test_lambda_value_error(self) -> None: 45 | """If lambda is negative, a ValueError should be raised.""" 46 | with pytest.raises(ValueError): 47 | MMCRLoss(lmda=-1) 48 | 49 | def test_shape_assertion_forward(self) -> None: 50 | bs = 3 51 | dim = 128 52 | k = 32 53 | 54 | loss_fn = MMCRLoss() 55 | online = torch.randn(bs, k, dim) 56 | momentum = torch.randn(bs, k, dim + 1) 57 | 58 | with pytest.raises(AssertionError): 59 | loss_fn(online, momentum) 60 | -------------------------------------------------------------------------------- /tests/api/test_rest_parser.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from lightly.openapi_generated.swagger_client import ( 6 | ActiveLearningScoreCreateRequest, 7 | ApiClient, 8 | SamplingMethod, 9 | ScoresApi, 10 | ) 11 | from lightly.openapi_generated.swagger_client.rest import ApiException 12 | 13 | 14 | class TestRestParser(unittest.TestCase): 15 | @unittest.skip("This test only shows the error, it does not ensure it is solved.") 16 | def test_parse_active_learning_scores(self): 17 | score_value_tuple = ( 18 | np.random.normal(0, 1, size=(999,)).astype(np.float32), 19 | np.random.normal(0, 1, size=(999,)).astype(np.float64), 20 | [12.0] * 999, 21 | ) 22 | api_client = ApiClient() 23 | self._scores_api = ScoresApi(api_client) 24 | for i, score_values in enumerate(score_value_tuple): 25 | with self.subTest(i=i, msg=str(type(score_values))): 26 | body = ActiveLearningScoreCreateRequest( 27 | score_type=SamplingMethod.CORESET, scores=list(score_values) 28 | ) 29 | if isinstance(score_values[0], float): 30 | with self.assertRaises(ApiException): 31 | self._scores_api.create_or_update_active_learning_score_by_tag_id( 32 | body, dataset_id="dataset_id_xyz", tag_id="tag_id_xyz" 33 | ) 34 | else: 35 | with self.assertRaises(AttributeError): 36 | self._scores_api.create_or_update_active_learning_score_by_tag_id( 37 | body, dataset_id="dataset_id_xyz", tag_id="tag_id_xyz" 38 | ) 39 | -------------------------------------------------------------------------------- /.github/workflows/test_minimal_deps.yml: -------------------------------------------------------------------------------- 1 | name: Minimal Dependency Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | jobs: 11 | detect-code-changes: 12 | name: Detect Code Changes 13 | runs-on: ubuntu-22.04 14 | outputs: 15 | run-tests: ${{ steps.filter.outputs.run-tests }} 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: dorny/paths-filter@v3 19 | id: filter 20 | with: 21 | list-files: shell 22 | filters: | 23 | run-tests: 24 | - '!docs/**' 25 | - '!examples/**' 26 | - '!benchmarks/**' 27 | 28 | test: 29 | name: Test 30 | needs: detect-code-changes 31 | if: needs.detect-code-changes.outputs.run-tests == 'true' 32 | runs-on: ubuntu-22.04 33 | strategy: 34 | matrix: 35 | python: ["3.7"] 36 | dependencies: ["minimal", "minimal-extras"] 37 | steps: 38 | - uses: actions/checkout@v4 39 | - name: Set Up Python 40 | uses: actions/setup-python@v5 41 | with: 42 | python-version: ${{ matrix.python }} 43 | - name: Set Up Environment 44 | run: | 45 | # Install libavformat-dev and libavdevice-dev for PyAV as versions before 9.0 46 | # do not provide wheels including FFMPEG. 47 | sudo apt-get install libavformat-dev libavdevice-dev 48 | make install-uv reset-venv 49 | source .venv/bin/activate 50 | make install-${{ matrix.dependencies }} 51 | - name: Run Tests 52 | run: | 53 | source .venv/bin/activate 54 | export LIGHTLY_SERVER_LOCATION="localhost:-1" 55 | python -m pytest -s -v --runslow 56 | -------------------------------------------------------------------------------- /lightly/models/zoo.py: -------------------------------------------------------------------------------- 1 | """ Lightly Model Zoo """ 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | 6 | from typing import List 7 | 8 | ZOO = { 9 | "resnet-9/simclr/d16/w0.0625": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.0625-i-ce0d6bd9.pth", 10 | "resnet-9/simclr/d16/w0.125": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.125-i-7269c38d.pth", 11 | "resnet-18/simclr/d16/w1.0": "https://storage.googleapis.com/models_boris/whattolabel-resnet18-simclr-d16-w1.0-i-58852cb9.pth", 12 | "resnet-18/simclr/d32/w1.0": "https://storage.googleapis.com/models_boris/whattolabel-resnet18-simclr-d32-w1.0-i-085d0693.pth", 13 | "resnet-34/simclr/d16/w1.0": "https://storage.googleapis.com/models_boris/whattolabel-resnet34-simclr-d16-w1.0-i-6e80d963.pth", 14 | "resnet-34/simclr/d32/w1.0": "https://storage.googleapis.com/models_boris/whattolabel-resnet34-simclr-d32-w1.0-i-9f185b45.pth", 15 | } 16 | 17 | 18 | def checkpoints() -> List[str]: 19 | """Returns the Lightly model zoo as a list of checkpoints. 20 | 21 | Checkpoints: 22 | ResNet-9: 23 | SimCLR with width = 0.0625 and num_ftrs = 16 24 | ResNet-9: 25 | SimCLR with width = 0.125 and num_ftrs = 16 26 | ResNet-18: 27 | SimCLR with width = 1.0 and num_ftrs = 16 28 | ResNet-18: 29 | SimCLR with width = 1.0 and num_ftrs = 32 30 | ResNet-34: 31 | SimCLR with width = 1.0 and num_ftrs = 16 32 | ResNet-34: 33 | SimCLR with width = 1.0 and num_ftrs = 32 34 | 35 | Returns: 36 | A list of available checkpoints as URLs. 37 | 38 | """ 39 | return [item for key, item in ZOO.items()] 40 | -------------------------------------------------------------------------------- /tests/loss/test_wmse_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lightly.loss.wmse_loss import WMSELoss 5 | 6 | try: 7 | import torch.linalg.solve_triangular 8 | except ImportError: 9 | pytest.skip("torch.linalg.solve_triangular not available", allow_module_level=True) 10 | 11 | 12 | class TestWMSELoss: 13 | def test_forward(self) -> None: 14 | bs = 512 15 | dim = 128 16 | num_samples = 2 17 | 18 | loss_fn = WMSELoss() 19 | x = torch.randn(bs * num_samples, dim) 20 | 21 | loss_fn(x) 22 | 23 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") 24 | def test_forward_cuda(self) -> None: 25 | bs = 512 26 | dim = 128 27 | num_samples = 2 28 | 29 | loss_fn = WMSELoss().cuda() 30 | x = torch.randn(bs * num_samples, dim).cuda() 31 | 32 | loss_fn(x) 33 | 34 | def test_loss_value(self) -> None: 35 | """If all values are zero, the loss should be zero.""" 36 | bs = 512 37 | dim = 128 38 | num_samples = 2 39 | 40 | loss_fn = WMSELoss() 41 | x = torch.randn(bs * num_samples, dim) 42 | 43 | loss = loss_fn(x) 44 | assert loss > 0 45 | 46 | def test_embedding_dim_error(self) -> None: 47 | with pytest.raises(ValueError): 48 | WMSELoss(embedding_dim=2, w_size=2) 49 | 50 | def test_num_samples_error(self) -> None: 51 | with pytest.raises(RuntimeError): 52 | loss_fn = WMSELoss(num_samples=3) 53 | x = torch.randn(5, 128) 54 | loss_fn(x) 55 | 56 | def test_w_size_error(self) -> None: 57 | with pytest.raises(ValueError): 58 | loss_fn = WMSELoss(w_size=5) 59 | x = torch.randn(4, 128) 60 | loss_fn(x) 61 | -------------------------------------------------------------------------------- /tests/utils/benchmarking/test_metric_callback.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning import LightningModule, Trainer 3 | from torch.utils.data import DataLoader 4 | from torchvision.datasets import FakeData 5 | from torchvision.transforms import ToTensor 6 | 7 | from lightly.utils.benchmarking import MetricCallback 8 | 9 | 10 | class TestMetricCallback: 11 | def test(self) -> None: 12 | callback = MetricCallback() 13 | trainer = Trainer(accelerator="cpu", callbacks=[callback], max_epochs=3) 14 | dataset = FakeData( 15 | size=10, image_size=(3, 32, 32), num_classes=10, transform=ToTensor() 16 | ) 17 | train_dataloader = DataLoader(dataset, batch_size=2) 18 | val_dataloader = DataLoader(dataset, batch_size=2) 19 | trainer.fit( 20 | _DummyModule(), 21 | train_dataloaders=train_dataloader, 22 | val_dataloaders=val_dataloader, 23 | ) 24 | assert callback.train_metrics["train_epoch"] == [0, 1, 2] 25 | assert callback.train_metrics["train_epoch_dict"] == [0, 1, 2] 26 | # test logs 2 * epoch in validation step 27 | assert callback.val_metrics["val_epoch"] == [0, 2, 4] 28 | assert callback.val_metrics["val_epoch_dict"] == [0, 2, 4] 29 | 30 | 31 | class _DummyModule(LightningModule): 32 | def __init__(self) -> None: 33 | super().__init__() 34 | 35 | def training_step(self, batch, batch_idx) -> None: 36 | self.log("train_epoch", self.trainer.current_epoch) 37 | self.log_dict({"train_epoch_dict": self.trainer.current_epoch}) 38 | 39 | def validation_step(self, batch, batch_idx) -> None: 40 | self.log("val_epoch", self.trainer.current_epoch * 2) 41 | self.log_dict({"val_epoch_dict": self.trainer.current_epoch * 2}) 42 | 43 | def configure_optimizers(self) -> None: 44 | return None 45 | -------------------------------------------------------------------------------- /lightly/openapi_generated/swagger_client/models/sector.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Lightly API 5 | 6 | Lightly.ai enables you to do self-supervised learning in an easy and intuitive way. The lightly.ai OpenAPI spec defines how one can interact with our REST API to unleash the full potential of lightly.ai # noqa: E501 7 | 8 | The version of the OpenAPI document: 1.0.0 9 | Contact: support@lightly.ai 10 | Generated by OpenAPI Generator (https://openapi-generator.tech) 11 | 12 | Do not edit the class manually. 13 | """ 14 | 15 | 16 | import json 17 | import pprint 18 | import re # noqa: F401 19 | from enum import Enum 20 | from aenum import no_arg # type: ignore 21 | 22 | 23 | 24 | 25 | 26 | class Sector(str, Enum): 27 | """ 28 | Sector 29 | """ 30 | 31 | """ 32 | allowed enum values 33 | """ 34 | ADVERTISING = 'ADVERTISING' 35 | AGRICULTURE = 'AGRICULTURE' 36 | AUTOMOTIVE = 'AUTOMOTIVE' 37 | EDUCATION = 'EDUCATION' 38 | ENERGY = 'ENERGY' 39 | ENTERTAINMENT = 'ENTERTAINMENT' 40 | ENVIRONMENTAL = 'ENVIRONMENTAL' 41 | FINANCE = 'FINANCE' 42 | FOOD = 'FOOD' 43 | HEALTHCARE = 'HEALTHCARE' 44 | INTERNET_OF_THINGS = 'INTERNET_OF_THINGS' 45 | LOGISTICS = 'LOGISTICS' 46 | MACHINE_LEARNING = 'MACHINE_LEARNING' 47 | MANUFACTURING = 'MANUFACTURING' 48 | MEDICINE = 'MEDICINE' 49 | RECYCLING = 'RECYCLING' 50 | RETAIL = 'RETAIL' 51 | ROBOTICS = 'ROBOTICS' 52 | SECURITY = 'SECURITY' 53 | SOFTWARE_DEVELOPMENT = 'SOFTWARE_DEVELOPMENT' 54 | SPORTS = 'SPORTS' 55 | SURVEILLANCE = 'SURVEILLANCE' 56 | TRANSPORTATION = 'TRANSPORTATION' 57 | OTHER = 'OTHER' 58 | 59 | @classmethod 60 | def from_json(cls, json_str: str) -> 'Sector': 61 | """Create an instance of Sector from a JSON string""" 62 | return Sector(json.loads(json_str)) 63 | 64 | 65 | -------------------------------------------------------------------------------- /docs/source/lightly.loss.rst: -------------------------------------------------------------------------------- 1 | .. _lightly-loss: 2 | 3 | lightly.loss 4 | ============ 5 | 6 | .. automodule:: lightly.loss 7 | 8 | .. autoclass:: lightly.loss.barlow_twins_loss.BarlowTwinsLoss 9 | :members: 10 | 11 | .. autoclass:: lightly.loss.dcl_loss.DCLLoss 12 | :members: 13 | 14 | .. autoclass:: lightly.loss.dcl_loss.DCLWLoss 15 | :members: 16 | 17 | .. autoclass:: lightly.loss.detcon_loss.DetConBLoss 18 | :members: 19 | 20 | .. autoclass:: lightly.loss.detcon_loss.DetConSLoss 21 | :members: 22 | 23 | .. autoclass:: lightly.loss.dino_loss.DINOLoss 24 | :members: 25 | 26 | .. autoclass:: lightly.loss.hypersphere_loss.HypersphereLoss 27 | :members: 28 | 29 | .. autoclass:: lightly.loss.ibot_loss.IBOTPatchLoss 30 | :members: 31 | 32 | .. autoclass:: lightly.loss.koleo_loss.KoLeoLoss 33 | :members: 34 | 35 | .. autoclass:: lightly.loss.memory_bank.MemoryBankModule 36 | :members: 37 | 38 | .. autoclass:: lightly.loss.mmcr_loss.MMCRLoss 39 | :members: 40 | 41 | .. autoclass:: lightly.loss.msn_loss.MSNLoss 42 | :members: 43 | 44 | .. autoclass:: lightly.loss.negative_cosine_similarity.NegativeCosineSimilarity 45 | :members: 46 | 47 | .. autoclass:: lightly.loss.ntx_ent_loss.NTXentLoss 48 | :members: 49 | 50 | .. autoclass:: lightly.loss.pmsn_loss.PMSNLoss 51 | :members: 52 | 53 | .. autoclass:: lightly.loss.pmsn_loss.PMSNCustomLoss 54 | :members: 55 | 56 | .. autoclass:: lightly.loss.regularizer.co2.CO2Regularizer 57 | :members: 58 | 59 | .. autoclass:: lightly.loss.swav_loss.SwaVLoss 60 | :members: 61 | 62 | .. autoclass:: lightly.loss.sym_neg_cos_sim_loss.SymNegCosineSimilarityLoss 63 | :members: 64 | 65 | .. autoclass:: lightly.loss.tico_loss.TiCoLoss 66 | :members: 67 | 68 | .. autoclass:: lightly.loss.vicreg_loss.VICRegLoss 69 | :members: 70 | 71 | .. autoclass:: lightly.loss.vicregl_loss.VICRegLLoss 72 | :members: 73 | -------------------------------------------------------------------------------- /tests/embedding/test_callbacks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from omegaconf import OmegaConf 3 | 4 | from lightly.embedding import callbacks 5 | 6 | 7 | def test_create_summary_callback(): 8 | summary_cb = callbacks.create_summary_callback( 9 | summary_callback_config=OmegaConf.create({"max_depth": 99}), 10 | trainer_config=OmegaConf.create(), 11 | ) 12 | assert summary_cb._max_depth == 99 13 | 14 | 15 | def test_create_summary_callback__weights_summary(): 16 | # If "weights_summary" is specified, it takes precedence. 17 | summary_cb = callbacks.create_summary_callback( 18 | summary_callback_config=OmegaConf.create({"max_depth": 99}), 19 | trainer_config=OmegaConf.create({"weights_summary": "top"}), 20 | ) 21 | assert summary_cb._max_depth == 1 22 | 23 | summary_cb = callbacks.create_summary_callback( 24 | summary_callback_config=OmegaConf.create({"max_depth": 99}), 25 | trainer_config=OmegaConf.create({"weights_summary": "full"}), 26 | ) 27 | assert summary_cb._max_depth == -1 28 | 29 | # If "weights_summary" is None or "None", normal config is applied. 30 | summary_cb = callbacks.create_summary_callback( 31 | summary_callback_config=OmegaConf.create({"max_depth": 99}), 32 | trainer_config=OmegaConf.create({"weights_summary": None}), 33 | ) 34 | assert summary_cb._max_depth == 99 35 | 36 | summary_cb = callbacks.create_summary_callback( 37 | summary_callback_config=OmegaConf.create({"max_depth": 99}), 38 | trainer_config=OmegaConf.create({"weights_summary": "None"}), 39 | ) 40 | assert summary_cb._max_depth == 99 41 | 42 | with pytest.raises(ValueError): 43 | callbacks.create_summary_callback( 44 | summary_callback_config=OmegaConf.create(), 45 | trainer_config=OmegaConf.create({"weights_summary": "invalid"}), 46 | ) 47 | -------------------------------------------------------------------------------- /lightly/loss/negative_cosine_similarity.py: -------------------------------------------------------------------------------- 1 | """ Negative Cosine Similarity Loss Function """ 2 | 3 | # Copyright (c) 2020. Lightly AG and its affiliates. 4 | # All Rights Reserved 5 | 6 | import torch 7 | from torch.nn.functional import cosine_similarity 8 | 9 | 10 | class NegativeCosineSimilarity(torch.nn.Module): 11 | """Implementation of the Negative Cosine Simililarity used in the SimSiam[0] paper. 12 | 13 | - [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566 14 | 15 | Examples: 16 | >>> # initialize loss function 17 | >>> loss_fn = NegativeCosineSimilarity() 18 | >>> 19 | >>> # generate two representation tensors 20 | >>> # with batch size 10 and dimension 128 21 | >>> x0 = torch.randn(10, 128) 22 | >>> x1 = torch.randn(10, 128) 23 | >>> 24 | >>> # calculate loss 25 | >>> loss = loss_fn(x0, x1) 26 | """ 27 | 28 | def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: 29 | """Initializes the NegativeCosineSimilarity module the specified parameters. 30 | 31 | Same parameters as in torch.nn.CosineSimilarity 32 | 33 | Args: 34 | dim: 35 | Dimension where cosine similarity is computed. 36 | eps: 37 | Small value to avoid division by zero. 38 | """ 39 | super().__init__() 40 | self.dim = dim 41 | self.eps = eps 42 | 43 | def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: 44 | """Computes the negative cosine similarity between two tensors. 45 | 46 | Args: 47 | x0: 48 | First input tensor. 49 | x1: 50 | Second input tensor. 51 | 52 | Returns: 53 | The mean negative cosine similarity. 54 | """ 55 | return -cosine_similarity(x0, x1, self.dim, self.eps).mean() 56 | -------------------------------------------------------------------------------- /tests/loss/test_co2.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lightly.loss.regularizer import CO2Regularizer 5 | 6 | 7 | class TestCO2Regularizer: 8 | @pytest.mark.parametrize("bsz", range(1, 20)) 9 | def test_forward_pass_no_memory_bank(self, bsz: int) -> None: 10 | reg = CO2Regularizer(memory_bank_size=0) 11 | batch_1 = torch.randn((bsz, 32)) 12 | batch_2 = torch.randn((bsz, 32)) 13 | 14 | # symmetry 15 | l1 = reg(batch_1, batch_2) 16 | l2 = reg(batch_2, batch_1) 17 | assert l1 == pytest.approx(l2) 18 | 19 | @pytest.mark.parametrize("bsz", range(1, 20)) 20 | def test_forward_pass_memory_bank(self, bsz: int) -> None: 21 | reg = CO2Regularizer(memory_bank_size=(4096, 32)) 22 | batch_1 = torch.randn((bsz, 32)) 23 | batch_2 = torch.randn((bsz, 32)) 24 | 25 | l1 = reg(batch_1, batch_2) 26 | assert l1 > 0 27 | 28 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") 29 | @pytest.mark.parametrize("bsz", range(1, 20)) 30 | def test_forward_pass_cuda_no_memory_bank(self, bsz: int) -> None: 31 | reg = CO2Regularizer(memory_bank_size=0) 32 | batch_1 = torch.randn((bsz, 32)).cuda() 33 | batch_2 = torch.randn((bsz, 32)).cuda() 34 | 35 | # symmetry 36 | l1 = reg(batch_1, batch_2) 37 | l2 = reg(batch_2, batch_1) 38 | assert l1 == pytest.approx(l2) 39 | 40 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No cuda") 41 | @pytest.mark.parametrize("bsz", range(1, 20)) 42 | def test_forward_pass_cuda_memory_bank(self, bsz: int) -> None: 43 | reg = CO2Regularizer(memory_bank_size=(4096, 32)) 44 | batch_1 = torch.randn((bsz, 32)).cuda() 45 | batch_2 = torch.randn((bsz, 32)).cuda() 46 | 47 | # symmetry 48 | l1 = reg(batch_1, batch_2) 49 | assert l1 > 0 50 | -------------------------------------------------------------------------------- /tests/utils/test_dist.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock 3 | 4 | import torch 5 | from pytest import CaptureFixture 6 | 7 | from lightly.utils import dist 8 | 9 | 10 | class TestDist(unittest.TestCase): 11 | def test_eye_rank_undist(self): 12 | self.assertTrue(torch.all(dist.eye_rank(3) == torch.eye(3))) 13 | 14 | def test_eye_rank_dist(self): 15 | n = 3 16 | zeros = torch.zeros((n, n)).bool() 17 | eye = torch.eye(n).bool() 18 | for world_size in [1, 3]: 19 | for rank in range(0, world_size): 20 | with mock.patch( 21 | "torch.distributed.is_initialized", lambda: True 22 | ), mock.patch( 23 | "lightly.utils.dist.world_size", lambda: world_size 24 | ), mock.patch( 25 | "lightly.utils.dist.rank", lambda: rank 26 | ): 27 | expected = [] 28 | for _ in range(0, rank): 29 | expected.append(zeros) 30 | expected.append(eye) 31 | for _ in range(rank, world_size - 1): 32 | expected.append(zeros) 33 | expected = torch.cat(expected, dim=1) 34 | self.assertTrue(torch.all(dist.eye_rank(n) == expected)) 35 | 36 | 37 | def test_rank_zero_only__rank_0() -> None: 38 | @dist.rank_zero_only 39 | def fn(): 40 | return 0 41 | 42 | assert fn() == 0 43 | 44 | 45 | def test_rank_zero_only__rank_1() -> None: 46 | @dist.rank_zero_only 47 | def fn(): 48 | return 0 49 | 50 | with mock.patch.object(dist, "rank", lambda: 1): 51 | assert fn() is None 52 | 53 | 54 | def test_print_rank_zero(capsys: CaptureFixture[str]) -> None: 55 | dist.print_rank_zero("message") 56 | assert capsys.readouterr().out == "message\n" 57 | -------------------------------------------------------------------------------- /tests/api/test_swagger_api_client.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from pytest_mock import MockerFixture 4 | 5 | from lightly.api.swagger_api_client import LightlySwaggerApiClient 6 | from lightly.api.swagger_rest_client import LightlySwaggerRESTClientObject 7 | from lightly.openapi_generated.swagger_client import Configuration 8 | from lightly.openapi_generated.swagger_client.rest import RESTResponse 9 | 10 | 11 | def test_pickle(mocker: MockerFixture) -> None: 12 | client = LightlySwaggerApiClient(configuration=Configuration(), timeout=5) 13 | client.last_response = mocker.MagicMock(spec_set=RESTResponse).return_value 14 | new_client = pickle.loads(pickle.dumps(client)) 15 | 16 | expected = { 17 | "_pool": None, 18 | "client_side_validation": True, 19 | # "configuration", ignore because some parts of configuration are recreated on unpickling 20 | "cookie": None, 21 | "default_headers": {"User-Agent": "OpenAPI-Generator/1.0.0/python"}, 22 | # "last_response", ignore because it is not copied during pickling 23 | # "rest_client", ignore because some parts of rest client are recreated on unpickling 24 | } 25 | # Check that all expected values are set except the ignored ones. 26 | assert all(hasattr(client, key) for key in expected.keys()) 27 | # Check that new client values are equal to expected values. 28 | assert all(new_client.__dict__[key] == value for key, value in expected.items()) 29 | 30 | # Extra assertions for attributes ignored in the tests above. 31 | assert isinstance(new_client.__dict__["configuration"], Configuration) 32 | assert isinstance( 33 | new_client.__dict__["rest_client"], LightlySwaggerRESTClientObject 34 | ) 35 | # Last reponse is completely removed from client object and is only dynamically 36 | # reassigned in the ApiClient.__call_api method. 37 | assert not hasattr(new_client, "last_response") 38 | -------------------------------------------------------------------------------- /lightly/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """The lightly.models.modules package provides reusable modules. 2 | 3 | This package contains reusable modules such as the NNmemoryBankModule which 4 | can be combined with any lightly model. 5 | 6 | """ 7 | 8 | # Copyright (c) 2021. Lightly AG and its affiliates. 9 | # All Rights Reserved 10 | 11 | 12 | from lightly.models.modules.heads import ( 13 | BarlowTwinsProjectionHead, 14 | BYOLPredictionHead, 15 | BYOLProjectionHead, 16 | DenseCLProjectionHead, 17 | DINOProjectionHead, 18 | MMCRProjectionHead, 19 | MoCoProjectionHead, 20 | NNCLRPredictionHead, 21 | NNCLRProjectionHead, 22 | SimCLRProjectionHead, 23 | SimSiamPredictionHead, 24 | SimSiamProjectionHead, 25 | SMoGPredictionHead, 26 | SMoGProjectionHead, 27 | SMoGPrototypes, 28 | SwaVProjectionHead, 29 | SwaVPrototypes, 30 | ) 31 | from lightly.models.modules.nn_memory_bank import NNMemoryBankModule 32 | from lightly.utils import dependency as _dependency 33 | 34 | if _dependency.torchvision_vit_available(): 35 | # Requires torchvision >=0.12 36 | from lightly.models.modules.masked_autoencoder import ( 37 | MAEBackbone, 38 | MAEDecoder, 39 | MAEEncoder, 40 | ) 41 | from lightly.models.modules.masked_vision_transformer_torchvision import ( 42 | MaskedVisionTransformerTorchvision, 43 | ) 44 | 45 | if _dependency.timm_vit_available(): 46 | # Requires timm >= 0.9.9 47 | from lightly.models.modules.heads_timm import AIMPredictionHead 48 | from lightly.models.modules.ijepa_timm import IJEPAPredictorTIMM 49 | from lightly.models.modules.masked_autoencoder_timm import MAEDecoderTIMM 50 | from lightly.models.modules.masked_causal_vision_transformer import ( 51 | MaskedCausalVisionTransformer, 52 | ) 53 | from lightly.models.modules.masked_vision_transformer_timm import ( 54 | MaskedVisionTransformerTIMM, 55 | ) 56 | -------------------------------------------------------------------------------- /tests/loss/test_ibot_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from lightly.loss.ibot_loss import IBOTPatchLoss 5 | 6 | 7 | class TestIBOTPatchLoss: 8 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 9 | def test_forward(self, device: str) -> None: 10 | if not torch.cuda.is_available() and device == "cuda": 11 | pytest.skip("CUDA not available") 12 | 13 | criterion = IBOTPatchLoss( 14 | output_dim=2, 15 | teacher_temp=0.1, 16 | student_temp=0.2, 17 | center_mode="mean", 18 | center_momentum=0.9, 19 | ) 20 | teacher_out = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) 21 | student_out = torch.tensor([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]) 22 | mask = torch.tensor( 23 | [ 24 | [[True, False], [True, False]], 25 | [[False, False], [False, True]], 26 | [[False, False], [False, False]], 27 | ] 28 | ) 29 | 30 | loss = criterion.forward( 31 | teacher_out=teacher_out, student_out=student_out, mask=mask 32 | ) 33 | assert loss == pytest.approx(0.4057, rel=0.0001) 34 | expected_center = 0.1 * teacher_out.mean(0) 35 | assert torch.all(torch.isclose(criterion.center.value, expected_center)) 36 | # Loss value was calculated with the original implementation from: 37 | # https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/ibot_patch_loss.py 38 | # 39 | # Code: 40 | # orig_criterion = iBOTPatchLoss(patch_out_dim=2, student_temp=0.2) 41 | # orig_t_center = orig_criterion.softmax_center_teacher(teacher_out, 0.1) 42 | # orig_loss = orig_criterion.forward_masked( 43 | # student_patch_tokens_masked=student_out, 44 | # teacher_patch_tokens_masked=orig_t_center, 45 | # student_masks_flat=mask.flatten(start_dim=1), 46 | # ) 47 | -------------------------------------------------------------------------------- /lightly/utils/hipify.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import warnings 3 | from typing import Optional, Type, Union 4 | 5 | 6 | class bcolors: 7 | """ANSI escape sequences for colored terminal output.""" 8 | 9 | HEADER = "\033[95m" 10 | OKBLUE = "\033[94m" 11 | OKGREEN = "\033[92m" 12 | WARNING = "\033[93m" 13 | FAIL = "\033[91m" 14 | ENDC = "\033[0m" 15 | BOLD = "\033[1m" 16 | UNDERLINE = "\033[4m" 17 | 18 | 19 | def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning) -> None: 20 | """Prints a warning message with custom formatting. 21 | 22 | Temporarily overrides the default warning format to apply custom styling, then 23 | restores the original formatting after the warning is printed. 24 | 25 | Args: 26 | message: 27 | The warning message to print. 28 | warning_class: 29 | The type of warning to raise. 30 | 31 | """ 32 | old_format = copy.copy(warnings.formatwarning) 33 | warnings.formatwarning = _custom_formatwarning 34 | warnings.warn(message, warning_class) 35 | warnings.formatwarning = old_format 36 | 37 | 38 | def _custom_formatwarning( 39 | message: Union[str, Warning], 40 | category: Type[Warning], 41 | filename: str, 42 | lineno: int, 43 | line: Optional[str] = None, 44 | ) -> str: 45 | """Custom format for warning messages. 46 | 47 | Only the warning message is printed, with additional styling applied. 48 | 49 | Args: 50 | message: 51 | The warning message or warning object. 52 | category: 53 | The warning class. 54 | filename: 55 | The file where the warning originated. 56 | lineno: 57 | The line number where the warning occurred. 58 | line: 59 | The line of code that triggered the warning (if available). 60 | 61 | Returns: 62 | str: The formatted warning message. 63 | 64 | """ 65 | return f"{bcolors.WARNING}{message}{bcolors.WARNING}\n" 66 | -------------------------------------------------------------------------------- /lightly/utils/benchmarking/online_linear_classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | from pytorch_lightning import LightningModule 4 | from torch import Tensor 5 | from torch.nn import CrossEntropyLoss, Linear 6 | 7 | from lightly.utils.benchmarking.topk import mean_topk_accuracy 8 | 9 | 10 | class OnlineLinearClassifier(LightningModule): 11 | def __init__( 12 | self, 13 | feature_dim: int = 2048, 14 | num_classes: int = 1000, 15 | topk: Tuple[int, ...] = (1, 5), 16 | ) -> None: 17 | super().__init__() 18 | self.feature_dim = feature_dim 19 | self.num_classes = num_classes 20 | self.topk = topk 21 | 22 | self.classification_head = Linear(feature_dim, num_classes) 23 | self.criterion = CrossEntropyLoss() 24 | 25 | def forward(self, x: Tensor) -> Tensor: 26 | return self.classification_head(x.detach().flatten(start_dim=1)) 27 | 28 | def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]: 29 | features, targets = batch[0], batch[1] 30 | predictions = self.forward(features) 31 | loss = self.criterion(predictions, targets) 32 | _, predicted_classes = predictions.topk(max(self.topk)) 33 | topk = mean_topk_accuracy(predicted_classes, targets, k=self.topk) 34 | return loss, topk 35 | 36 | def training_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[str, Tensor]]: 37 | loss, topk = self.shared_step(batch=batch, batch_idx=batch_idx) 38 | log_dict = {"train_online_cls_loss": loss} 39 | log_dict.update({f"train_online_cls_top{k}": acc for k, acc in topk.items()}) 40 | return loss, log_dict 41 | 42 | def validation_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[str, Tensor]]: 43 | loss, topk = self.shared_step(batch=batch, batch_idx=batch_idx) 44 | log_dict = {"val_online_cls_loss": loss} 45 | log_dict.update({f"val_online_cls_top{k}": acc for k, acc in topk.items()}) 46 | return loss, log_dict 47 | --------------------------------------------------------------------------------