├── deeptlf.egg-info ├── dependency_links.txt ├── top_level.txt ├── requires.txt ├── SOURCES.txt └── PKG-INFO ├── .coverage ├── pipeline.png ├── dist ├── deeptlf-0.3.1.tar.gz └── deeptlf-0.3.1-py3-none-any.whl ├── .gitignore ├── deeptlf ├── __init__.py ├── tde.py └── deeptlf.py ├── LICENSE ├── tests ├── conftest.py ├── README.md └── test_deeptlf.py ├── examples ├── README.md ├── minimal_example.py └── DeepTLF_example.ipynb ├── README.md └── pyproject.toml /deeptlf.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deeptlf.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | deeptlf 2 | -------------------------------------------------------------------------------- /.coverage: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unnir/DeepTLF/HEAD/.coverage -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unnir/DeepTLF/HEAD/pipeline.png -------------------------------------------------------------------------------- /dist/deeptlf-0.3.1.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unnir/DeepTLF/HEAD/dist/deeptlf-0.3.1.tar.gz -------------------------------------------------------------------------------- /dist/deeptlf-0.3.1-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unnir/DeepTLF/HEAD/dist/deeptlf-0.3.1-py3-none-any.whl -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache/ 2 | *.pyc 3 | *.pt 4 | 5 | .idea/ 6 | .ipynb_checkpoints/ 7 | 8 | catboost_info/ 9 | logs/ 10 | output/ 11 | 12 | *.db 13 | -------------------------------------------------------------------------------- /deeptlf/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeptlf import DeepTFL 2 | from .tde import TreeDrivenEncoder 3 | 4 | __version__ = "0.1.0" 5 | __author__ = "Vadim Borisov" 6 | 7 | __all__ = ["DeepTFL", "TreeDrivenEncoder"] 8 | -------------------------------------------------------------------------------- /deeptlf.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.0 2 | torch>=1.7.0 3 | scikit-learn>=0.24.0 4 | xgboost>=1.3.0 5 | tqdm>=4.50.0 6 | 7 | [examples] 8 | jupyter>=1.0.0 9 | matplotlib>=3.3.0 10 | 11 | [test] 12 | pytest>=7.0.0 13 | pytest-cov>=4.0.0 14 | -------------------------------------------------------------------------------- /deeptlf.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | .gitignore 2 | LICENSE 3 | README.md 4 | pipeline.png 5 | pyproject.toml 6 | deeptlf/__init__.py 7 | deeptlf/deeptlf.py 8 | deeptlf/tde.py 9 | deeptlf.egg-info/PKG-INFO 10 | deeptlf.egg-info/SOURCES.txt 11 | deeptlf.egg-info/dependency_links.txt 12 | deeptlf.egg-info/requires.txt 13 | deeptlf.egg-info/top_level.txt 14 | tests/test_deeptlf.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vadim Borisov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | 5 | @pytest.fixture 6 | def sample_data(): 7 | """Generate sample data for testing.""" 8 | np.random.seed(42) 9 | X = np.random.randn(100, 10) # 100 samples, 10 features 10 | y_reg = np.random.randn(100, 1) # Regression target with shape (n_samples, 1) 11 | y_class = np.random.randint(0, 3, 100) # Classification target (3 classes) 12 | return { 13 | 'X': X, 14 | 'y_reg': y_reg, 15 | 'y_class': y_class 16 | } 17 | 18 | @pytest.fixture 19 | def device(): 20 | """Return available device (cuda if available, else cpu).""" 21 | return torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | @pytest.fixture 24 | def model_params(): 25 | """Return default model parameters for testing.""" 26 | return { 27 | 'n_est': 5, # Small number for quick testing 28 | 'max_depth': 3, 29 | 'drop': 0.1, 30 | 'xgb_lr': 0.1, 31 | 'batch_size': 32, 32 | 'n_epoch': 2, # Small number for quick testing 33 | 'hidden_dim': 64, 34 | 'n_layers': 2, 35 | 'debug': True 36 | } -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # DeepTLF Tests 2 | 3 | This directory contains tests for the DeepTLF package. The tests cover both CPU and CUDA functionality, along with various edge cases and error conditions. 4 | 5 | ## Test Structure 6 | 7 | - `conftest.py`: Contains pytest fixtures for data generation and model configuration 8 | - `test_deeptlf.py`: Main test file containing all test cases 9 | 10 | ## Running Tests 11 | 12 | To run all tests: 13 | ```bash 14 | pytest tests/ 15 | ``` 16 | 17 | To run tests with detailed output: 18 | ```bash 19 | pytest tests/ -v 20 | ``` 21 | 22 | To run a specific test: 23 | ```bash 24 | pytest tests/test_deeptlf.py::test_name 25 | ``` 26 | 27 | ## Test Coverage 28 | 29 | The tests cover: 30 | 1. Model initialization and parameter validation 31 | 2. Classification and regression training 32 | 3. Device compatibility (CPU/CUDA) 33 | 4. Input validation and error handling 34 | 5. Model saving and loading 35 | 6. Edge cases (empty inputs, NaN values) 36 | 37 | ## Requirements 38 | 39 | Make sure you have pytest installed: 40 | ```bash 41 | pip install pytest 42 | ``` 43 | 44 | ## Notes 45 | 46 | - Tests use small datasets and few epochs to run quickly 47 | - CUDA tests will automatically skip if no GPU is available 48 | - Some tests use temporary directories for model checkpoints -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # DeepTLF Examples 2 | 3 | This directory contains examples demonstrating how to use the DeepTLF package. 4 | 5 | ## Files 6 | 7 | 1. `minimal_example.py`: A minimal Python script demonstrating both classification and regression tasks using synthetic data. 8 | 2. `DeepTLF_example.ipynb`: A comprehensive Jupyter notebook with detailed examples and explanations. 9 | 10 | ## Running the Examples 11 | 12 | ### Minimal Example 13 | 14 | To run the minimal example: 15 | 16 | ```bash 17 | python minimal_example.py 18 | ``` 19 | 20 | This script will: 21 | 1. Generate synthetic data for both classification and regression 22 | 2. Train DeepTLF models on both tasks 23 | 3. Make predictions and show performance metrics 24 | 25 | ### Jupyter Notebook 26 | 27 | To run the notebook: 28 | 29 | ```bash 30 | jupyter notebook DeepTLF_example.ipynb 31 | ``` 32 | 33 | The notebook contains: 34 | - Detailed explanations of DeepTLF functionality 35 | - Real-world dataset examples 36 | - Visualization of results 37 | - Parameter tuning examples 38 | 39 | ## Requirements 40 | 41 | Make sure you have DeepTLF installed: 42 | ```bash 43 | pip install deeptlf 44 | ``` 45 | 46 | Additional requirements for the examples: 47 | - numpy 48 | - scikit-learn 49 | - jupyter (for notebook example) 50 | - matplotlib (for visualizations in notebook) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepTLF: A Framework for Enhanced Deep Learning on Tabular Data 2 | 3 | ![DeepTLF Pipeline](pipeline.png) 4 | 5 | ## Overview 6 | 7 | **DeepTLF** significantly outperforms traditional Deep Neural Networks (DNNs) in handling tabular data. Using our novel TreeDrivenEncoder, we transform complex, heterogeneous data into a format highly compatible with DNNs. This enables a 19.6% average performance increase compared to conventional DNNs. 8 | 9 | ## Installation 10 | 11 | You can install DeepTLF directly from PyPI: 12 | 13 | ```bash 14 | pip install deeptlf 15 | ``` 16 | 17 | ## Quick Start 18 | 19 | Seamlessly integrate DeepTLF into your workflow through its scikit-learn-compatible API: 20 | 21 | ```python 22 | from deeptlf import DeepTFL 23 | 24 | # Initialize and train model 25 | dtlf_model = DeepTFL(n_est=23, max_depth=3, drop=0.23, n_layers=4, task='class') 26 | dtlf_model.fit(X_train, y_train) 27 | 28 | # Make predictions 29 | dtlf_y_hat = dtlf_model.predict(X_test) 30 | ``` 31 | 32 | ## Features 33 | - Transforms heterogeneous data into DNN-friendly format 34 | - Supports multimodal learning 35 | - Adheres to the scikit-learn API for effortless integration 36 | - Features advanced options like custom layers, dropout rates, and more 37 | 38 | 39 | ## Citation 40 | To cite DeepTLF in your work: 41 | ```bib 42 | @article{borisov2022deeptlf, 43 | title={DeepTLF: robust deep neural networks for heterogeneous tabular data}, 44 | author={Borisov, Vadim and Broelemann, Klaus and Kasneci, Enkelejda and Kasneci, Gjergji}, 45 | journal={International Journal of Data Science and Analytics}, 46 | pages={1--16}, 47 | year={2022}, 48 | publisher={Springer} 49 | } 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel", "setuptools-scm>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "deeptlf" 7 | version = "0.3.1" 8 | description = "Deep Tabular Learning Framework" 9 | readme = "README.md" 10 | license = {file = "LICENSE"} 11 | authors = [ 12 | {name = "Vadim Borisov", email = "vadim@tabularis.ai"} 13 | ] 14 | maintainers = [ 15 | {name = "Vadim Borisov", email = "vadim@tabularis.ai"} 16 | ] 17 | requires-python = ">=3.7" 18 | keywords = ["deep-learning", "tabular-data", "machine-learning", "neural-networks"] 19 | classifiers = [ 20 | "Development Status :: 4 - Beta", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.7", 25 | "Programming Language :: Python :: 3.8", 26 | "Programming Language :: Python :: 3.9", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 28 | ] 29 | dependencies = [ 30 | "numpy>=1.19.0", 31 | "torch>=1.7.0", 32 | "scikit-learn>=0.24.0", 33 | "xgboost>=1.3.0", 34 | "tqdm>=4.50.0" 35 | ] 36 | 37 | [project.optional-dependencies] 38 | test = [ 39 | "pytest>=7.0.0", 40 | "pytest-cov>=4.0.0", 41 | ] 42 | examples = [ 43 | "jupyter>=1.0.0", 44 | "matplotlib>=3.3.0", 45 | ] 46 | 47 | [project.urls] 48 | Homepage = "https://github.com/unnir/deeptlf" 49 | Documentation = "https://github.com/unnir/deeptlf#readme" 50 | Repository = "https://github.com/unnir/deeptlf.git" 51 | Issues = "https://github.com/unnir/deeptlf/issues" 52 | 53 | [tool.setuptools] 54 | packages = ["deeptlf"] 55 | 56 | [tool.pytest.ini_options] 57 | testpaths = ["tests"] 58 | python_files = ["test_*.py"] 59 | addopts = "-v --cov=deeptlf --cov-report=term-missing" 60 | -------------------------------------------------------------------------------- /examples/minimal_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal example demonstrating DeepTLF usage for both classification and regression tasks. 3 | """ 4 | 5 | import numpy as np 6 | from sklearn.datasets import make_classification, make_regression 7 | from sklearn.model_selection import train_test_split 8 | from deeptlf import DeepTFL 9 | 10 | def run_classification_example(): 11 | """Run classification example using synthetic data.""" 12 | print("\n=== Classification Example ===") 13 | 14 | # Generate synthetic classification data 15 | X, y = make_classification( 16 | n_samples=1000, 17 | n_features=20, 18 | n_informative=15, 19 | n_redundant=5, 20 | n_classes=3, 21 | random_state=42 22 | ) 23 | 24 | # Split data 25 | X_train, X_test, y_train, y_test = train_test_split( 26 | X, y, test_size=0.2, random_state=42 27 | ) 28 | 29 | # Initialize and train model 30 | model = DeepTFL( 31 | task='class', 32 | n_est=10, 33 | max_depth=3, 34 | n_epoch=5, 35 | hidden_dim=64, 36 | drop=0.1, 37 | debug=True 38 | ) 39 | 40 | print("Training classification model...") 41 | model.fit(X_train, y_train) 42 | 43 | # Make predictions 44 | y_pred = model.predict(X_test) 45 | accuracy = np.mean(y_pred == y_test) 46 | print(f"Classification accuracy: {accuracy:.4f}") 47 | 48 | def run_regression_example(): 49 | """Run regression example using synthetic data.""" 50 | print("\n=== Regression Example ===") 51 | 52 | # Generate synthetic regression data 53 | X, y = make_regression( 54 | n_samples=1000, 55 | n_features=20, 56 | n_informative=15, 57 | noise=0.1, 58 | random_state=42 59 | ) 60 | 61 | y = y.reshape(-1, 1) 62 | 63 | # Split data 64 | X_train, X_test, y_train, y_test = train_test_split( 65 | X, y, test_size=0.2, random_state=42 66 | ) 67 | 68 | # Initialize and train model 69 | model = DeepTFL( 70 | task='reg', 71 | n_est=10, 72 | max_depth=3, 73 | n_epoch=5, 74 | hidden_dim=64, 75 | drop=0.1, 76 | debug=True 77 | ) 78 | 79 | print("Training regression model...") 80 | model.fit(X_train, y_train) 81 | 82 | # Make predictions 83 | y_pred = model.predict(X_test) 84 | mse = np.mean((y_pred - y_test) ** 2) 85 | print(f"Regression MSE: {mse:.4f}") 86 | 87 | if __name__ == "__main__": 88 | # Run both examples 89 | run_classification_example() 90 | run_regression_example() 91 | -------------------------------------------------------------------------------- /deeptlf/tde.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from typing import List 4 | from collections import defaultdict, Counter 5 | 6 | node_info_re = re.compile(r"f(\d+)<([0-9\.]+)") 7 | 8 | 9 | def parse_node_info(node_info: str): 10 | matched = node_info_re.findall(node_info) 11 | if not matched: 12 | return None, None 13 | feature, threshold = matched[0] 14 | return int(feature), float(threshold) 15 | 16 | 17 | def extract_node_data_from_tree(tree_number, tree): 18 | node_data = defaultdict(dict) 19 | for node in tree.split("\n"): 20 | cleaned_node = re.sub(r"\s+", "", node) 21 | if ":" not in cleaned_node: 22 | continue 23 | node_id, node_info = map(str.strip, cleaned_node.split(":")) 24 | feature, threshold = parse_node_info(node_info) 25 | if feature is not None and threshold is not None: 26 | node_data[(tree_number, int(node_id))] = { 27 | "feature": feature, 28 | "threshold": threshold, 29 | } 30 | return node_data 31 | 32 | 33 | class TreeDrivenEncoder: 34 | def __init__(self, min_freq=2): 35 | self.all_conditions = [] 36 | self.min_freq = min_freq 37 | 38 | def fit(self, trees: List[str]): 39 | node_data = defaultdict(dict) 40 | for tree_number, tree in enumerate(trees): 41 | node_data.update(extract_node_data_from_tree(tree_number, tree)) 42 | 43 | feature_counter = Counter(data["feature"] for data in node_data.values()) 44 | 45 | self.all_conditions = [ 46 | data 47 | for key, data in sorted(node_data.items()) 48 | if feature_counter[data["feature"]] >= self.min_freq 49 | ] 50 | 51 | def transform(self, X): 52 | X = np.array(X) 53 | 54 | # Validate input 55 | if X.ndim != 2: 56 | raise ValueError("Input X must be a 2-dimensional array") 57 | if not self.all_conditions: 58 | raise ValueError("TreeDrivenEncoder has not been fitted. Call fit() first.") 59 | 60 | # Get the maximum feature index used in conditions 61 | max_feature_idx = max(condition["feature"] for condition in self.all_conditions) 62 | if X.shape[1] <= max_feature_idx: 63 | raise ValueError(f"Input X has {X.shape[1]} features, but the encoder requires at least {max_feature_idx + 1} features") 64 | 65 | n_conditions = len(self.all_conditions) 66 | encoded_X = np.zeros((X.shape[0], n_conditions), dtype=int) 67 | 68 | try: 69 | for i, condition in enumerate(self.all_conditions): 70 | encoded_X[:, i] = ( 71 | X[:, condition["feature"]] < condition["threshold"] 72 | ).astype(int) 73 | except Exception as e: 74 | raise RuntimeError(f"Error during transformation: {str(e)}") 75 | 76 | return encoded_X 77 | -------------------------------------------------------------------------------- /deeptlf.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: deeptlf 3 | Version: 0.3.1 4 | Summary: Deep Tabular Learning Framework 5 | Author-email: Vadim Borisov 6 | Maintainer-email: Vadim Borisov 7 | License: MIT License 8 | 9 | Copyright (c) 2024 Vadim Borisov 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | 29 | Project-URL: Homepage, https://github.com/unnir/deeptlf 30 | Project-URL: Documentation, https://github.com/unnir/deeptlf#readme 31 | Project-URL: Repository, https://github.com/unnir/deeptlf.git 32 | Project-URL: Issues, https://github.com/unnir/deeptlf/issues 33 | Keywords: deep-learning,tabular-data,machine-learning,neural-networks 34 | Classifier: Development Status :: 4 - Beta 35 | Classifier: License :: OSI Approved :: MIT License 36 | Classifier: Operating System :: OS Independent 37 | Classifier: Programming Language :: Python :: 3 38 | Classifier: Programming Language :: Python :: 3.7 39 | Classifier: Programming Language :: Python :: 3.8 40 | Classifier: Programming Language :: Python :: 3.9 41 | Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence 42 | Requires-Python: >=3.7 43 | Description-Content-Type: text/markdown 44 | License-File: LICENSE 45 | Requires-Dist: numpy>=1.19.0 46 | Requires-Dist: torch>=1.7.0 47 | Requires-Dist: scikit-learn>=0.24.0 48 | Requires-Dist: xgboost>=1.3.0 49 | Requires-Dist: tqdm>=4.50.0 50 | Provides-Extra: test 51 | Requires-Dist: pytest>=7.0.0; extra == "test" 52 | Requires-Dist: pytest-cov>=4.0.0; extra == "test" 53 | Provides-Extra: examples 54 | Requires-Dist: jupyter>=1.0.0; extra == "examples" 55 | Requires-Dist: matplotlib>=3.3.0; extra == "examples" 56 | 57 | # DeepTLF: A Framework for Enhanced Deep Learning on Tabular Data 58 | 59 | ![DeepTLF Pipeline](pipeline.png) 60 | 61 | ## Overview 62 | 63 | **DeepTLF** significantly outperforms traditional Deep Neural Networks (DNNs) in handling tabular data. Using our novel TreeDrivenEncoder, we transform complex, heterogeneous data into a format highly compatible with DNNs. This enables a 19.6% average performance increase compared to conventional DNNs. 64 | 65 | ## Installation 66 | 67 | You can install DeepTLF directly from PyPI: 68 | 69 | ```bash 70 | pip install deeptlf 71 | ``` 72 | 73 | ## Quick Start 74 | 75 | Seamlessly integrate DeepTLF into your workflow through its scikit-learn-compatible API: 76 | 77 | ```python 78 | from deeptlf import DeepTFL 79 | 80 | # Initialize and train model 81 | dtlf_model = DeepTFL(n_est=23, max_depth=3, drop=0.23, n_layers=4, task='class') 82 | dtlf_model.fit(X_train, y_train) 83 | 84 | # Make predictions 85 | dtlf_y_hat = dtlf_model.predict(X_test) 86 | ``` 87 | 88 | ## Features 89 | - Transforms heterogeneous data into DNN-friendly format 90 | - Supports multimodal learning 91 | - Adheres to the scikit-learn API for effortless integration 92 | - Features advanced options like custom layers, dropout rates, and more 93 | 94 | 95 | ## Citation 96 | To cite DeepTLF in your work: 97 | ```bib 98 | @article{borisov2022deeptlf, 99 | title={DeepTLF: robust deep neural networks for heterogeneous tabular data}, 100 | author={Borisov, Vadim and Broelemann, Klaus and Kasneci, Enkelejda and Kasneci, Gjergji}, 101 | journal={International Journal of Data Science and Analytics}, 102 | pages={1--16}, 103 | year={2022}, 104 | publisher={Springer} 105 | } 106 | -------------------------------------------------------------------------------- /tests/test_deeptlf.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | from deeptlf import DeepTFL 5 | 6 | def test_model_initialization(model_params): 7 | """Test if model initializes correctly with different parameters.""" 8 | # Test classification 9 | model_class = DeepTFL(task='class', **model_params) 10 | assert model_class.task == 'class' 11 | 12 | # Test regression 13 | model_reg = DeepTFL(task='reg', **model_params) 14 | assert model_reg.task == 'reg' 15 | 16 | # Test device assignment 17 | assert hasattr(model_class, 'device') 18 | assert isinstance(model_class.device, torch.device) 19 | 20 | def test_invalid_parameters(): 21 | """Test if model raises appropriate errors for invalid parameters.""" 22 | with pytest.raises(ValueError): 23 | DeepTFL(n_est=-1) # Invalid n_est 24 | 25 | with pytest.raises(ValueError): 26 | DeepTFL(drop=2.0) # Invalid dropout 27 | 28 | with pytest.raises(ValueError): 29 | DeepTFL(task='invalid') # Invalid task 30 | 31 | def test_classification_training(sample_data, model_params, device): 32 | """Test classification model training.""" 33 | X, y = sample_data['X'], sample_data['y_class'] 34 | 35 | model = DeepTFL(task='class', **model_params) 36 | model.fit(X, y) 37 | 38 | # Test if model components are properly initialized 39 | assert model.xgb_model is not None 40 | assert model.nn_model is not None 41 | assert model.TDE_encoder is not None 42 | 43 | # Test predictions 44 | y_pred = model.predict(X) 45 | assert len(y_pred) == len(y) 46 | assert np.all(y_pred >= 0) and np.all(y_pred < 3) # 3 classes 47 | 48 | def test_regression_training(sample_data, model_params, device): 49 | """Test regression model training.""" 50 | X, y = sample_data['X'], sample_data['y_reg'] 51 | # Reshape y to match model output shape 52 | y = y.reshape(-1, 1) 53 | 54 | model = DeepTFL(task='reg', **model_params) 55 | model.fit(X, y) 56 | 57 | # Test predictions 58 | y_pred = model.predict(X) 59 | assert len(y_pred) == len(y) 60 | assert y_pred.shape == y.shape 61 | assert isinstance(y_pred, np.ndarray) 62 | 63 | def test_empty_input(model_params): 64 | """Test model behavior with empty input.""" 65 | model = DeepTFL(**model_params) 66 | 67 | with pytest.raises(ValueError, match="Empty input data"): 68 | model.fit(np.array([]), np.array([])) 69 | 70 | with pytest.raises(ValueError, match="Empty input data"): 71 | model.predict(np.array([])) 72 | 73 | def test_nan_input(sample_data, model_params): 74 | """Test model behavior with NaN input.""" 75 | X, y = sample_data['X'].copy(), sample_data['y_class'].copy() 76 | X[0, 0] = np.nan 77 | 78 | model = DeepTFL(**model_params) 79 | 80 | with pytest.raises(ValueError, match="Input contains NaN values"): 81 | model.fit(X, y) 82 | 83 | def test_input_validation(sample_data, model_params): 84 | """Test input validation during fit and predict.""" 85 | X, y = sample_data['X'], sample_data['y_class'] 86 | model = DeepTFL(**model_params) 87 | 88 | # Test mismatched lengths 89 | with pytest.raises(ValueError): 90 | model.fit(X, y[:-1]) 91 | 92 | # Test wrong dimensions 93 | with pytest.raises(ValueError): 94 | model.fit(X.reshape(-1), y) 95 | 96 | def test_device_compatibility(sample_data, model_params, device): 97 | """Test model works on both CPU and CUDA (if available).""" 98 | X, y = sample_data['X'], sample_data['y_class'] 99 | model = DeepTFL(**model_params) 100 | model.fit(X, y) 101 | 102 | # Check if model is on the correct device 103 | assert next(model.nn_model.parameters()).device == device 104 | 105 | # Test predictions work on the device 106 | y_pred = model.predict(X) 107 | assert isinstance(y_pred, np.ndarray) 108 | 109 | def test_model_save_load(sample_data, model_params, tmp_path): 110 | """Test model checkpoint saving and loading.""" 111 | X, y = sample_data['X'], sample_data['y_class'] 112 | checkpoint_path = tmp_path / "model_checkpoint.pt" 113 | 114 | # Train and save model 115 | model = DeepTFL(checkpoint_name=str(checkpoint_path), **model_params) 116 | model.fit(X, y) 117 | 118 | # Load model and make predictions 119 | new_model = DeepTFL(checkpoint_name=str(checkpoint_path), **model_params) 120 | new_model.TDE_encoder = model.TDE_encoder # Need to share the encoder 121 | new_model.input_shape = model.input_shape 122 | new_model.nn_model = model.nn_model.__class__( 123 | model.input_shape, 124 | model.hidden_dim, 125 | model.n_layers, 126 | len(np.unique(y)), 127 | model.drop 128 | ).to(model.device) 129 | 130 | y_pred = new_model.predict(X) 131 | assert len(y_pred) == len(y) 132 | 133 | def test_output_shapes(sample_data, model_params, tmp_path): 134 | """Test output shapes for both regression and classification.""" 135 | # Test regression shapes 136 | X, y_reg = sample_data['X'], sample_data['y_reg'] 137 | reg_checkpoint = str(tmp_path / "reg_checkpoint.pt") 138 | reg_model = DeepTFL(task='reg', checkpoint_name=reg_checkpoint, **model_params) 139 | reg_model.fit(X, y_reg) 140 | reg_pred = reg_model.predict(X) 141 | assert reg_pred.shape == (len(X), 1) # Regression output should be (n_samples, 1) 142 | 143 | # Test classification shapes 144 | y_class = sample_data['y_class'] 145 | class_checkpoint = str(tmp_path / "class_checkpoint.pt") 146 | class_model = DeepTFL(task='class', checkpoint_name=class_checkpoint, **model_params) 147 | class_model.fit(X, y_class) 148 | class_pred = class_model.predict(X) 149 | assert class_pred.shape == (len(X),) # Classification output should be (n_samples,) 150 | 151 | # Test single sample prediction shapes 152 | single_X = X[0:1] # Single sample with shape (1, n_features) 153 | reg_single_pred = reg_model.predict(single_X) 154 | class_single_pred = class_model.predict(single_X) 155 | assert reg_single_pred.shape == (1, 1) # Single regression prediction 156 | assert class_single_pred.shape == (1,) # Single classification prediction -------------------------------------------------------------------------------- /deeptlf/deeptlf.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | from sklearn.base import BaseEstimator 9 | from sklearn.model_selection import train_test_split 10 | 11 | import xgboost as xgb 12 | from tqdm import tqdm 13 | 14 | from .tde import TreeDrivenEncoder 15 | 16 | 17 | class DeepTFL(BaseEstimator): 18 | """ 19 | A deep learning model based on XGBoost and a custom neural network. 20 | 21 | Parameters 22 | ---------- 23 | n_est : int, optional 24 | Number of estimators for XGBoost model, default is 23. 25 | max_depth : int, optional 26 | Max depth for each tree in XGBoost, default is 4. 27 | n_epoch : int, optional 28 | Number of epochs for neural network training, default is 100. 29 | hidden_dim : int, optional 30 | Hidden layer dimensions for neural network, default is 128. 31 | drop : float, optional 32 | Dropout rate for neural network, default is 0.23. 33 | xgb_lr : float, optional 34 | Learning rate for XGBoost model, default is 0.5. 35 | n_layers : int, optional 36 | Number of layers in the neural network, default is 4. 37 | checkpoint_name : str, optional 38 | File name to save the neural network model, default is 'checkpoint.pt'. 39 | batch_size : int, optional 40 | Batch size for neural network training, default is 320. 41 | task : str, optional 42 | Type of machine learning task ('class' for classification, other values for regression), default is 'class'. 43 | debug : bool, optional 44 | Whether to print debugging information, default is False. 45 | 46 | Attributes 47 | ---------- 48 | xgb_model : XGBClassifier or XGBRegressor 49 | Fitted XGBoost model. 50 | nn_model : NeuralNet 51 | Fitted neural network model. 52 | TDE_encoder : TreeDrivenEncoder 53 | Fitted Tree-Driven Encoder. 54 | input_shape : int 55 | Shape of the input feature space. 56 | device : torch.device 57 | Device used for computations ('cuda' or 'cpu'). 58 | """ 59 | 60 | def __init__( 61 | self, 62 | n_est=23, 63 | max_depth=4, 64 | drop=0.23, 65 | xgb_lr=0.5, 66 | batch_size=320, 67 | n_epoch=100, 68 | hidden_dim=256, 69 | n_layers=4, 70 | task="class", 71 | debug=False, 72 | checkpoint_name="checkpoint.pt", 73 | ): 74 | # Validate input parameters 75 | if not isinstance(n_est, int) or n_est <= 0: 76 | raise ValueError("n_est must be a positive integer") 77 | if not isinstance(max_depth, int) or max_depth <= 0: 78 | raise ValueError("max_depth must be a positive integer") 79 | if not isinstance(drop, (int, float)) or not 0 <= drop <= 1: 80 | raise ValueError("drop must be a float between 0 and 1") 81 | if not isinstance(xgb_lr, (int, float)) or xgb_lr <= 0: 82 | raise ValueError("xgb_lr must be a positive float") 83 | if not isinstance(batch_size, int) or batch_size <= 0: 84 | raise ValueError("batch_size must be a positive integer") 85 | if not isinstance(n_epoch, int) or n_epoch <= 0: 86 | raise ValueError("n_epoch must be a positive integer") 87 | if not isinstance(hidden_dim, int) or hidden_dim <= 0: 88 | raise ValueError("hidden_dim must be a positive integer") 89 | if not isinstance(n_layers, int) or n_layers <= 0: 90 | raise ValueError("n_layers must be a positive integer") 91 | if task not in ["class", "reg"]: 92 | raise ValueError("task must be either 'class' or 'reg'") 93 | if not isinstance(debug, bool): 94 | raise ValueError("debug must be a boolean") 95 | if not isinstance(checkpoint_name, str) or not checkpoint_name: 96 | raise ValueError("checkpoint_name must be a non-empty string") 97 | 98 | self.n_est = n_est 99 | self.max_depth = max_depth 100 | self.n_epoch = n_epoch 101 | self.hidden_dim = hidden_dim 102 | self.drop = drop 103 | self.debug = debug 104 | self.xgb_lr = xgb_lr 105 | self.n_layers = n_layers 106 | self.checkpoint_name = checkpoint_name 107 | self.batch_size = batch_size 108 | self.task = task 109 | self.xgb_model = None 110 | self.nn_model = None 111 | self.TDE_encoder = TreeDrivenEncoder() 112 | self.input_shape = None 113 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 114 | 115 | def fit(self, X_train, y_train, X_val=None, y_val=None): 116 | # Validate inputs 117 | if len(X_train) == 0 or len(y_train) == 0: 118 | raise ValueError("Empty input data") 119 | if len(X_train) != len(y_train): 120 | raise ValueError("X_train and y_train must have the same length") 121 | if np.any(np.isnan(X_train)) or np.any(np.isnan(y_train)): 122 | raise ValueError("Input contains NaN values") 123 | if X_val is not None and y_val is not None: 124 | if len(X_val) != len(y_val): 125 | raise ValueError("X_val and y_val must have the same length") 126 | if np.any(np.isnan(X_val)) or np.any(np.isnan(y_val)): 127 | raise ValueError("Validation data contains NaN values") 128 | 129 | # Store target values for classification tasks 130 | if self.task == "class": 131 | self.last_y = y_train 132 | 133 | self.fit_xgb(X_train) 134 | trees = self.xgb_model.get_booster().get_dump(with_stats=False) 135 | self.TDE_encoder.fit(trees) 136 | enc_X_train = self.TDE_encoder.transform(X_train) 137 | self.input_shape = enc_X_train.shape[1] 138 | self.fit_nn(enc_X_train, y_train, X_val, y_val) 139 | 140 | def fit_xgb(self, X_train): 141 | # Using XGBRegressor for self-supervised learning. 142 | self.xgb_model = xgb.XGBRegressor( 143 | learning_rate=self.xgb_lr, 144 | n_jobs=-1, 145 | max_depth=self.max_depth, 146 | n_estimators=self.n_est, 147 | ) 148 | # Using X_train as target for self-supervised learning. 149 | self.xgb_model.fit(X_train, X_train) 150 | 151 | def fit_nn(self, enc_X_train, y_train, X_val=None, y_val=None): 152 | if X_val is not None: 153 | enc_X_val = self.TDE_encoder.transform(X_val) 154 | else: 155 | enc_X_train, enc_X_val, y_train, y_val = train_test_split( 156 | enc_X_train, y_train, test_size=0.2, random_state=42 157 | ) 158 | 159 | train_loader = DataLoader( 160 | dataset=myDataset(enc_X_train, y_train), 161 | batch_size=self.batch_size, 162 | shuffle=True, 163 | ) 164 | val_loader = DataLoader( 165 | dataset=myDataset(enc_X_val, y_val), 166 | batch_size=self.batch_size, 167 | shuffle=False, 168 | ) 169 | criterion = nn.CrossEntropyLoss() if self.task == "class" else nn.MSELoss() 170 | num_of_outputs = len(set(y_train)) if self.task == "class" else 1 171 | self.nn_model = NeuralNet( 172 | self.input_shape, self.hidden_dim, self.n_layers, num_of_outputs, self.drop 173 | ).to(self.device) 174 | optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=1e-3) 175 | early_stopping = EarlyStopping( 176 | patience=20, verbose=self.debug, path=self.checkpoint_name 177 | ) 178 | 179 | for epoch in tqdm(range(self.n_epoch), desc="Epochs"): 180 | self.nn_model.train() 181 | for batch_X, batch_y in train_loader: 182 | if self.task == "class": 183 | batch_X, batch_y = batch_X.float().to( 184 | self.device 185 | ), batch_y.long().to(self.device) 186 | else: 187 | batch_X, batch_y = batch_X.float().to( 188 | self.device 189 | ), batch_y.float().to(self.device) 190 | outputs = self.nn_model(batch_X) 191 | loss = criterion(outputs, batch_y) 192 | optimizer.zero_grad() 193 | loss.backward() 194 | optimizer.step() 195 | 196 | # if X_val is not None: 197 | self.nn_model.eval() 198 | val_loss = 0.0 199 | with torch.no_grad(): 200 | for batch_X, batch_y in val_loader: 201 | if self.task == "class": 202 | batch_X, batch_y = batch_X.float().to( 203 | self.device 204 | ), batch_y.long().to(self.device) 205 | else: 206 | batch_X, batch_y = batch_X.float().to( 207 | self.device 208 | ), batch_y.float().to(self.device) 209 | outputs = self.nn_model(batch_X) 210 | loss = criterion(outputs, batch_y) 211 | val_loss += loss.item() 212 | 213 | val_loss /= len(val_loader) # Average validation loss 214 | early_stopping(val_loss, self.nn_model) 215 | 216 | if early_stopping.early_stop: 217 | print("Early stopping") 218 | break # Break out of the epoch loop 219 | 220 | def predict(self, X): 221 | # Validate input 222 | if len(X) == 0: 223 | raise ValueError("Empty input data") 224 | if np.any(np.isnan(X)): 225 | raise ValueError("Input contains NaN values") 226 | 227 | # Check if model exists and is initialized 228 | if self.nn_model is None: 229 | try: 230 | self.nn_model = NeuralNet( 231 | self.input_shape, 232 | self.hidden_dim, 233 | self.n_layers, 234 | 1 if self.task == "reg" else len(np.unique(self.last_y)), 235 | self.drop 236 | ).to(self.device) 237 | except AttributeError: 238 | raise RuntimeError("Model not fitted. Call fit() before predict()") 239 | 240 | try: 241 | self.nn_model.load_state_dict(torch.load(self.checkpoint_name)) 242 | except FileNotFoundError: 243 | raise RuntimeError(f"Model checkpoint not found at {self.checkpoint_name}") 244 | except Exception as e: 245 | raise RuntimeError(f"Error loading model: {str(e)}") 246 | 247 | self.nn_model.eval() # Set model to evaluation mode 248 | n_samples = len(X) 249 | n_batches = (n_samples + self.batch_size - 1) // self.batch_size # Ceiling division 250 | y_hats = [] 251 | 252 | with torch.no_grad(): # Disable gradient computation for prediction 253 | for i in range(n_batches): 254 | start_idx = i * self.batch_size 255 | end_idx = min((i + 1) * self.batch_size, n_samples) 256 | 257 | batch_X = X[start_idx:end_idx] 258 | enc_X_batch = self.TDE_encoder.transform(batch_X) 259 | 260 | try: 261 | if self.task == "class": 262 | y_hat = self.nn_model(torch.Tensor(enc_X_batch).to(self.device)) 263 | y_hat = torch.argmax(y_hat, dim=1).cpu().numpy() 264 | else: 265 | y_hat = self.nn_model(torch.Tensor(enc_X_batch).to(self.device)) 266 | y_hat = y_hat.detach().cpu().numpy() 267 | except Exception as e: 268 | raise RuntimeError(f"Error during prediction: {str(e)}") 269 | 270 | y_hats.append(y_hat) 271 | 272 | predictions = np.concatenate(y_hats) 273 | # Ensure consistent output shapes 274 | if self.task == "reg": 275 | return predictions.reshape(-1, 1) # Shape: (n_samples, 1) 276 | else: 277 | return predictions.reshape(-1) # Shape: (n_samples,) 278 | 279 | 280 | class myDataset(Dataset): 281 | def __init__(self, data, labels): 282 | self.data = data 283 | self.labels = labels 284 | 285 | def __len__(self): 286 | return len(self.data) 287 | 288 | def __getitem__(self, idx): 289 | return self.data[idx], self.labels[idx] 290 | 291 | 292 | class EarlyStopping: 293 | """Early stops the training if validation loss doesn't improve after a given patience.""" 294 | 295 | # https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py 296 | 297 | def __init__( 298 | self, 299 | patience=7, 300 | verbose=False, 301 | delta=0.001, 302 | path="checkpoint.pt", 303 | trace_func=print, 304 | ): 305 | """ 306 | Args: 307 | patience (int): How long to wait after last time validation loss improved. 308 | Default: 7 309 | verbose (bool): If True, prints a message for each validation loss improvement. 310 | Default: False 311 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 312 | Default: 0 313 | path (str): Path for the checkpoint to be saved to. 314 | Default: 'checkpoint.pt' 315 | trace_func (function): trace print function. 316 | Default: print 317 | """ 318 | self.patience = patience 319 | self.verbose = verbose 320 | self.counter = 0 321 | self.best_score = None 322 | self.early_stop = False 323 | self.val_loss_min = np.Inf 324 | self.delta = delta 325 | self.path = path 326 | self.trace_func = trace_func 327 | 328 | def __call__(self, val_loss, model): 329 | score = -val_loss 330 | 331 | if self.best_score is None: 332 | self.best_score = score 333 | self.save_checkpoint(val_loss, model) 334 | elif score < self.best_score + self.delta: 335 | self.counter += 1 336 | # self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 337 | if self.counter >= self.patience: 338 | self.early_stop = True 339 | else: 340 | self.best_score = score 341 | self.save_checkpoint(val_loss, model) 342 | self.counter = 0 343 | 344 | def save_checkpoint(self, val_loss, model): 345 | """Saves model when validation loss decrease.""" 346 | if self.verbose: 347 | self.trace_func( 348 | f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..." 349 | ) 350 | torch.save(model.state_dict(), self.path) 351 | self.val_loss_min = val_loss 352 | 353 | 354 | class NeuralNet(nn.Module): 355 | def __init__(self, input_dim, hidden_dim, n_layers, num_classes, drop): 356 | super(NeuralNet, self).__init__() 357 | 358 | layers = [] 359 | layers.append(nn.Linear(input_dim, hidden_dim)) 360 | layers.append(nn.BatchNorm1d(hidden_dim)) 361 | layers.append(nn.PReLU()) 362 | layers.append(nn.Dropout(drop)) 363 | 364 | for _ in range(n_layers): 365 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 366 | layers.append(nn.BatchNorm1d(hidden_dim)) 367 | layers.append(nn.ReLU()) 368 | layers.append(nn.Dropout(drop)) 369 | 370 | layers.append(nn.Linear(hidden_dim, num_classes)) 371 | 372 | self.model = nn.Sequential(*layers) 373 | 374 | def forward(self, x): 375 | return self.model(x) 376 | -------------------------------------------------------------------------------- /examples/DeepTLF_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "quarterly-accessory", 6 | "metadata": {}, 7 | "source": [ 8 | "# DeepTLF (pre-release alpha vesrion)\n", 9 | "- DeepTLF follows the sckit-learn API\n", 10 | "- PyTorch Implementation" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "a2827b81-dad7-4273-ad92-9cc5767fea9c", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "%load_ext autoreload\n", 21 | "%autoreload 2" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "id": "universal-flashing", 28 | "metadata": { 29 | "tags": [] 30 | }, 31 | "outputs": [ 32 | { 33 | "ename": "ModuleNotFoundError", 34 | "evalue": "No module named 'src'", 35 | "output_type": "error", 36 | "traceback": [ 37 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 38 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 39 | "Cell \u001b[0;32mIn[2], line 8\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_breast_cancer, fetch_california_housing\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mxgboost\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mxgb\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msrc\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DeepTFL\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mwarnings\u001b[39;00m\n\u001b[1;32m 11\u001b[0m warnings\u001b[38;5;241m.\u001b[39mfilterwarnings(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", 40 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'src'" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "import numpy as np\n", 46 | "\n", 47 | "from sklearn.model_selection import train_test_split\n", 48 | "from sklearn.metrics import accuracy_score, mean_squared_error\n", 49 | "from sklearn.datasets import load_breast_cancer, fetch_california_housing\n", 50 | "\n", 51 | "import xgboost as xgb\n", 52 | "from deeptlf import DeepTFL\n", 53 | "\n", 54 | "import warnings\n", 55 | "warnings.filterwarnings('ignore')" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "elder-large", 61 | "metadata": {}, 62 | "source": [ 63 | "# DeepTLF model" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "id": "nasty-volume", 70 | "metadata": { 71 | "tags": [] 72 | }, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "\u001b[0;31mInit signature:\u001b[0m\n", 78 | "\u001b[0mDeepTFL\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", 79 | "\u001b[0;34m\u001b[0m \u001b[0mn_est\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m23\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 80 | "\u001b[0;34m\u001b[0m \u001b[0mmax_depth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 81 | "\u001b[0;34m\u001b[0m \u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.23\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 82 | "\u001b[0;34m\u001b[0m \u001b[0mxgb_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 83 | "\u001b[0;34m\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m320\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 84 | "\u001b[0;34m\u001b[0m \u001b[0mn_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 85 | "\u001b[0;34m\u001b[0m \u001b[0mhidden_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 86 | "\u001b[0;34m\u001b[0m \u001b[0mn_layers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 87 | "\u001b[0;34m\u001b[0m \u001b[0mtask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'class'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 88 | "\u001b[0;34m\u001b[0m \u001b[0mdebug\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 89 | "\u001b[0;34m\u001b[0m \u001b[0mcheckpoint_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'checkpoint.pt'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", 90 | "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 91 | "\u001b[0;31mDocstring:\u001b[0m \n", 92 | "A deep learning model based on XGBoost and a custom neural network.\n", 93 | "\n", 94 | "Parameters\n", 95 | "----------\n", 96 | "n_est : int, optional\n", 97 | " Number of estimators for XGBoost model, default is 23.\n", 98 | "max_depth : int, optional\n", 99 | " Max depth for each tree in XGBoost, default is 4.\n", 100 | "n_epoch : int, optional\n", 101 | " Number of epochs for neural network training, default is 100.\n", 102 | "hidden_dim : int, optional\n", 103 | " Hidden layer dimensions for neural network, default is 128.\n", 104 | "drop : float, optional\n", 105 | " Dropout rate for neural network, default is 0.23.\n", 106 | "xgb_lr : float, optional\n", 107 | " Learning rate for XGBoost model, default is 0.5.\n", 108 | "n_layers : int, optional\n", 109 | " Number of layers in the neural network, default is 4.\n", 110 | "checkpoint_name : str, optional\n", 111 | " File name to save the neural network model, default is 'checkpoint.pt'.\n", 112 | "batch_size : int, optional\n", 113 | " Batch size for neural network training, default is 320.\n", 114 | "task : str, optional\n", 115 | " Type of machine learning task ('class' for classification, other values for regression), default is 'class'.\n", 116 | "debug : bool, optional\n", 117 | " Whether to print debugging information, default is False.\n", 118 | " \n", 119 | "Attributes\n", 120 | "----------\n", 121 | "xgb_model : XGBClassifier or XGBRegressor\n", 122 | " Fitted XGBoost model.\n", 123 | "nn_model : NeuralNet\n", 124 | " Fitted neural network model.\n", 125 | "TDE_encoder : TreeDrivenEncoder\n", 126 | " Fitted Tree-Driven Encoder.\n", 127 | "input_shape : int\n", 128 | " Shape of the input feature space.\n", 129 | "device : torch.device\n", 130 | " Device used for computations ('cuda' or 'cpu').\n", 131 | "\u001b[0;31mFile:\u001b[0m /notebooks/deeptlf/DeepTLF/src/deeptlf.py\n", 132 | "\u001b[0;31mType:\u001b[0m type\n", 133 | "\u001b[0;31mSubclasses:\u001b[0m \n" 134 | ] 135 | }, 136 | "metadata": {}, 137 | "output_type": "display_data" 138 | } 139 | ], 140 | "source": [ 141 | "DeepTFL?" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "id": "df5f1b7d-f7ba-41b0-bcd2-ca658f9ec19c", 147 | "metadata": {}, 148 | "source": [ 149 | "# Example: Classification - Breast Cancer Dataset " 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 4, 155 | "id": "055532c4-626e-4236-b8e3-7951c077c84b", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "X, y = load_breast_cancer(return_X_y=True)\n", 160 | "\n", 161 | "X_train, X_test, y_train, y_test = train_test_split(X, y, \n", 162 | " test_size=0.1,\n", 163 | " random_state=42)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 5, 169 | "id": "primary-helmet", 170 | "metadata": { 171 | "tags": [] 172 | }, 173 | "outputs": [ 174 | { 175 | "name": "stderr", 176 | "output_type": "stream", 177 | "text": [ 178 | "Epochs: 36%|███▌ | 36/100 [00:02<00:03, 17.43it/s]" 179 | ] 180 | }, 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "Early stopping\n" 186 | ] 187 | }, 188 | { 189 | "name": "stderr", 190 | "output_type": "stream", 191 | "text": [ 192 | "\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "dtlf_model = DeepTFL(n_est=30, max_depth=3, drop=0.5, n_layers=3, task='class')\n", 198 | "dtlf_model.fit(X_train=X_train, y_train=y_train)\n", 199 | "dtlf_y_hat = dtlf_model.predict(X_test)\n", 200 | "dtlf_y_hat = np.clip(dtlf_y_hat,0,1)\n", 201 | "acc_DeepTLF = accuracy_score(y_test, dtlf_y_hat)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 6, 207 | "id": "voluntary-mandate", 208 | "metadata": { 209 | "tags": [] 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "xgb_model = xgb.XGBClassifier()\n", 214 | "xgb_model.fit(X_train,y_train)\n", 215 | "xgb_y_hat = xgb_model.predict(X_test)\n", 216 | "xgb_y_hat = np.clip(xgb_y_hat,0,1)\n", 217 | "acc_xgb = accuracy_score(y_test, xgb_y_hat)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 7, 223 | "id": "united-triple", 224 | "metadata": { 225 | "tags": [] 226 | }, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "Results:\n", 233 | "GBDT accuracy: \t\t 96.49122807017544 %\n", 234 | "DeepTLF accuracy:\t 98.24561403508771 %\n" 235 | ] 236 | } 237 | ], 238 | "source": [ 239 | "print('Results:')\n", 240 | "print('GBDT accuracy: \\t\\t', acc_xgb*100,'%')\n", 241 | "print('DeepTLF accuracy:\\t', acc_DeepTLF*100,'%')" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "666ee242-ab76-4bef-a529-b45e8bc72a35", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "id": "5a805987-cd4c-4e65-b901-a4066e674105", 255 | "metadata": {}, 256 | "source": [ 257 | "# Example: Regression - California Housing Dataset " 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 8, 263 | "id": "6ec3cf69-c595-464c-a84a-cf46d63d29b9", 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "X, y = fetch_california_housing(return_X_y=True)\n", 268 | "# split data \n", 269 | "X_train, X_test, y_train, y_test = train_test_split(X, y, \n", 270 | " test_size=0.1,\n", 271 | " random_state=42)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 9, 277 | "id": "5812dcab-f987-4f4a-b0f8-d6ebe6f2dd93", 278 | "metadata": { 279 | "tags": [] 280 | }, 281 | "outputs": [ 282 | { 283 | "name": "stderr", 284 | "output_type": "stream", 285 | "text": [ 286 | "Epochs: 0%| | 1/1000 [00:02<40:00, 2.40s/it]" 287 | ] 288 | }, 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "Validation loss decreased (inf --> 1.429115). Saving model ...\n" 294 | ] 295 | }, 296 | { 297 | "name": "stderr", 298 | "output_type": "stream", 299 | "text": [ 300 | "Epochs: 0%| | 2/1000 [00:04<41:44, 2.51s/it]" 301 | ] 302 | }, 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "Validation loss decreased (1.429115 --> 1.309111). Saving model ...\n" 308 | ] 309 | }, 310 | { 311 | "name": "stderr", 312 | "output_type": "stream", 313 | "text": [ 314 | "Epochs: 0%| | 3/1000 [00:07<41:16, 2.48s/it]" 315 | ] 316 | }, 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "Validation loss decreased (1.309111 --> 1.288790). Saving model ...\n" 322 | ] 323 | }, 324 | { 325 | "name": "stderr", 326 | "output_type": "stream", 327 | "text": [ 328 | "Epochs: 1%| | 6/1000 [00:14<36:37, 2.21s/it]" 329 | ] 330 | }, 331 | { 332 | "name": "stdout", 333 | "output_type": "stream", 334 | "text": [ 335 | "Validation loss decreased (1.288790 --> 1.285344). Saving model ...\n" 336 | ] 337 | }, 338 | { 339 | "name": "stderr", 340 | "output_type": "stream", 341 | "text": [ 342 | "Epochs: 1%|▏ | 13/1000 [00:30<38:30, 2.34s/it]" 343 | ] 344 | }, 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "Validation loss decreased (1.285344 --> 1.284007). Saving model ...\n" 350 | ] 351 | }, 352 | { 353 | "name": "stderr", 354 | "output_type": "stream", 355 | "text": [ 356 | "Epochs: 2%|▏ | 17/1000 [00:40<39:34, 2.42s/it]" 357 | ] 358 | }, 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "Validation loss decreased (1.284007 --> 1.282636). Saving model ...\n" 364 | ] 365 | }, 366 | { 367 | "name": "stderr", 368 | "output_type": "stream", 369 | "text": [ 370 | "Epochs: 4%|▎ | 36/1000 [01:29<40:02, 2.49s/it]" 371 | ] 372 | }, 373 | { 374 | "name": "stdout", 375 | "output_type": "stream", 376 | "text": [ 377 | "Early stopping\n" 378 | ] 379 | }, 380 | { 381 | "name": "stderr", 382 | "output_type": "stream", 383 | "text": [ 384 | "\n" 385 | ] 386 | } 387 | ], 388 | "source": [ 389 | "# DeepTLF \n", 390 | "\n", 391 | "dtlf_model = DeepTFL(n_est=100, \n", 392 | " max_depth=5, \n", 393 | " debug=True,\n", 394 | " n_epoch = 1000,\n", 395 | " task='reg')\n", 396 | "dtlf_model.fit(X_train=X_train, y_train=y_train)\n", 397 | "dtlf_y_hat = dtlf_model.predict(X_test)\n", 398 | "mse_DeepTLF = mean_squared_error(y_test, dtlf_y_hat)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 10, 404 | "id": "e4b8ca73-9dcb-42f6-9477-0cec46dc022f", 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "# XGBoost \n", 409 | "\n", 410 | "xgb_model = xgb.XGBRFRegressor()\n", 411 | "xgb_model.fit(X_train,y_train)\n", 412 | "xgb_y_hat = xgb_model.predict(X_test)\n", 413 | "mse_xgb = mean_squared_error(y_test, xgb_y_hat)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 11, 419 | "id": "b041045a-08fd-4e3c-8c0b-2ebbc2557051", 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | "Results:\n", 427 | "GBDT MSE: \t 0.4234836546668566 %\n", 428 | "DeepTLF MSE:\t 1.3378848061009587 %\n" 429 | ] 430 | } 431 | ], 432 | "source": [ 433 | "print('Results:')\n", 434 | "print('GBDT MSE: \\t', mse_xgb,'%')\n", 435 | "print('DeepTLF MSE:\\t', mse_DeepTLF,'%')" 436 | ] 437 | } 438 | ], 439 | "metadata": { 440 | "kernelspec": { 441 | "display_name": "Python 3 (ipykernel)", 442 | "language": "python", 443 | "name": "python3" 444 | }, 445 | "language_info": { 446 | "codemirror_mode": { 447 | "name": "ipython", 448 | "version": 3 449 | }, 450 | "file_extension": ".py", 451 | "mimetype": "text/x-python", 452 | "name": "python", 453 | "nbconvert_exporter": "python", 454 | "pygments_lexer": "ipython3", 455 | "version": "3.11.5" 456 | } 457 | }, 458 | "nbformat": 4, 459 | "nbformat_minor": 5 460 | } 461 | --------------------------------------------------------------------------------