├── .github └── workflows │ ├── docs.yml │ └── publish.yml ├── .gitignore ├── .ipynb_checkpoints ├── LICENSE-checkpoint ├── README-checkpoint.md ├── __init__-checkpoint.py ├── hellokan-checkpoint.ipynb └── setup-checkpoint.py ├── LICENSE ├── README.md ├── __init__.py ├── docs ├── .ipynb_checkpoints │ ├── API_1_indexing_-checkpoint.ipynb │ ├── API_2_plotting_-checkpoint.ipynb │ ├── API_3_grid_-checkpoint.ipynb │ ├── API_4_extract_activations_-checkpoint.ipynb │ ├── API_5_initialization_hyperparameter_-checkpoint.ipynb │ ├── API_6_training_hyperparameter_-checkpoint.ipynb │ ├── API_7_pruning_-checkpoint.ipynb │ ├── API_8_checkpoint_-checkpoint.ipynb │ ├── API_9_video_-checkpoint.ipynb │ ├── Example_10_relativity-addition_-checkpoint.ipynb │ ├── Example_11_encouraing_linear_-checkpoint.ipynb │ ├── Example_12_unsupervised_learning-checkpoint.ipynb │ ├── Example_13_phase_transition-checkpoint.ipynb │ ├── Example_1_function_fitting_-checkpoint.ipynb │ ├── Example_2_deep_formula_-checkpoint.ipynb │ ├── Example_3_classfication_-checkpoint.ipynb │ ├── Example_4_symbolic_regression_-checkpoint.ipynb │ ├── Example_5_special_functions_-checkpoint.ipynb │ ├── Example_6_PDE_-checkpoint.ipynb │ ├── Example_7_continual_learning_-checkpoint.ipynb │ ├── Example_8_scaling_-checkpoint.ipynb │ ├── Example_9_singularity_-checkpoint.ipynb │ ├── Makefile-checkpoint │ ├── community-checkpoint.rst │ ├── conf-checkpoint.py │ ├── demos-checkpoint.rst │ ├── examples-checkpoint.rst │ ├── index-checkpoint.rst │ ├── interp-checkpoint.rst │ ├── intro-checkpoint.ipynb │ ├── intro-checkpoint.rst │ ├── kan-checkpoint.rst │ ├── kan_plot-checkpoint.png │ ├── make-checkpoint.bat │ ├── modules-checkpoint.rst │ └── physics-checkpoint.rst ├── API_demo │ ├── .ipynb_checkpoints │ │ └── API_6_training_hyperparameter-checkpoint.rst │ ├── API_10_device.ipynb │ ├── API_10_device.rst │ ├── API_11_create_dataset.ipynb │ ├── API_11_create_dataset.rst │ ├── API_12_checkpoint_save_load_model.ipynb │ ├── API_12_checkpoint_save_load_model.rst │ ├── API_12_checkpoint_save_load_model_files │ │ ├── API_12_checkpoint_save_load_model_11_1.png │ │ ├── API_12_checkpoint_save_load_model_13_2.png │ │ ├── API_12_checkpoint_save_load_model_3_1.png │ │ ├── API_12_checkpoint_save_load_model_7_2.png │ │ └── API_12_checkpoint_save_load_model_9_1.png │ ├── API_1_indexing.ipynb │ ├── API_1_indexing.rst │ ├── API_1_indexing_files │ │ ├── API_1_indexing_12_0.png │ │ ├── API_1_indexing_14_0.png │ │ ├── API_1_indexing_16_0.png │ │ ├── API_1_indexing_1_1.png │ │ ├── API_1_indexing_4_1.png │ │ ├── API_1_indexing_5_1.png │ │ ├── API_1_indexing_6_1.png │ │ ├── API_1_indexing_7_1.png │ │ └── API_1_indexing_8_1.png │ ├── API_2_plotting.ipynb │ ├── API_2_plotting.rst │ ├── API_2_plotting_files │ │ ├── API_2_plotting_10_0.png │ │ ├── API_2_plotting_11_0.png │ │ ├── API_2_plotting_13_0.png │ │ ├── API_2_plotting_14_0.png │ │ ├── API_2_plotting_15_0.png │ │ ├── API_2_plotting_17_1.png │ │ ├── API_2_plotting_19_0.png │ │ ├── API_2_plotting_20_0.png │ │ ├── API_2_plotting_21_0.png │ │ ├── API_2_plotting_23_0.png │ │ ├── API_2_plotting_25_0.png │ │ ├── API_2_plotting_28_0.png │ │ ├── API_2_plotting_31_0.png │ │ ├── API_2_plotting_4_0.png │ │ ├── API_2_plotting_5_0.png │ │ └── API_2_plotting_9_0.png │ ├── API_3_extract_activations.ipynb │ ├── API_3_extract_activations.rst │ ├── API_3_extract_activations_files │ │ ├── API_3_extract_activations_1_1.png │ │ └── API_3_extract_activations_2_0.png │ ├── API_4_initialization.ipynb │ ├── API_4_initialization.rst │ ├── API_4_initialization_files │ │ ├── API_4_initialization_10_1.png │ │ ├── API_4_initialization_11_1.png │ │ ├── API_4_initialization_3_1.png │ │ ├── API_4_initialization_5_1.png │ │ ├── API_4_initialization_7_1.png │ │ └── API_4_initialization_8_1.png │ ├── API_5_grid.ipynb │ ├── API_5_grid.rst │ ├── API_5_grid_files │ │ └── API_5_grid_2_1.png │ ├── API_6_training_hyperparameter.ipynb │ ├── API_6_training_hyperparameter.rst │ ├── API_6_training_hyperparameter_files │ │ ├── API_6_training_hyperparameter_12_3.png │ │ ├── API_6_training_hyperparameter_14_3.png │ │ ├── API_6_training_hyperparameter_17_3.png │ │ ├── API_6_training_hyperparameter_4_3.png │ │ ├── API_6_training_hyperparameter_7_3.png │ │ └── API_6_training_hyperparameter_9_3.png │ ├── API_7_pruning.ipynb │ ├── API_7_pruning.rst │ ├── API_7_pruning_files │ │ ├── API_7_pruning_10_3.png │ │ ├── API_7_pruning_11_1.png │ │ ├── API_7_pruning_2_3.png │ │ ├── API_7_pruning_3_1.png │ │ ├── API_7_pruning_5_3.png │ │ └── API_7_pruning_7_0.png │ ├── API_8_regularization.ipynb │ ├── API_8_regularization.rst │ ├── API_8_regularization_files │ │ ├── API_8_regularization_4_3.png │ │ └── API_8_regularization_6_0.png │ ├── API_9_video.ipynb │ └── API_9_video.rst ├── Community │ ├── Community_1_physics_informed_kan.ipynb │ ├── Community_1_physics_informed_kan.rst │ ├── Community_1_physics_informed_kan_files │ │ ├── Community_1_physics_informed_kan_3_0.png │ │ ├── Community_1_physics_informed_kan_4_0.png │ │ └── Community_1_physics_informed_kan_5_0.png │ ├── Community_2_protein_sequence_classification.ipynb │ ├── Community_2_protein_sequence_classification.rst │ └── Community_2_protein_sequence_classification_files │ │ └── Community_2_protein_sequence_classification_13_0.png ├── Example │ ├── Example_10_relativity-addition.ipynb │ ├── Example_10_relativity-addition.rst │ ├── Example_10_relativity-addition_files │ │ ├── Example_10_relativity-addition_12_0.png │ │ ├── Example_10_relativity-addition_17_0.png │ │ └── Example_10_relativity-addition_6_0.png │ ├── Example_11_encouraing_linear.ipynb │ ├── Example_11_encouraing_linear.rst │ ├── Example_11_encouraing_linear_files │ │ ├── Example_11_encouraing_linear_5_0.png │ │ └── Example_11_encouraing_linear_8_0.png │ ├── Example_12_unsupervised_learning.ipynb │ ├── Example_12_unsupervised_learning.rst │ ├── Example_12_unsupervised_learning_files │ │ ├── Example_12_unsupervised_learning_4_0.png │ │ ├── Example_12_unsupervised_learning_6_0.png │ │ └── Example_12_unsupervised_learning_8_0.png │ ├── Example_13_phase_transition.ipynb │ ├── Example_13_phase_transition.rst │ ├── Example_13_phase_transition_files │ │ ├── Example_13_phase_transition_5_0.png │ │ ├── Example_13_phase_transition_7_0.png │ │ └── Example_13_phase_transition_9_0.png │ ├── Example_14_knot_supervised.ipynb │ ├── Example_14_knot_supervised.rst │ ├── Example_15_knot_unsupervised.ipynb │ ├── Example_15_knot_unsupervised.rst │ ├── Example_1_function_fitting.ipynb │ ├── Example_1_function_fitting.rst │ ├── Example_1_function_fitting_files │ │ ├── Example_1_function_fitting_12_0.png │ │ └── Example_1_function_fitting_14_1.png │ ├── Example_3_deep_formula.ipynb │ ├── Example_3_deep_formula.rst │ ├── Example_3_deep_formula_files │ │ ├── Example_3_deep_formula_11_1.png │ │ ├── Example_3_deep_formula_4_0.png │ │ ├── Example_3_deep_formula_7_1.png │ │ └── Example_3_deep_formula_9_3.png │ ├── Example_4_classfication.ipynb │ ├── Example_4_classfication.rst │ ├── Example_4_classfication_files │ │ ├── Example_4_classfication_12_1.png │ │ └── Example_4_classfication_3_2.png │ ├── Example_5_special_functions.ipynb │ ├── Example_5_special_functions.rst │ ├── Example_5_special_functions_files │ │ ├── Example_5_special_functions_4_0.png │ │ └── Example_5_special_functions_6_0.png │ ├── Example_6_PDE_interpretation.ipynb │ ├── Example_6_PDE_interpretation.rst │ ├── Example_6_PDE_interpretation_files │ │ └── Example_6_PDE_interpretation_4_0.png │ ├── Example_7_PDE_accuracy.ipynb │ ├── Example_7_PDE_accuracy.rst │ ├── Example_7_PDE_accuracy_files │ │ └── Example_7_PDE_accuracy_3_1.png │ ├── Example_8_continual_learning.ipynb │ ├── Example_8_continual_learning.rst │ ├── Example_8_continual_learning_files │ │ ├── Example_8_continual_learning_2_1.png │ │ ├── Example_8_continual_learning_4_0.png │ │ └── Example_8_continual_learning_8_0.png │ ├── Example_9_singularity.ipynb │ ├── Example_9_singularity.rst │ └── Example_9_singularity_files │ │ ├── Example_9_singularity_3_0.png │ │ └── Example_9_singularity_9_0.png ├── Interp │ ├── .ipynb_checkpoints │ │ └── Interp_11_sparse_init-checkpoint.rst │ ├── Interp_10A_swap.ipynb │ ├── Interp_10A_swap.rst │ ├── Interp_10A_swap_files │ │ ├── Interp_10A_swap_11_0.png │ │ ├── Interp_10A_swap_13_0.png │ │ ├── Interp_10A_swap_6_0.png │ │ └── Interp_10A_swap_8_0.png │ ├── Interp_10B_swap.ipynb │ ├── Interp_10B_swap.rst │ ├── Interp_10B_swap_files │ │ ├── Interp_10B_swap_11_0.png │ │ ├── Interp_10B_swap_3_0.png │ │ ├── Interp_10B_swap_5_0.png │ │ ├── Interp_10B_swap_7_0.png │ │ └── Interp_10B_swap_9_0.png │ ├── Interp_10_hessian.ipynb │ ├── Interp_10_hessian.rst │ ├── Interp_10_hessian_files │ │ ├── Interp_10_hessian_4_0.png │ │ └── Interp_10_hessian_6_0.png │ ├── Interp_11_sparse_init.ipynb │ ├── Interp_11_sparse_init.rst │ ├── Interp_11_sparse_init_files │ │ ├── Interp_11_sparse_init_1_1.png │ │ └── Interp_11_sparse_init_2_1.png │ ├── Interp_1_Hello, MultKAN.ipynb │ ├── Interp_1_Hello, MultKAN.rst │ ├── Interp_1_Hello, MultKAN_files │ │ ├── Interp_1_Hello, MultKAN_11_0.png │ │ ├── Interp_1_Hello, MultKAN_4_0.png │ │ ├── Interp_1_Hello, MultKAN_7_1.png │ │ └── Interp_1_Hello, MultKAN_9_0.png │ ├── Interp_2_Advanced MultKAN.ipynb │ ├── Interp_2_Advanced MultKAN.rst │ ├── Interp_2_Advanced MultKAN_files │ │ ├── Interp_2_Advanced MultKAN_11_1.png │ │ ├── Interp_2_Advanced MultKAN_2_1.png │ │ ├── Interp_2_Advanced MultKAN_4_1.png │ │ ├── Interp_2_Advanced MultKAN_6_1.png │ │ └── Interp_2_Advanced MultKAN_9_1.png │ ├── Interp_3_KAN_Compiler.ipynb │ ├── Interp_3_KAN_Compiler.rst │ ├── Interp_3_KAN_Compiler_files │ │ ├── Interp_3_KAN_Compiler_11_0.png │ │ ├── Interp_3_KAN_Compiler_13_0.png │ │ ├── Interp_3_KAN_Compiler_15_0.png │ │ ├── Interp_3_KAN_Compiler_16_0.png │ │ ├── Interp_3_KAN_Compiler_2_1.png │ │ ├── Interp_3_KAN_Compiler_4_0.png │ │ └── Interp_3_KAN_Compiler_9_0.png │ ├── Interp_4_feature_attribution.ipynb │ ├── Interp_4_feature_attribution.rst │ ├── Interp_4_feature_attribution_files │ │ ├── Interp_4_feature_attribution_10_1.png │ │ ├── Interp_4_feature_attribution_13_0.png │ │ ├── Interp_4_feature_attribution_15_1.png │ │ ├── Interp_4_feature_attribution_17_1.png │ │ ├── Interp_4_feature_attribution_3_0.png │ │ ├── Interp_4_feature_attribution_7_1.png │ │ └── Interp_4_feature_attribution_8_1.png │ ├── Interp_5_test_symmetry.ipynb │ ├── Interp_5_test_symmetry.rst │ ├── Interp_5_test_symmetry_files │ │ ├── Interp_5_test_symmetry_16_0.png │ │ ├── Interp_5_test_symmetry_17_0.png │ │ ├── Interp_5_test_symmetry_18_0.png │ │ └── Interp_5_test_symmetry_19_0.png │ ├── Interp_6_test_symmetry_NN.ipynb │ ├── Interp_6_test_symmetry_NN.rst │ ├── Interp_6_test_symmetry_NN_files │ │ ├── Interp_6_test_symmetry_NN_1_0.png │ │ └── Interp_6_test_symmetry_NN_3_0.png │ ├── Interp_8_adding_auxillary_variables.ipynb │ ├── Interp_8_adding_auxillary_variables.rst │ ├── Interp_8_adding_auxillary_variables_files │ │ ├── Interp_8_adding_auxillary_variables_4_0.png │ │ ├── Interp_8_adding_auxillary_variables_6_0.png │ │ └── Interp_8_adding_auxillary_variables_8_0.png │ ├── Interp_9_different_plotting_metrics.ipynb │ ├── Interp_9_different_plotting_metrics.rst │ └── Interp_9_different_plotting_metrics_files │ │ ├── Interp_9_different_plotting_metrics_3_0.png │ │ ├── Interp_9_different_plotting_metrics_4_0.png │ │ └── Interp_9_different_plotting_metrics_5_0.png ├── Makefile ├── Physics │ ├── Physics_1_Lagrangian.ipynb │ ├── Physics_1_Lagrangian.rst │ ├── Physics_1_Lagrangian_files │ │ ├── Physics_1_Lagrangian_10_0.png │ │ ├── Physics_1_Lagrangian_12_0.png │ │ ├── Physics_1_Lagrangian_15_0.png │ │ ├── Physics_1_Lagrangian_3_0.png │ │ ├── Physics_1_Lagrangian_4_1.png │ │ └── Physics_1_Lagrangian_6_0.png │ ├── Physics_2A_conservation_law.ipynb │ ├── Physics_2A_conservation_law.rst │ ├── Physics_2A_conservation_law_files │ │ └── Physics_2A_conservation_law_2_0.png │ ├── Physics_2B_conservation_law_2D.ipynb │ ├── Physics_2B_conservation_law_2D.rst │ ├── Physics_2B_conservation_law_2D_files │ │ ├── Physics_2B_conservation_law_2D_12_0.png │ │ ├── Physics_2B_conservation_law_2D_2_0.png │ │ └── Physics_2B_conservation_law_2D_6_0.png │ ├── Physics_3_blackhole.ipynb │ ├── Physics_3_blackhole.rst │ ├── Physics_3_blackhole_files │ │ ├── Physics_3_blackhole_10_1.png │ │ ├── Physics_3_blackhole_5_1.png │ │ ├── Physics_3_blackhole_6_1.png │ │ └── Physics_3_blackhole_7_1.png │ ├── Physics_4A_constitutive_laws_P11.ipynb │ ├── Physics_4A_constitutive_laws_P11.rst │ ├── Physics_4A_constitutive_laws_P11_files │ │ ├── Physics_4A_constitutive_laws_P11_10_0.png │ │ ├── Physics_4A_constitutive_laws_P11_11_1.png │ │ ├── Physics_4A_constitutive_laws_P11_14_0.png │ │ ├── Physics_4A_constitutive_laws_P11_2_1.png │ │ ├── Physics_4A_constitutive_laws_P11_3_1.png │ │ ├── Physics_4A_constitutive_laws_P11_5_0.png │ │ └── Physics_4A_constitutive_laws_P11_8_0.png │ ├── Physics_4B_constitutive_laws_P12_with_prior.ipynb │ ├── Physics_4B_constitutive_laws_P12_with_prior.rst │ ├── Physics_4B_constitutive_laws_P12_with_prior_files │ │ ├── Physics_4B_constitutive_laws_P12_with_prior_10_1.png │ │ ├── Physics_4B_constitutive_laws_P12_with_prior_13_0.png │ │ ├── Physics_4B_constitutive_laws_P12_with_prior_2_1.png │ │ ├── Physics_4B_constitutive_laws_P12_with_prior_3_1.png │ │ ├── Physics_4B_constitutive_laws_P12_with_prior_4_1.png │ │ ├── Physics_4B_constitutive_laws_P12_with_prior_6_0.png │ │ └── Physics_4B_constitutive_laws_P12_with_prior_9_0.png │ ├── Physics_4C_constitutive_laws_P12_without_prior.ipynb │ ├── Physics_4C_constitutive_laws_P12_without_prior.rst │ └── Physics_4C_constitutive_laws_P12_without_prior_files │ │ ├── Physics_4C_constitutive_laws_P12_without_prior_3_1.png │ │ ├── Physics_4C_constitutive_laws_P12_without_prior_4_0.png │ │ └── Physics_4C_constitutive_laws_P12_without_prior_6_0.png ├── community.rst ├── conf.py ├── demos.rst ├── examples.rst ├── index.rst ├── interp.rst ├── intro.ipynb ├── intro.rst ├── intro_files │ ├── intro_10_0.png │ ├── intro_12_0.png │ ├── intro_14_0.png │ ├── intro_15_0.png │ ├── intro_17_0.png │ ├── intro_19_0.png │ ├── intro_21_0.png │ ├── intro_23_0.png │ ├── intro_26_0.png │ └── intro_6_0.png ├── kan.rst ├── kan_plot.png ├── make.bat ├── modules.rst └── physics.rst ├── hellokan.ipynb ├── kan ├── .ipynb_checkpoints │ ├── KANLayer-checkpoint.py │ ├── LBFGS-checkpoint.py │ ├── MLP-checkpoint.py │ ├── MultKAN-checkpoint.py │ ├── Symbolic_KANLayer-checkpoint.py │ ├── __init__-checkpoint.py │ ├── compiler-checkpoint.py │ ├── experiment-checkpoint.py │ ├── feynman-checkpoint.py │ ├── hypothesis-checkpoint.py │ ├── spline-checkpoint.py │ └── utils-checkpoint.py ├── KANLayer.py ├── LBFGS.py ├── MLP.py ├── MultKAN.py ├── Symbolic_KANLayer.py ├── __init__.py ├── assets │ └── img │ │ ├── mult_symbol.png │ │ └── sum_symbol.png ├── compiler.py ├── experiment.py ├── experiments │ └── experiment1.ipynb ├── feynman.py ├── hypothesis.py ├── spline.py └── utils.py ├── model ├── 0.0_cache_data ├── 0.0_config.yml ├── 0.0_state └── history.txt ├── pykan.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt └── top_level.txt ├── requirements.txt ├── setup.py └── tutorials ├── .ipynb_checkpoints ├── API_10_device-checkpoint.ipynb ├── API_11_create_dataset-checkpoint.ipynb ├── API_12_checkpoint_save_load_model-checkpoint.ipynb ├── API_1_indexing-checkpoint.ipynb ├── API_2_plotting-checkpoint.ipynb ├── API_3_extract_activations-checkpoint.ipynb ├── API_4_initialization-checkpoint.ipynb ├── API_5_grid-checkpoint.ipynb ├── API_6_training_hyperparameter-checkpoint.ipynb ├── API_7_pruning-checkpoint.ipynb ├── API_8_regularization-checkpoint.ipynb ├── API_9_video-checkpoint.ipynb ├── Community_1_physics_informed_kan-checkpoint.ipynb ├── Example_10_relativity-addition-checkpoint.ipynb ├── Example_11_encouraing_linear-checkpoint.ipynb ├── Example_12_unsupervised_learning-checkpoint.ipynb ├── Example_13_phase_transition-checkpoint.ipynb ├── Example_14_knot_supervised-checkpoint.ipynb ├── Example_15_knot_unsupervised-checkpoint.ipynb ├── Example_1_function_fitting-checkpoint.ipynb ├── Example_3_deep_formula-checkpoint.ipynb ├── Example_4_classfication-checkpoint.ipynb ├── Example_5_special_functions-checkpoint.ipynb ├── Example_6_PDE_interpretation-checkpoint.ipynb ├── Example_7_PDE_accuracy-checkpoint.ipynb ├── Example_8_continual_learning-checkpoint.ipynb ├── Example_9_singularity-checkpoint.ipynb ├── Interp_10A_swap-checkpoint.ipynb ├── Interp_10B_swap-checkpoint.ipynb ├── Interp_10_hessian-checkpoint.ipynb ├── Interp_11_sparse_init-checkpoint.ipynb ├── Interp_1_Hello, MultKAN-checkpoint.ipynb ├── Interp_2_Advanced MultKAN-checkpoint.ipynb ├── Interp_3_KAN_Compiler-checkpoint.ipynb ├── Interp_4_feature_attribution-checkpoint.ipynb ├── Interp_5_test_symmetry-checkpoint.ipynb ├── Interp_6_test_symmetry_NN-checkpoint.ipynb ├── Interp_8_adding_auxillary_variables-checkpoint.ipynb └── Interp_9_different_plotting_metrics-checkpoint.ipynb ├── API_demo ├── API_10_device.ipynb ├── API_11_create_dataset.ipynb ├── API_12_checkpoint_save_load_model.ipynb ├── API_1_indexing.ipynb ├── API_2_plotting.ipynb ├── API_3_extract_activations.ipynb ├── API_4_initialization.ipynb ├── API_5_grid.ipynb ├── API_6_training_hyperparameter.ipynb ├── API_7_pruning.ipynb ├── API_8_regularization.ipynb └── API_9_video.ipynb ├── Community ├── Community_1_physics_informed_kan.ipynb └── Community_2_protein_sequence_classification.ipynb ├── Example ├── .ipynb_checkpoints │ └── Example_1_function_fitting-checkpoint.ipynb ├── Example_10_relativity-addition.ipynb ├── Example_11_encouraing_linear.ipynb ├── Example_12_unsupervised_learning.ipynb ├── Example_13_phase_transition.ipynb ├── Example_14_knot_supervised.ipynb ├── Example_15_knot_unsupervised.ipynb ├── Example_1_function_fitting.ipynb ├── Example_3_deep_formula.ipynb ├── Example_4_classfication.ipynb ├── Example_5_special_functions.ipynb ├── Example_6_PDE_interpretation.ipynb ├── Example_7_PDE_accuracy.ipynb ├── Example_8_continual_learning.ipynb ├── Example_9_singularity.ipynb └── model │ ├── 0.0_cache_data │ ├── 0.0_config.yml │ ├── 0.0_state │ ├── 0.10_cache_data │ ├── 0.10_config.yml │ ├── 0.10_state │ ├── 0.11_cache_data │ ├── 0.11_config.yml │ ├── 0.11_state │ ├── 0.1_cache_data │ ├── 0.1_config.yml │ ├── 0.1_state │ ├── 0.2_cache_data │ ├── 0.2_config.yml │ ├── 0.2_state │ ├── 0.3_cache_data │ ├── 0.3_config.yml │ ├── 0.3_state │ ├── 0.4_cache_data │ ├── 0.4_config.yml │ ├── 0.4_state │ ├── 0.5_cache_data │ ├── 0.5_config.yml │ ├── 0.5_state │ ├── 0.6_cache_data │ ├── 0.6_config.yml │ ├── 0.6_state │ ├── 0.7_cache_data │ ├── 0.7_config.yml │ ├── 0.7_state │ ├── 0.8_cache_data │ ├── 0.8_config.yml │ ├── 0.8_state │ ├── 0.9_cache_data │ ├── 0.9_config.yml │ ├── 0.9_state │ └── history.txt ├── Interp ├── .ipynb_checkpoints │ └── Interp_1_Hello, MultKAN-checkpoint.ipynb ├── Interp_10A_swap.ipynb ├── Interp_10B_swap.ipynb ├── Interp_10_hessian.ipynb ├── Interp_11_sparse_init.ipynb ├── Interp_1_Hello, MultKAN.ipynb ├── Interp_2_Advanced MultKAN.ipynb ├── Interp_3_KAN_Compiler.ipynb ├── Interp_4_feature_attribution.ipynb ├── Interp_5_test_symmetry.ipynb ├── Interp_6_test_symmetry_NN.ipynb ├── Interp_8_adding_auxillary_variables.ipynb └── Interp_9_different_plotting_metrics.ipynb └── Physics ├── Physics_1_Lagrangian.ipynb ├── Physics_2A_conservation_law.ipynb ├── Physics_2B_conservation_law_2D.ipynb ├── Physics_3_blackhole.ipynb ├── Physics_4A_constitutive_laws_P11.ipynb ├── Physics_4B_constitutive_laws_P12_with_prior.ipynb └── Physics_4C_constitutive_laws_P12_without_prior.ipynb /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: 5 | - master 6 | 7 | jobs: 8 | docs: 9 | name: Docs 10 | runs-on: ubuntu-latest 11 | steps: 12 | 13 | - uses: actions/checkout@v2 14 | 15 | - name: Install Python 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: 3.9.19 19 | 20 | - name: Install requirements 21 | run: | 22 | pip install -U sphinx 23 | pip install sphinx-rtd-theme 24 | 25 | 26 | - name: Build docs 27 | run: | 28 | pip3 install . 29 | cd docs 30 | sphinx-apidoc -o . ../kan 31 | make clean html 32 | make html 33 | # https://github.com/peaceiris/actions-gh-pages 34 | - name: Deploy 35 | if: success() 36 | uses: peaceiris/actions-gh-pages@v3 37 | with: 38 | publish_branch: gh-pages 39 | github_token: ${{ secrets.GITHUB_TOKEN }} 40 | publish_dir: docs/_build/html/ 41 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package to PyPI when a Release is Created 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | pypi-publish: 9 | name: Publish release to PyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: release 13 | url: https://pypi.org/p/pykan 14 | permissions: 15 | id-token: write 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.x" 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install setuptools wheel 26 | - name: Build package 27 | run: | 28 | python setup.py sdist bdist_wheel # Could also be python -m build 29 | - name: Publish package distributions to PyPI 30 | uses: pypa/gh-action-pypi-publish@release/v1 31 | with: 32 | password: ${{ secrets.PYPI_API_TOKEN }} 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | docs/_build/ 4 | docs/_static/ 5 | docs/_templates 6 | test 7 | sr 8 | pde 9 | hidden 10 | molecule 11 | expressiveness 12 | figures 13 | molecule 14 | applications 15 | experiments 16 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/LICENSE-checkpoint: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ziming Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/.ipynb_checkpoints/__init__-checkpoint.py -------------------------------------------------------------------------------- /.ipynb_checkpoints/setup-checkpoint.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | # Load the long_description from README.md 4 | with open("README.md", "r", encoding="utf8") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="pykan", 9 | version="0.2.7", 10 | author="Ziming Liu", 11 | author_email="zmliu@mit.edu", 12 | description="Kolmogorov Arnold Networks", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | # url="https://github.com/kindxiaoming/", 16 | packages=setuptools.find_packages(), 17 | include_package_data=True, 18 | package_data={ 19 | 'pykan': [ 20 | 'figures/lock.png', 21 | 'assets/img/sum_symbol.png', 22 | 'assets/img/mult_symbol.png', 23 | ], 24 | }, 25 | classifiers=[ 26 | "Programming Language :: Python :: 3", 27 | "License :: OSI Approved :: MIT License", 28 | "Operating System :: OS Independent", 29 | ], 30 | python_requires='>=3.6', 31 | ) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ziming Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/__init__.py -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/API_9_video_-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "134e7f9d", 6 | "metadata": {}, 7 | "source": [ 8 | "# API Demo 9: Videos of KAN training\n", 9 | "\n", 10 | "### We have shown one can visualize KAN with the plot() method. If one wants to save the training dynamics of KAN plots, one only needs to pass argument save_video = True to train() method (and set some video related parameters)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "2075ef56", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "train loss: 6.39e-03 | test loss: 6.40e-03 | reg: 7.91e+00 : 100%|██| 50/50 [01:30<00:00, 1.81s/it]\n" 24 | ] 25 | }, 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "Moviepy - Building video video.mp4.\n", 31 | "Moviepy - Writing video video.mp4\n", 32 | "\n" 33 | ] 34 | }, 35 | { 36 | "name": "stderr", 37 | "output_type": "stream", 38 | "text": [ 39 | " \r" 40 | ] 41 | }, 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "Moviepy - Done !\n", 47 | "Moviepy - video ready video.mp4\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "from kan import KAN, create_dataset\n", 53 | "import torch\n", 54 | "\n", 55 | "# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n", 56 | "model = KAN(width=[4,2,1,1], grid=3, k=3, seed=0)\n", 57 | "f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n", 58 | "dataset = create_dataset(f, n_var=4, train_num=3000)\n", 59 | "\n", 60 | "# train the model\n", 61 | "#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n", 62 | "model.train(dataset, opt=\"LBFGS\", steps=50, lamb=5e-5, lamb_entropy=2., save_video=True, beta=10, \n", 63 | " in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n", 64 | " out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n", 65 | " video_name='video', fps=5);" 66 | ] 67 | } 68 | ], 69 | "metadata": { 70 | "kernelspec": { 71 | "display_name": "Python 3 (ipykernel)", 72 | "language": "python", 73 | "name": "python3" 74 | }, 75 | "language_info": { 76 | "codemirror_mode": { 77 | "name": "ipython", 78 | "version": 3 79 | }, 80 | "file_extension": ".py", 81 | "mimetype": "text/x-python", 82 | "name": "python", 83 | "nbconvert_exporter": "python", 84 | "pygments_lexer": "ipython3", 85 | "version": "3.9.7" 86 | } 87 | }, 88 | "nbformat": 4, 89 | "nbformat_minor": 5 90 | } 91 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/Makefile-checkpoint: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/community-checkpoint.rst: -------------------------------------------------------------------------------- 1 | .. _community: 2 | 3 | Community 4 | --------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | Community/Community_1_physics_informed_kan.rst 10 | Community/Community_2_protein_sequence_classification.rst 11 | 12 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/conf-checkpoint.py: -------------------------------------------------------------------------------- 1 | import sphinx_rtd_theme 2 | 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # For the full list of built-in configuration values, see the documentation: 6 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 7 | 8 | # -- Project information ----------------------------------------------------- 9 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 10 | 11 | project = 'Kolmogorov Arnold Network' 12 | copyright = '2024, Ziming Liu' 13 | author = 'Ziming Liu' 14 | 15 | # -- General configuration --------------------------------------------------- 16 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 17 | 18 | extensions = ["sphinx_rtd_theme", 19 | "sphinx.ext.autodoc", 20 | "sphinx.ext.autosectionlabel" 21 | ] 22 | 23 | templates_path = ['_templates'] 24 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 25 | 26 | 27 | 28 | # -- Options for HTML output ------------------------------------------------- 29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 30 | 31 | #html_theme = 'alabaster' 32 | html_theme = "sphinx_rtd_theme" 33 | html_static_path = ['_static'] 34 | 35 | def skip(app, what, name, obj, would_skip, options): 36 | if name == "__init__": 37 | return False 38 | return would_skip 39 | 40 | def setup(app): 41 | app.connect("autodoc-skip-member", skip) 42 | 43 | autodoc_mock_imports = ["numpy", 44 | "torch", 45 | "torch.nn", 46 | "matplotlib", 47 | "matplotlib.pyplot", 48 | "tqdm", 49 | "sympy", 50 | "scipy", 51 | "sklearn", 52 | "torch.optim", 53 | "re", 54 | "yaml", 55 | "pandas"] 56 | 57 | 58 | source_suffix = [".rst", ".md"] 59 | #source_suffix = [".rst", ".md", ".ipynb"] 60 | #source_suffix = { 61 | # '.rst': 'restructuredtext', 62 | # '.ipynb': 'myst-nb', 63 | # '.myst': 'myst-nb', 64 | #} 65 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/demos-checkpoint.rst: -------------------------------------------------------------------------------- 1 | .. _api-demo: 2 | 3 | API Demos 4 | --------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | API_demo/API_1_indexing.rst 10 | API_demo/API_2_plotting.rst 11 | API_demo/API_3_extract_activations.rst 12 | API_demo/API_4_initialization.rst 13 | API_demo/API_5_grid.rst 14 | API_demo/API_6_training_hyperparameter.rst 15 | API_demo/API_7_pruning.rst 16 | API_demo/API_8_regularization.rst 17 | API_demo/API_9_video.rst 18 | API_demo/API_10_device.rst 19 | API_demo/API_11_create_dataset.rst 20 | API_demo/API_12_checkpoint_save_load_model.rst -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/examples-checkpoint.rst: -------------------------------------------------------------------------------- 1 | .. _examples: 2 | 3 | Examples 4 | -------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | Example/Example_1_function_fitting.rst 10 | Example/Example_3_deep_formula.rst 11 | Example/Example_4_classfication.rst 12 | Example/Example_5_special_functions.rst 13 | Example/Example_6_PDE_interpretation.rst 14 | Example/Example_7_PDE_accuracy.rst 15 | Example/Example_8_continual_learning.rst 16 | Example/Example_9_singularity.rst 17 | Example/Example_10_relativity-addition.rst 18 | Example/Example_11_encouraing_linear.rst 19 | Example/Example_12_unsupervised_learning.rst 20 | Example/Example_13_phase_transition.rst 21 | Example/Example_14_knot_supervised.rst 22 | Example/Example_15_knot_unsupervised.rst 23 | 24 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/index-checkpoint.rst: -------------------------------------------------------------------------------- 1 | .. kolmogorov-arnold-network documentation master file, created by 2 | sphinx-quickstart on Sun Apr 21 12:57:28 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Kolmogorov Arnold Network (KAN) documentation! 7 | ========================================================== 8 | 9 | .. image:: kan_plot.png 10 | 11 | This documentation is for the `paper`_ "KAN: Kolmogorov-Arnold Networks" and the `github repo`_. 12 | Kolmogorov-Arnold Networks, inspired by the Kolmogorov-Arnold representation theorem, are promising alternatives 13 | of Multi-Layer Preceptrons (MLPs). KANs have activation functions on edges, whereas MLPs have activation functions on nodes. 14 | This simple change makes KAN better than MLPs in terms of both accuracy and interpretability. 15 | 16 | .. _github repo: https://github.com/KindXiaoming/pykan 17 | .. _paper: https://arxiv.org/abs/2404.19756 18 | 19 | Installation 20 | ------------ 21 | 22 | Installation via github 23 | ~~~~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | .. code-block:: python 26 | 27 | git clone https://github.com/KindXiaoming/pykan.git 28 | cd pykan 29 | pip install -e . 30 | # pip install -r requirements.txt # install requirements 31 | 32 | 33 | Installation via PyPI 34 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 35 | 36 | .. code-block:: python 37 | 38 | pip install pykan 39 | 40 | 41 | Requirements 42 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 43 | 44 | .. code-block:: python 45 | # python==3.9.7 46 | matplotlib==3.6.2 47 | numpy==1.24.4 48 | scikit_learn==1.1.3 49 | setuptools==65.5.0 50 | sympy==1.11.1 51 | torch==2.2.2 52 | tqdm==4.66.2 53 | 54 | Get started 55 | ----------- 56 | 57 | * Quickstart: :ref:`hello-kan` 58 | * KANs in Action: :ref:`api-demo`, :ref:`examples` 59 | * API (advanced): :ref:`api`. 60 | 61 | .. toctree:: 62 | :maxdepth: 1 63 | :caption: Contents: 64 | 65 | intro.rst 66 | modules.rst 67 | demos.rst 68 | examples.rst 69 | interp.rst 70 | physics.rst 71 | community.rst 72 | 73 | Indices and tables 74 | ================== 75 | 76 | * :ref:`genindex` 77 | * :ref:`modindex` 78 | * :ref:`search` 79 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/interp-checkpoint.rst: -------------------------------------------------------------------------------- 1 | .. _interp: 2 | 3 | Interpretability 4 | ---------------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | Interp/Interp_1_Hello, MultKAN.rst 10 | Interp/Interp_2_Advanced MultKAN.rst 11 | Interp/Interp_3_KAN_Compiler.rst 12 | Interp/Interp_4_feature_attribution.rst 13 | Interp/Interp_5_test_symmetry.rst 14 | Interp/Interp_6_test_symmetry_NN.rst 15 | Interp/Interp_8_adding_auxillary_variables.rst 16 | Interp/Interp_9_different_plotting_metrics.rst 17 | Interp/Interp_10_hessian.rst 18 | Interp/Interp_10A_swap.rst 19 | Interp/Interp_10B_swap.rst 20 | Interp/Interp_11_sparse_init.rst 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/kan-checkpoint.rst: -------------------------------------------------------------------------------- 1 | kan package 2 | =========== 3 | 4 | kan.KAN module 5 | -------------- 6 | 7 | .. automodule:: kan.MultKAN 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | kan.KANLayer module 13 | ------------------- 14 | 15 | .. automodule:: kan.KANLayer 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | kan.LBFGS module 21 | ---------------- 22 | 23 | .. automodule:: kan.LBFGS 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | kan.Symbolic\_KANLayer module 29 | ----------------------------- 30 | 31 | .. automodule:: kan.Symbolic_KANLayer 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | kan.spline module 37 | ----------------- 38 | 39 | .. automodule:: kan.spline 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | kan.utils module 45 | ---------------- 46 | 47 | .. automodule:: kan.utils 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | kan.compiler module 53 | ------------------- 54 | 55 | .. automodule:: kan.compiler 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | kan.hypothesis module 61 | --------------------- 62 | 63 | .. automodule:: kan.hypothesis 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/kan_plot-checkpoint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/.ipynb_checkpoints/kan_plot-checkpoint.png -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/make-checkpoint.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/modules-checkpoint.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | API 4 | === 5 | 6 | .. toctree:: 7 | :maxdepth: 4 8 | 9 | kan 10 | -------------------------------------------------------------------------------- /docs/.ipynb_checkpoints/physics-checkpoint.rst: -------------------------------------------------------------------------------- 1 | .. _physics: 2 | 3 | Physics 4 | ------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | Physics/Physics_1_Lagrangian.rst 10 | Physics/Physics_2A_conservation_law.rst 11 | Physics/Physics_2B_conservation_law_2D.rst 12 | Physics/Physics_3_blackhole.rst 13 | Physics/Physics_4A_constitutive_laws_P11.rst 14 | Physics/Physics_4B_constitutive_laws_P12_with_prior.rst 15 | Physics/Physics_4C_constitutive_laws_P12_without_prior.rst 16 | 17 | 18 | -------------------------------------------------------------------------------- /docs/API_demo/API_10_device.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "134e7f9d", 6 | "metadata": {}, 7 | "source": [ 8 | "# Demo 10: Device\n", 9 | "\n", 10 | "All other demos have by default used device = 'cpu'. In case we want to use cuda, we should pass the device argument to model and dataset." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "7a4ac1e1-84ba-4bc3-91b6-a776a5e7711c", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "cpu\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "from kan import KAN, create_dataset\n", 29 | "import torch\n", 30 | "\n", 31 | "\n", 32 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 33 | "print(device)\n", 34 | "\n", 35 | "#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 36 | "device = 'cpu'\n", 37 | "print(device)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "2075ef56", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "checkpoint directory created: ./model\n", 51 | "saving model version 0.0\n" 52 | ] 53 | }, 54 | { 55 | "name": "stderr", 56 | "output_type": "stream", 57 | "text": [ 58 | "| train_loss: 6.83e-01 | test_loss: 7.21e-01 | reg: 1.04e+03 | : 100%|█| 50/50 [00:19<00:00, 2.56it\n" 59 | ] 60 | }, 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "saving model version 0.1\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "model = KAN(width=[4,100,100,100,1], grid=3, k=3, seed=0).to(device)\n", 71 | "f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n", 72 | "dataset = create_dataset(f, n_var=4, train_num=1000, device=device)\n", 73 | "\n", 74 | "# train the model\n", 75 | "#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n", 76 | "model.fit(dataset, opt=\"Adam\", lr=1e-3, steps=50, lamb=1e-3, lamb_entropy=5., update_grid=False);" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "2f182cc1-51bf-4151-a253-a52fe854919e", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "id": "f6f8125e-d26d-4c97-9e5f-988099bb4737", 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "cuda\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "device = 'cuda'\n", 103 | "print(device)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "id": "95017dfa-3a2a-43e0-8b68-fb220ca5abc9", 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "checkpoint directory created: ./model\n", 117 | "saving model version 0.0\n" 118 | ] 119 | }, 120 | { 121 | "name": "stderr", 122 | "output_type": "stream", 123 | "text": [ 124 | "| train_loss: 6.83e-01 | test_loss: 7.21e-01 | reg: 1.04e+03 | : 100%|█| 50/50 [00:01<00:00, 26.45it\n" 125 | ] 126 | }, 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "saving model version 0.1\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "model = KAN(width=[4,100,100,100,1], grid=3, k=3, seed=0).to(device)\n", 137 | "f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n", 138 | "dataset = create_dataset(f, n_var=4, train_num=1000, device=device)\n", 139 | "\n", 140 | "# train the model\n", 141 | "#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n", 142 | "model.fit(dataset, opt=\"Adam\", lr=1e-3, steps=50, lamb=1e-3, lamb_entropy=5., update_grid=False);" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "8230d562-2635-4adc-b566-06ac679b166a", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Python 3 (ipykernel)", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.9.16" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 5 175 | } 176 | -------------------------------------------------------------------------------- /docs/API_demo/API_10_device.rst: -------------------------------------------------------------------------------- 1 | Demo 10: Device 2 | =============== 3 | 4 | All other demos have by default used device = ‘cpu’. In case we want to 5 | use cuda, we should pass the device argument to model and dataset. 6 | 7 | .. code:: ipython3 8 | 9 | from kan import KAN, create_dataset 10 | import torch 11 | 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | print(device) 15 | 16 | #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | device = 'cpu' 18 | print(device) 19 | 20 | 21 | .. parsed-literal:: 22 | 23 | cpu 24 | 25 | 26 | .. code:: ipython3 27 | 28 | model = KAN(width=[4,100,100,100,1], grid=3, k=3, seed=0).to(device) 29 | f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2) 30 | dataset = create_dataset(f, n_var=4, train_num=1000, device=device) 31 | 32 | # train the model 33 | #model.train(dataset, opt="LBFGS", steps=20, lamb=1e-3, lamb_entropy=2.); 34 | model.fit(dataset, opt="Adam", lr=1e-3, steps=50, lamb=1e-3, lamb_entropy=5., update_grid=False); 35 | 36 | 37 | .. parsed-literal:: 38 | 39 | checkpoint directory created: ./model 40 | saving model version 0.0 41 | 42 | 43 | .. parsed-literal:: 44 | 45 | | train_loss: 6.83e-01 | test_loss: 7.21e-01 | reg: 1.04e+03 | : 100%|█| 50/50 [00:19<00:00, 2.56it 46 | 47 | 48 | .. parsed-literal:: 49 | 50 | saving model version 0.1 51 | 52 | 53 | 54 | .. code:: ipython3 55 | 56 | device = 'cuda' 57 | print(device) 58 | 59 | 60 | .. parsed-literal:: 61 | 62 | cuda 63 | 64 | 65 | .. code:: ipython3 66 | 67 | model = KAN(width=[4,100,100,100,1], grid=3, k=3, seed=0).to(device) 68 | f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2) 69 | dataset = create_dataset(f, n_var=4, train_num=1000, device=device) 70 | 71 | # train the model 72 | #model.train(dataset, opt="LBFGS", steps=20, lamb=1e-3, lamb_entropy=2.); 73 | model.fit(dataset, opt="Adam", lr=1e-3, steps=50, lamb=1e-3, lamb_entropy=5., update_grid=False); 74 | 75 | 76 | .. parsed-literal:: 77 | 78 | checkpoint directory created: ./model 79 | saving model version 0.0 80 | 81 | 82 | .. parsed-literal:: 83 | 84 | | train_loss: 6.83e-01 | test_loss: 7.21e-01 | reg: 1.04e+03 | : 100%|█| 50/50 [00:01<00:00, 26.45it 85 | 86 | 87 | .. parsed-literal:: 88 | 89 | saving model version 0.1 90 | 91 | 92 | -------------------------------------------------------------------------------- /docs/API_demo/API_11_create_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "53ff2e87", 6 | "metadata": {}, 7 | "source": [ 8 | "# API 11: Create dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "25a90774", 14 | "metadata": {}, 15 | "source": [ 16 | "how to use create_dataset in kan.utils" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "2f9ae0c7", 22 | "metadata": {}, 23 | "source": [ 24 | "Standard way" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "id": "3e2b9f8b", 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "cuda\n" 38 | ] 39 | }, 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "torch.Size([1000, 1])" 44 | ] 45 | }, 46 | "execution_count": 1, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "from kan.utils import create_dataset\n", 53 | "import torch\n", 54 | "\n", 55 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 56 | "print(device)\n", 57 | "\n", 58 | "f = lambda x: x[:,[0]] * x[:,[1]]\n", 59 | "dataset = create_dataset(f, n_var=2, device=device)\n", 60 | "dataset['train_label'].shape" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "877956c9", 66 | "metadata": {}, 67 | "source": [ 68 | "Lazier way. We sometimes forget to add the bracket, i.e., write x[:,[0]] as x[:,0], and this used to lead to an error in training (loss not going down). Now the create_dataset can automatically detect this simplification and produce the correct behavior." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "id": "b14dd4a2", 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "text/plain": [ 80 | "torch.Size([1000, 1])" 81 | ] 82 | }, 83 | "execution_count": 2, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "f = lambda x: x[:,0] * x[:,1]\n", 90 | "dataset = create_dataset(f, n_var=2, device=device)\n", 91 | "dataset['train_label'].shape" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "60230da4", 97 | "metadata": {}, 98 | "source": [ 99 | "Laziest way. If you even want to get rid of the colon symbol, i.e., you want to write x[;,0] as x[0], you can do that but need to pass in f_mode = 'row'." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "id": "e764f415", 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/plain": [ 111 | "torch.Size([1000, 1])" 112 | ] 113 | }, 114 | "execution_count": 3, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "f = lambda x: x[0] * x[1]\n", 121 | "dataset = create_dataset(f, n_var=2, f_mode='row', device=device)\n", 122 | "dataset['train_label'].shape" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "8e1f1732", 128 | "metadata": {}, 129 | "source": [ 130 | "if you already have x (inputs) and y (outputs), and you only want to partition them into train/test, use create_dataset_from_data" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 4, 136 | "id": "accf900a", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "import torch\n", 141 | "from kan.utils import create_dataset_from_data\n", 142 | "\n", 143 | "x = torch.rand(100,2)\n", 144 | "y = torch.rand(100,1)\n", 145 | "dataset = create_dataset_from_data(x, y, device=device)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "c45062a8", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3 (ipykernel)", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.9.16" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 5 178 | } 179 | -------------------------------------------------------------------------------- /docs/API_demo/API_11_create_dataset.rst: -------------------------------------------------------------------------------- 1 | API 11: Create dataset 2 | ====================== 3 | 4 | how to use create_dataset in kan.utils 5 | 6 | Standard way 7 | 8 | .. code:: ipython3 9 | 10 | from kan.utils import create_dataset 11 | import torch 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | print(device) 15 | 16 | f = lambda x: x[:,[0]] * x[:,[1]] 17 | dataset = create_dataset(f, n_var=2, device=device) 18 | dataset['train_label'].shape 19 | 20 | 21 | .. parsed-literal:: 22 | 23 | cuda 24 | 25 | 26 | 27 | 28 | .. parsed-literal:: 29 | 30 | torch.Size([1000, 1]) 31 | 32 | 33 | 34 | Lazier way. We sometimes forget to add the bracket, i.e., write x[:,[0]] 35 | as x[:,0], and this used to lead to an error in training (loss not going 36 | down). Now the create_dataset can automatically detect this 37 | simplification and produce the correct behavior. 38 | 39 | .. code:: ipython3 40 | 41 | f = lambda x: x[:,0] * x[:,1] 42 | dataset = create_dataset(f, n_var=2, device=device) 43 | dataset['train_label'].shape 44 | 45 | 46 | 47 | 48 | .. parsed-literal:: 49 | 50 | torch.Size([1000, 1]) 51 | 52 | 53 | 54 | Laziest way. If you even want to get rid of the colon symbol, i.e., you 55 | want to write x[;,0] as x[0], you can do that but need to pass in f_mode 56 | = ‘row’. 57 | 58 | .. code:: ipython3 59 | 60 | f = lambda x: x[0] * x[1] 61 | dataset = create_dataset(f, n_var=2, f_mode='row', device=device) 62 | dataset['train_label'].shape 63 | 64 | 65 | 66 | 67 | .. parsed-literal:: 68 | 69 | torch.Size([1000, 1]) 70 | 71 | 72 | 73 | if you already have x (inputs) and y (outputs), and you only want to 74 | partition them into train/test, use create_dataset_from_data 75 | 76 | .. code:: ipython3 77 | 78 | import torch 79 | from kan.utils import create_dataset_from_data 80 | 81 | x = torch.rand(100,2) 82 | y = torch.rand(100,1) 83 | dataset = create_dataset_from_data(x, y, device=device) 84 | 85 | -------------------------------------------------------------------------------- /docs/API_demo/API_12_checkpoint_save_load_model.rst: -------------------------------------------------------------------------------- 1 | API 12: Checkpoint, save & load model 2 | ===================================== 3 | 4 | Whenever the KAN (model) is altered (e.g., fit, prune …), a new version 5 | is saved to the model.ckpt folder (by default ‘model’). The version 6 | number is ‘a.b’, where a is the round number (starting from zero, +1 7 | when model.rewind() is called), b is the version number in each round. 8 | 9 | the initialized model has version 0.0 10 | 11 | .. code:: ipython3 12 | 13 | from kan import * 14 | import torch 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | print(device) 18 | 19 | f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 20 | dataset = create_dataset(f, n_var=2, device=device) 21 | model = KAN(width=[2,5,1], grid=5, k=3, seed=1, auto_save=True, device=device) 22 | model.get_act(dataset) 23 | model.plot() 24 | 25 | 26 | .. parsed-literal:: 27 | 28 | cuda 29 | checkpoint directory created: ./model 30 | saving model version 0.0 31 | 32 | 33 | 34 | .. image:: API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_3_1.png 35 | 36 | 37 | the auto_save is on (by default) 38 | 39 | .. code:: ipython3 40 | 41 | model.auto_save 42 | 43 | 44 | 45 | 46 | .. parsed-literal:: 47 | 48 | True 49 | 50 | 51 | 52 | After fitting, the version becomes 0.1 53 | 54 | .. code:: ipython3 55 | 56 | model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01); 57 | model.plot() 58 | 59 | 60 | .. parsed-literal:: 61 | 62 | | train_loss: 3.34e-02 | test_loss: 3.29e-02 | reg: 4.93e+00 | : 100%|█| 20/20 [00:03<00:00, 5.10it 63 | 64 | 65 | .. parsed-literal:: 66 | 67 | saving model version 0.1 68 | 69 | 70 | 71 | .. image:: API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_7_2.png 72 | 73 | 74 | After pruning, the version becomes 0.2 75 | 76 | .. code:: ipython3 77 | 78 | model = model.prune() 79 | model.plot() 80 | 81 | 82 | .. parsed-literal:: 83 | 84 | saving model version 0.2 85 | 86 | 87 | 88 | .. image:: API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_9_1.png 89 | 90 | 91 | Suppose we want to revert back to version 0.1, use model = 92 | model.rewind(‘0.1’). This starts a new round, meaning version 0.1 93 | renamed to version 1.1. 94 | 95 | .. code:: ipython3 96 | 97 | # revert to version 0.1 (if continuing) 98 | model = model.rewind('0.1') 99 | 100 | # revert to version 0.1 (if starting from scratch) 101 | #model = KAN.loadckpt('./model' + '0.1') 102 | #model.get_act(dataset) 103 | 104 | model.plot() 105 | 106 | 107 | .. parsed-literal:: 108 | 109 | rewind to model version 0.1, renamed as 1.1 110 | 111 | 112 | 113 | .. image:: API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_11_1.png 114 | 115 | 116 | Suppose we do some more manipulation to version 1.1, we will roll 117 | forward to version 1.2 118 | 119 | .. code:: ipython3 120 | 121 | model.fit(dataset, opt="LBFGS", steps=2); 122 | model.plot() 123 | 124 | 125 | .. parsed-literal:: 126 | 127 | | train_loss: 2.06e-02 | test_loss: 2.18e-02 | reg: 5.48e+00 | : 100%|█| 2/2 [00:00<00:00, 5.83it/s 128 | 129 | 130 | .. parsed-literal:: 131 | 132 | saving model version 1.2 133 | 134 | 135 | 136 | .. image:: API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_13_2.png 137 | 138 | -------------------------------------------------------------------------------- /docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_11_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_11_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_13_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_13_2.png -------------------------------------------------------------------------------- /docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_3_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_7_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_7_2.png -------------------------------------------------------------------------------- /docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_9_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_12_checkpoint_save_load_model_files/API_12_checkpoint_save_load_model_9_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_12_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_12_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_14_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_14_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_16_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_16_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_1_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_4_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_5_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_5_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_6_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_6_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_7_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_7_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_1_indexing_files/API_1_indexing_8_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_1_indexing_files/API_1_indexing_8_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_10_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_11_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_11_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_13_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_13_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_14_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_14_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_15_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_15_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_17_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_17_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_19_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_19_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_20_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_20_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_21_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_21_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_23_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_23_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_25_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_25_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_28_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_28_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_31_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_31_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_4_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_5_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_2_plotting_files/API_2_plotting_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_2_plotting_files/API_2_plotting_9_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_3_extract_activations.rst: -------------------------------------------------------------------------------- 1 | API 3: Extracting activation functions 2 | ====================================== 3 | 4 | The KAN diagrams give intuitive illustration, but sometimes we may also 5 | want to extract the values of activation functions for more quantitative 6 | tasks. Using the indexing convention introduced in the indexing 7 | notebook, each edge is indexed as :math:`(l,i,j)`, where :math:`l` is 8 | the layer index, :math:`i` is the input neuron index, and :math:`j` is 9 | output neuron index. 10 | 11 | .. code:: ipython3 12 | 13 | from kan import * 14 | import matplotlib.pyplot as plt 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | print(device) 18 | 19 | # create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). 20 | model = KAN(width=[2,5,1], grid=3, k=3, seed=1, device=device) 21 | x = torch.normal(0,1,size=(100,2)).to(device) 22 | model(x) 23 | model.plot(beta=100) 24 | 25 | 26 | .. parsed-literal:: 27 | 28 | cuda 29 | checkpoint directory created: ./model 30 | saving model version 0.0 31 | 32 | 33 | 34 | .. image:: API_3_extract_activations_files/API_3_extract_activations_1_1.png 35 | 36 | 37 | .. code:: ipython3 38 | 39 | l = 0 40 | i = 0 41 | j = 3 42 | x, y = model.get_fun(l,i,j) 43 | 44 | 45 | 46 | .. image:: API_3_extract_activations_files/API_3_extract_activations_2_0.png 47 | 48 | 49 | If we are interested in the range of some activation function, we can 50 | use get_range. 51 | 52 | .. code:: ipython3 53 | 54 | model.get_range(l,i,j) 55 | 56 | 57 | .. parsed-literal:: 58 | 59 | x range: [-1.61 , 3.38 ] 60 | y range: [-0.19 , 0.56 ] 61 | 62 | 63 | 64 | 65 | .. parsed-literal:: 66 | 67 | (array(-1.6111118, dtype=float32), 68 | array(3.38374, dtype=float32), 69 | array(-0.18606013, dtype=float32), 70 | array(0.5614974, dtype=float32)) 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /docs/API_demo/API_3_extract_activations_files/API_3_extract_activations_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_3_extract_activations_files/API_3_extract_activations_1_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_3_extract_activations_files/API_3_extract_activations_2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_3_extract_activations_files/API_3_extract_activations_2_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_4_initialization.rst: -------------------------------------------------------------------------------- 1 | API 4: Initialization 2 | ===================== 3 | 4 | Initialization is the first step to gaurantee good training. Each 5 | activation function is initialized to be 6 | :math:`\phi(x)={\rm scale\_base}*b(x) + {\rm scale\_sp}*{\rm spline}(x)`. 7 | 1. :math:`b(x)` is the base function, default: ‘silu’, can be set with 8 | :math:`{\rm base\_fun}` 9 | 10 | 2. scale_sp sample from N(0, noise_scale^2) 11 | 12 | 3. scale_base sampled from N(scale_base_mu, scale_base_sigma^2) 13 | 14 | 4. sparse initialization: if sparse_init = True, most scale_base and 15 | scale_sp will be set to zero 16 | 17 | Default setup 18 | 19 | .. code:: ipython3 20 | 21 | from kan import KAN, create_dataset 22 | import torch 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | print(device) 26 | 27 | model = KAN(width=[2,5,1], grid=5, k=3, seed=0, device=device) 28 | x = torch.normal(0,1,size=(100,2)).to(device) 29 | model(x) # forward is needed to collect activations for plotting 30 | model.plot() 31 | 32 | 33 | .. parsed-literal:: 34 | 35 | cuda 36 | checkpoint directory created: ./model 37 | saving model version 0.0 38 | 39 | 40 | 41 | .. image:: API_4_initialization_files/API_4_initialization_3_1.png 42 | 43 | 44 | Case 1: Initialize all activation functions to be exactly linear. We 45 | need to set noise_scale_base = 0., base_fun = identity, noise_scale = 0. 46 | 47 | .. code:: ipython3 48 | 49 | model = KAN(width=[2,5,1], grid=5, k=3, seed=0, base_fun = 'identity', device=device) 50 | x = torch.normal(0,1,size=(100,2)).to(device) 51 | model(x) # forward is needed to collect activations for plotting 52 | model.plot() 53 | 54 | 55 | .. parsed-literal:: 56 | 57 | checkpoint directory created: ./model 58 | saving model version 0.0 59 | 60 | 61 | 62 | .. image:: API_4_initialization_files/API_4_initialization_5_1.png 63 | 64 | 65 | Case 2: Noisy spline initialization (not recommended, just for 66 | illustration). Set noise_scale to be a large number. 67 | 68 | .. code:: ipython3 69 | 70 | model = KAN(width=[2,5,1], grid=5, k=3, seed=0, noise_scale=0.3, device=device) 71 | x = torch.normal(0,1,size=(100,2)).to(device) 72 | model(x) # forward is needed to collect activations for plotting 73 | model.plot() 74 | 75 | 76 | .. parsed-literal:: 77 | 78 | checkpoint directory created: ./model 79 | saving model version 0.0 80 | 81 | 82 | 83 | .. image:: API_4_initialization_files/API_4_initialization_7_1.png 84 | 85 | 86 | .. code:: ipython3 87 | 88 | model = KAN(width=[2,5,1], grid=5, k=3, seed=0, noise_scale=10., device=device) 89 | x = torch.normal(0,1,size=(100,2)).to(device) 90 | model(x) # forward is needed to collect activations for plotting 91 | model.plot() 92 | 93 | 94 | .. parsed-literal:: 95 | 96 | checkpoint directory created: ./model 97 | saving model version 0.0 98 | 99 | 100 | 101 | .. image:: API_4_initialization_files/API_4_initialization_8_1.png 102 | 103 | 104 | Case 3: scale_base_mu and scale_base_sigma 105 | 106 | .. code:: ipython3 107 | 108 | model = KAN(width=[2,5,1], grid=5, k=3, seed=0, scale_base_mu=5, scale_base_sigma=0, device=device) 109 | x = torch.normal(0,1,size=(100,2)).to(device) 110 | model(x) # forward is needed to collect activations for plotting 111 | model.plot() 112 | 113 | 114 | .. parsed-literal:: 115 | 116 | checkpoint directory created: ./model 117 | saving model version 0.0 118 | 119 | 120 | 121 | .. image:: API_4_initialization_files/API_4_initialization_10_1.png 122 | 123 | 124 | .. code:: ipython3 125 | 126 | model = KAN(width=[2,5,1], grid=5, k=3, seed=0, sparse_init=True, device=device) 127 | x = torch.normal(0,1,size=(100,2)).to(device) 128 | model(x) # forward is needed to collect activations for plotting 129 | model.plot() 130 | 131 | 132 | .. parsed-literal:: 133 | 134 | checkpoint directory created: ./model 135 | saving model version 0.0 136 | 137 | 138 | 139 | .. image:: API_4_initialization_files/API_4_initialization_11_1.png 140 | 141 | 142 | -------------------------------------------------------------------------------- /docs/API_demo/API_4_initialization_files/API_4_initialization_10_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_4_initialization_files/API_4_initialization_10_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_4_initialization_files/API_4_initialization_11_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_4_initialization_files/API_4_initialization_11_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_4_initialization_files/API_4_initialization_3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_4_initialization_files/API_4_initialization_3_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_4_initialization_files/API_4_initialization_5_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_4_initialization_files/API_4_initialization_5_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_4_initialization_files/API_4_initialization_7_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_4_initialization_files/API_4_initialization_7_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_4_initialization_files/API_4_initialization_8_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_4_initialization_files/API_4_initialization_8_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_5_grid_files/API_5_grid_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_5_grid_files/API_5_grid_2_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_12_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_12_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_14_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_14_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_17_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_17_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_4_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_4_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_7_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_7_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_9_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_6_training_hyperparameter_files/API_6_training_hyperparameter_9_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_7_pruning.rst: -------------------------------------------------------------------------------- 1 | API 7: Pruning 2 | ============== 3 | 4 | We usually use pruning to make neural networks sparser hence more 5 | efficient and more interpretable. KANs provide two ways of pruning: 6 | automatic pruning, and manual pruning. 7 | 8 | Pruning nodes 9 | ------------- 10 | 11 | .. code:: ipython3 12 | 13 | from kan import * 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | print(device) 17 | 18 | # create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). 19 | model = KAN(width=[2,5,1], grid=5, k=3, seed=1, device=device) 20 | 21 | # create dataset f(x,y) = exp(sin(pi*x)+y^2) 22 | f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 23 | dataset = create_dataset(f, n_var=2, device=device) 24 | dataset['train_input'].shape, dataset['train_label'].shape 25 | 26 | # train the model 27 | model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01); 28 | model(dataset['train_input']) 29 | model.plot() 30 | 31 | 32 | .. parsed-literal:: 33 | 34 | cuda 35 | checkpoint directory created: ./model 36 | saving model version 0.0 37 | 38 | 39 | .. parsed-literal:: 40 | 41 | | train_loss: 3.46e-02 | test_loss: 3.46e-02 | reg: 4.91e+00 | : 100%|█| 20/20 [00:05<00:00, 3.36it 42 | 43 | 44 | .. parsed-literal:: 45 | 46 | saving model version 0.1 47 | 48 | 49 | 50 | .. image:: API_7_pruning_files/API_7_pruning_2_3.png 51 | 52 | 53 | .. code:: ipython3 54 | 55 | mode = 'auto' 56 | 57 | if mode == 'auto': 58 | # automatic 59 | model = model.prune_node(threshold=1e-2) # by default the threshold is 1e-2 60 | model.plot() 61 | elif mode == 'manual': 62 | # manual 63 | model = model.prune_node(active_neurons_id=[[0]]) 64 | 65 | 66 | .. parsed-literal:: 67 | 68 | saving model version 0.2 69 | 70 | 71 | 72 | .. image:: API_7_pruning_files/API_7_pruning_3_1.png 73 | 74 | 75 | Pruning Edges 76 | ------------- 77 | 78 | .. code:: ipython3 79 | 80 | from kan import * 81 | # create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). 82 | model = KAN(width=[2,5,1], grid=5, k=3, seed=1, device=device) 83 | 84 | # create dataset f(x,y) = exp(sin(pi*x)+y^2) 85 | f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 86 | dataset = create_dataset(f, n_var=2, device=device) 87 | dataset['train_input'].shape, dataset['train_label'].shape 88 | 89 | # train the model 90 | model.fit(dataset, opt="LBFGS", steps=6, lamb=0.01); 91 | model(dataset['train_input']) 92 | model.plot() 93 | 94 | 95 | .. parsed-literal:: 96 | 97 | checkpoint directory created: ./model 98 | saving model version 0.0 99 | 100 | 101 | .. parsed-literal:: 102 | 103 | | train_loss: 7.84e-02 | test_loss: 7.80e-02 | reg: 7.26e+00 | : 100%|█| 6/6 [00:01<00:00, 3.72it/s 104 | 105 | 106 | .. parsed-literal:: 107 | 108 | saving model version 0.1 109 | 110 | 111 | 112 | .. image:: API_7_pruning_files/API_7_pruning_5_3.png 113 | 114 | 115 | .. code:: ipython3 116 | 117 | model.prune_edge() 118 | 119 | 120 | .. parsed-literal:: 121 | 122 | saving model version 0.2 123 | 124 | 125 | .. code:: ipython3 126 | 127 | model.plot() 128 | 129 | 130 | 131 | .. image:: API_7_pruning_files/API_7_pruning_7_0.png 132 | 133 | 134 | Prune nodes and edges together 135 | ------------------------------ 136 | 137 | just use model.prune() 138 | 139 | .. code:: ipython3 140 | 141 | from kan import * 142 | # create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). 143 | model = KAN(width=[2,5,1], grid=5, k=3, seed=1, device=device) 144 | 145 | # create dataset f(x,y) = exp(sin(pi*x)+y^2) 146 | f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 147 | dataset = create_dataset(f, n_var=2, device=device) 148 | dataset['train_input'].shape, dataset['train_label'].shape 149 | 150 | # train the model 151 | model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01); 152 | model(dataset['train_input']) 153 | model.plot() 154 | 155 | 156 | .. parsed-literal:: 157 | 158 | checkpoint directory created: ./model 159 | saving model version 0.0 160 | 161 | 162 | .. parsed-literal:: 163 | 164 | | train_loss: 3.46e-02 | test_loss: 3.46e-02 | reg: 4.91e+00 | : 100%|█| 20/20 [00:05<00:00, 3.70it 165 | 166 | 167 | .. parsed-literal:: 168 | 169 | saving model version 0.1 170 | 171 | 172 | 173 | .. image:: API_7_pruning_files/API_7_pruning_10_3.png 174 | 175 | 176 | .. code:: ipython3 177 | 178 | model = model.prune() 179 | model.plot() 180 | 181 | 182 | .. parsed-literal:: 183 | 184 | saving model version 0.2 185 | 186 | 187 | 188 | .. image:: API_7_pruning_files/API_7_pruning_11_1.png 189 | 190 | 191 | -------------------------------------------------------------------------------- /docs/API_demo/API_7_pruning_files/API_7_pruning_10_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_7_pruning_files/API_7_pruning_10_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_7_pruning_files/API_7_pruning_11_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_7_pruning_files/API_7_pruning_11_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_7_pruning_files/API_7_pruning_2_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_7_pruning_files/API_7_pruning_2_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_7_pruning_files/API_7_pruning_3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_7_pruning_files/API_7_pruning_3_1.png -------------------------------------------------------------------------------- /docs/API_demo/API_7_pruning_files/API_7_pruning_5_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_7_pruning_files/API_7_pruning_5_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_7_pruning_files/API_7_pruning_7_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_7_pruning_files/API_7_pruning_7_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_8_regularization.rst: -------------------------------------------------------------------------------- 1 | API 8: Regularization 2 | ===================== 3 | 4 | Regularization helps interpretability by making KANs sparser. This may 5 | require some hyperparamter tuning. Let’s see how hyperparameters can 6 | affect training 7 | 8 | Load KAN and create_dataset 9 | 10 | .. code:: ipython3 11 | 12 | from kan import * 13 | import torch 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | print(device) 17 | 18 | f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 19 | dataset = create_dataset(f, n_var=2, device=device) 20 | dataset['train_input'].shape, dataset['train_label'].shape 21 | 22 | 23 | .. parsed-literal:: 24 | 25 | cuda 26 | 27 | 28 | 29 | 30 | .. parsed-literal:: 31 | 32 | (torch.Size([1000, 2]), torch.Size([1000, 1])) 33 | 34 | 35 | 36 | We apply L1 regularization to which tensor? Currently, we support five 37 | choices for reg_metric: \* ‘edge_forward_spline_n’: the “norm” of edge, 38 | normalized (output std/input std), only consider the spline (ignorning 39 | symbolic) \* ‘edge_forward_sum’: the “norm” of edge, normamlized (output 40 | std/input std), including both spline + symbolic \* 41 | ‘edge_forward_spline_u’: the “norm” of edge, unnormalized (output std), 42 | only consider the spline (ignorning symbolic) \* ‘edge_backward’: edge 43 | attribution score \* ‘node_backward’: node attribution score 44 | 45 | .. code:: ipython3 46 | 47 | # train the model 48 | model = KAN(width=[2,5,1], grid=3, k=3, seed=1, device=device) 49 | model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, reg_metric='edge_forward_spline_n'); # default 50 | #model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, reg_metric='edge_forward_sum'); 51 | #model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, reg_metric='edge_forward_spline_u'); 52 | #model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, reg_metric='edge_backward'); 53 | #model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, reg_metric='node_backward'); 54 | model.plot() 55 | 56 | 57 | .. parsed-literal:: 58 | 59 | checkpoint directory created: ./model 60 | saving model version 0.0 61 | 62 | 63 | .. parsed-literal:: 64 | 65 | | train_loss: 4.57e-02 | test_loss: 4.35e-02 | reg: 7.15e+00 | : 100%|█| 20/20 [00:04<00:00, 4.58it 66 | 67 | 68 | .. parsed-literal:: 69 | 70 | saving model version 0.1 71 | 72 | 73 | 74 | .. image:: API_8_regularization_files/API_8_regularization_4_3.png 75 | 76 | 77 | Note: To plot the KAN diagram, there are also three options \* 78 | forward_u: same as edge_forward_spline_u \* forward_n: same as 79 | edge_forward_spline_u \* backward: same as edge_backward 80 | 81 | .. code:: ipython3 82 | 83 | model.plot(metric='forward_u') 84 | #model.plot(metric='forward_n') 85 | #model.plot(metric='backward') # default 86 | 87 | 88 | 89 | .. image:: API_8_regularization_files/API_8_regularization_6_0.png 90 | 91 | 92 | -------------------------------------------------------------------------------- /docs/API_demo/API_8_regularization_files/API_8_regularization_4_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_8_regularization_files/API_8_regularization_4_3.png -------------------------------------------------------------------------------- /docs/API_demo/API_8_regularization_files/API_8_regularization_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/API_demo/API_8_regularization_files/API_8_regularization_6_0.png -------------------------------------------------------------------------------- /docs/API_demo/API_9_video.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "134e7f9d", 6 | "metadata": {}, 7 | "source": [ 8 | "# API 9: Videos\n", 9 | "\n", 10 | "We have shown one can visualize KAN with the plot() method. If one wants to save the training dynamics of KAN plots, one only needs to pass argument save_video = True to train() method (and set some video related parameters)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 6, 16 | "id": "2075ef56", 17 | "metadata": { 18 | "tags": [] 19 | }, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "cuda\n", 26 | "checkpoint directory created: ./model\n", 27 | "saving model version 0.0\n" 28 | ] 29 | }, 30 | { 31 | "name": "stderr", 32 | "output_type": "stream", 33 | "text": [ 34 | "| train_loss: 2.89e-01 | test_loss: 2.96e-01 | reg: 1.31e+01 | : 100%|█| 5/5 [00:09<00:00, 1.94s/it" 35 | ] 36 | }, 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "saving model version 0.1\n" 42 | ] 43 | }, 44 | { 45 | "name": "stderr", 46 | "output_type": "stream", 47 | "text": [ 48 | "\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "from kan import *\n", 54 | "import torch\n", 55 | "\n", 56 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 57 | "print(device)\n", 58 | "\n", 59 | "# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n", 60 | "model = KAN(width=[4,2,1,1], grid=3, k=3, seed=1, device=device)\n", 61 | "f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n", 62 | "dataset = create_dataset(f, n_var=4, train_num=3000, device=device)\n", 63 | "\n", 64 | "image_folder = 'video_img'\n", 65 | "\n", 66 | "# train the model\n", 67 | "#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n", 68 | "model.fit(dataset, opt=\"LBFGS\", steps=5, lamb=0.001, lamb_entropy=2., save_fig=True, beta=10, \n", 69 | " in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n", 70 | " out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n", 71 | " img_folder=image_folder);\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 2, 77 | "id": "c18245a3", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Moviepy - Building video video.mp4.\n", 85 | "Moviepy - Writing video video.mp4\n", 86 | "\n" 87 | ] 88 | }, 89 | { 90 | "name": "stderr", 91 | "output_type": "stream", 92 | "text": [ 93 | " \r" 94 | ] 95 | }, 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Moviepy - Done !\n", 101 | "Moviepy - video ready video.mp4\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "import os\n", 107 | "import numpy as np\n", 108 | "import moviepy.video.io.ImageSequenceClip # moviepy == 1.0.3\n", 109 | "\n", 110 | "video_name='video'\n", 111 | "fps=5\n", 112 | "\n", 113 | "fps = fps\n", 114 | "files = os.listdir(image_folder)\n", 115 | "train_index = []\n", 116 | "for file in files:\n", 117 | " if file[0].isdigit() and file.endswith('.jpg'):\n", 118 | " train_index.append(int(file[:-4]))\n", 119 | "\n", 120 | "train_index = np.sort(train_index)\n", 121 | "\n", 122 | "image_files = [image_folder+'/'+str(train_index[index])+'.jpg' for index in train_index]\n", 123 | "\n", 124 | "clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)\n", 125 | "clip.write_videofile(video_name+'.mp4')" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "88d0d737", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "Python 3 (ipykernel)", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.9.16" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 5 158 | } 159 | -------------------------------------------------------------------------------- /docs/API_demo/API_9_video.rst: -------------------------------------------------------------------------------- 1 | API 9: Videos 2 | ============= 3 | 4 | We have shown one can visualize KAN with the plot() method. If one wants 5 | to save the training dynamics of KAN plots, one only needs to pass 6 | argument save_video = True to train() method (and set some video related 7 | parameters) 8 | 9 | .. code:: ipython3 10 | 11 | from kan import * 12 | import torch 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | print(device) 16 | 17 | # create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). 18 | model = KAN(width=[4,2,1,1], grid=3, k=3, seed=1, device=device) 19 | f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2) 20 | dataset = create_dataset(f, n_var=4, train_num=3000, device=device) 21 | 22 | image_folder = 'video_img' 23 | 24 | # train the model 25 | #model.train(dataset, opt="LBFGS", steps=20, lamb=1e-3, lamb_entropy=2.); 26 | model.fit(dataset, opt="LBFGS", steps=5, lamb=0.001, lamb_entropy=2., save_fig=True, beta=10, 27 | in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'], 28 | out_vars=[r'${\rm exp}({\rm sin}(x_1^2+x_2^2)+{\rm sin}(x_3^2+x_4^2))$'], 29 | img_folder=image_folder); 30 | 31 | 32 | 33 | .. parsed-literal:: 34 | 35 | cuda 36 | checkpoint directory created: ./model 37 | saving model version 0.0 38 | 39 | 40 | .. parsed-literal:: 41 | 42 | | train_loss: 2.89e-01 | test_loss: 2.96e-01 | reg: 1.31e+01 | : 100%|█| 5/5 [00:09<00:00, 1.94s/it 43 | 44 | .. parsed-literal:: 45 | 46 | saving model version 0.1 47 | 48 | 49 | .. parsed-literal:: 50 | 51 | 52 | 53 | 54 | .. code:: ipython3 55 | 56 | import os 57 | import numpy as np 58 | import moviepy.video.io.ImageSequenceClip # moviepy == 1.0.3 59 | 60 | video_name='video' 61 | fps=5 62 | 63 | fps = fps 64 | files = os.listdir(image_folder) 65 | train_index = [] 66 | for file in files: 67 | if file[0].isdigit() and file.endswith('.jpg'): 68 | train_index.append(int(file[:-4])) 69 | 70 | train_index = np.sort(train_index) 71 | 72 | image_files = [image_folder+'/'+str(train_index[index])+'.jpg' for index in train_index] 73 | 74 | clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps) 75 | clip.write_videofile(video_name+'.mp4') 76 | 77 | 78 | .. parsed-literal:: 79 | 80 | Moviepy - Building video video.mp4. 81 | Moviepy - Writing video video.mp4 82 | 83 | 84 | 85 | .. parsed-literal:: 86 | 87 | 88 | 89 | .. parsed-literal:: 90 | 91 | Moviepy - Done ! 92 | Moviepy - video ready video.mp4 93 | 94 | 95 | -------------------------------------------------------------------------------- /docs/Community/Community_1_physics_informed_kan_files/Community_1_physics_informed_kan_3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Community/Community_1_physics_informed_kan_files/Community_1_physics_informed_kan_3_0.png -------------------------------------------------------------------------------- /docs/Community/Community_1_physics_informed_kan_files/Community_1_physics_informed_kan_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Community/Community_1_physics_informed_kan_files/Community_1_physics_informed_kan_4_0.png -------------------------------------------------------------------------------- /docs/Community/Community_1_physics_informed_kan_files/Community_1_physics_informed_kan_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Community/Community_1_physics_informed_kan_files/Community_1_physics_informed_kan_5_0.png -------------------------------------------------------------------------------- /docs/Community/Community_2_protein_sequence_classification_files/Community_2_protein_sequence_classification_13_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Community/Community_2_protein_sequence_classification_files/Community_2_protein_sequence_classification_13_0.png -------------------------------------------------------------------------------- /docs/Example/Example_10_relativity-addition_files/Example_10_relativity-addition_12_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_10_relativity-addition_files/Example_10_relativity-addition_12_0.png -------------------------------------------------------------------------------- /docs/Example/Example_10_relativity-addition_files/Example_10_relativity-addition_17_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_10_relativity-addition_files/Example_10_relativity-addition_17_0.png -------------------------------------------------------------------------------- /docs/Example/Example_10_relativity-addition_files/Example_10_relativity-addition_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_10_relativity-addition_files/Example_10_relativity-addition_6_0.png -------------------------------------------------------------------------------- /docs/Example/Example_11_encouraing_linear.rst: -------------------------------------------------------------------------------- 1 | Example 11: Encouraging linearity 2 | ================================= 3 | 4 | In cases where we don’t know how deep we should set KANs to be, one 5 | strategy is to try from small models, grudually making models 6 | wider/deeper until we find the minimal model that performs the task 7 | quite well. Another strategy is to start from a big enough model and 8 | prune it down. This jupyter notebook demonstrates cases where we go for 9 | the second strategy. Besides sparsity along width, we also want 10 | activation functions to be linear (‘shortcut’ along depth). 11 | 12 | There are two relevant tricks: 13 | 14 | (1) set the base function ‘base_fun’ to be linear; 15 | 16 | (2) penalize spline coefficients. When spline coefficients are zero, the 17 | activation function is linear. 18 | 19 | :math:`f(x)={\rm sin}(\pi x)`. Although we know a [1,1] KAN suffices, we 20 | suppose we don’t know that and use a [1,1,1,1] KAN instead. 21 | 22 | without trick 23 | 24 | .. code:: ipython3 25 | 26 | from kan import * 27 | 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | print(device) 30 | 31 | # create dataset f(x,y) = sin(pi*x). This task can be achieved by a [1,1] KAN 32 | f = lambda x: torch.sin(torch.pi*x[:,[0]]) 33 | dataset = create_dataset(f, n_var=1, device=device) 34 | 35 | model = KAN(width=[1,1,1,1], grid=5, k=3, seed=0, noise_scale=0.1, device=device) 36 | 37 | model.fit(dataset, opt="LBFGS", steps=20); 38 | 39 | 40 | .. parsed-literal:: 41 | 42 | cuda 43 | checkpoint directory created: ./model 44 | saving model version 0.0 45 | 46 | 47 | .. parsed-literal:: 48 | 49 | | train_loss: 3.74e-04 | test_loss: 3.84e-04 | reg: 8.88e+00 | : 100%|█| 20/20 [00:05<00:00, 3.79it 50 | 51 | .. parsed-literal:: 52 | 53 | saving model version 0.1 54 | 55 | 56 | .. parsed-literal:: 57 | 58 | 59 | 60 | 61 | .. code:: ipython3 62 | 63 | model.plot() 64 | 65 | 66 | 67 | .. image:: Example_11_encouraing_linear_files/Example_11_encouraing_linear_5_0.png 68 | 69 | 70 | with tricks 71 | 72 | .. code:: ipython3 73 | 74 | from kan import * 75 | 76 | # create dataset f(x,y) = sin(pi*x). This task can be achieved by a [1,1] KAN 77 | f = lambda x: torch.sin(torch.pi*x[:,[0]]) 78 | dataset = create_dataset(f, n_var=1, device=device) 79 | 80 | # set base_fun to be linear 81 | model = KAN(width=[1,1,1,1], grid=5, k=3, seed=0, base_fun='identity', noise_scale=0.1, device=device) 82 | 83 | # penality spline coefficients 84 | model.fit(dataset, opt="LBFGS", steps=20, lamb=1e-4, lamb_coef=10.0); 85 | 86 | 87 | .. parsed-literal:: 88 | 89 | checkpoint directory created: ./model 90 | saving model version 0.0 91 | 92 | 93 | .. parsed-literal:: 94 | 95 | | train_loss: 8.89e-03 | test_loss: 8.40e-03 | reg: 1.83e+01 | : 100%|█| 20/20 [00:04<00:00, 4.20it 96 | 97 | .. parsed-literal:: 98 | 99 | saving model version 0.1 100 | 101 | 102 | .. parsed-literal:: 103 | 104 | 105 | 106 | 107 | .. code:: ipython3 108 | 109 | model.plot(beta=10) 110 | 111 | 112 | 113 | .. image:: Example_11_encouraing_linear_files/Example_11_encouraing_linear_8_0.png 114 | 115 | 116 | -------------------------------------------------------------------------------- /docs/Example/Example_11_encouraing_linear_files/Example_11_encouraing_linear_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_11_encouraing_linear_files/Example_11_encouraing_linear_5_0.png -------------------------------------------------------------------------------- /docs/Example/Example_11_encouraing_linear_files/Example_11_encouraing_linear_8_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_11_encouraing_linear_files/Example_11_encouraing_linear_8_0.png -------------------------------------------------------------------------------- /docs/Example/Example_12_unsupervised_learning.rst: -------------------------------------------------------------------------------- 1 | Example 12: Unsupervised learning 2 | ================================= 3 | 4 | In this example, we will use KAN for unsupervised learning. Instead of 5 | trying to figure out how a target variable :math:`y` depends on input 6 | variables, we treat all variables on the equal footing (as input 7 | variables). Below we contruct a synthetic dataset where we have six 8 | variables :math:`x_1, x_2, x_3, x_4, x_5, x_6`. :math:`(x_1, x_2, x_3)` 9 | are dependent such that :math:`x_3={\rm exp}({\rm sin}(\pi x_1)+x_2^2)`; 10 | :math:`(x_4,x_5)` are dependent such that :math:`x_5=x_4^3`. And 11 | :math:`x_6` is independent of all other variables. Can we use KANs to 12 | discover these dependent groups? 13 | 14 | The idea is that we treat the problem as a classification problem. The 15 | dataset that satisfies these interdependent relations are ‘positive’ 16 | samples, while corrupted samples (by random permutation of features 17 | across samples) are ‘negative’ samples. We want to train a KAN to output 18 | 1 when it is a positive sample, and output 0 when it is a negative 19 | sample. We set the last layer activation to be Gaussian, so positive 20 | samples will have zero activation in the second to last layer, while 21 | negtive samples will have non-zero activation in the second to last 22 | layer. We can then define the relation implicitly as :math:`g=0` where 23 | :math:`g` is the activation in the second to last layer. 24 | 25 | Intialize model and create dataset 26 | 27 | .. code:: ipython3 28 | 29 | from kan import KAN 30 | import torch 31 | import copy 32 | 33 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | print(device) 35 | 36 | seed = 1 37 | 38 | model = KAN(width=[6,1,1], grid=3, k=3, seed=seed, device=device) 39 | 40 | # create dataset 41 | 42 | 43 | def create_dataset(train_num=500, test_num=500): 44 | 45 | def generate_contrastive(x): 46 | # positive samples 47 | batch = x.shape[0] 48 | x[:,2] = torch.exp(torch.sin(torch.pi*x[:,0])+x[:,1]**2) 49 | x[:,3] = x[:,4]**3 50 | 51 | # negative samples 52 | def corrupt(tensor): 53 | y = copy.deepcopy(tensor) 54 | for i in range(y.shape[1]): 55 | y[:,i] = y[:,i][torch.randperm(y.shape[0])] 56 | return y 57 | 58 | x_cor = corrupt(x) 59 | x = torch.cat([x, x_cor], dim=0) 60 | y = torch.cat([torch.ones(batch,), torch.zeros(batch,)], dim=0)[:,None] 61 | return x, y 62 | 63 | x = torch.rand(train_num, 6) * 2 - 1 64 | x_train, y_train = generate_contrastive(x) 65 | 66 | x = torch.rand(test_num, 6) * 2 - 1 67 | x_test, y_test = generate_contrastive(x) 68 | 69 | dataset = {} 70 | dataset['train_input'] = x_train.to(device) 71 | dataset['test_input'] = x_test.to(device) 72 | dataset['train_label'] = y_train.to(device) 73 | dataset['test_label'] = y_test.to(device) 74 | return dataset 75 | 76 | dataset = create_dataset() 77 | 78 | 79 | .. parsed-literal:: 80 | 81 | cuda 82 | checkpoint directory created: ./model 83 | saving model version 0.0 84 | 85 | 86 | .. code:: ipython3 87 | 88 | model(dataset['train_input']) 89 | model.plot(beta=10) 90 | 91 | 92 | 93 | .. image:: Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_4_0.png 94 | 95 | 96 | .. code:: ipython3 97 | 98 | # set the (1,0,0) activation to be gausssian 99 | #model.fix_symbolic(1,0,0,lambda x: torch.exp(-x**2/10),fit_params_bool=False) 100 | model.fix_symbolic(1,0,0,'gaussian',fit_params_bool=False) 101 | 102 | 103 | .. parsed-literal:: 104 | 105 | saving model version 0.1 106 | 107 | 108 | .. code:: ipython3 109 | 110 | model(dataset['train_input']) 111 | model.plot(beta=10) 112 | 113 | 114 | 115 | .. image:: Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_6_0.png 116 | 117 | 118 | .. code:: ipython3 119 | 120 | model.fit(dataset, opt="LBFGS", steps=50, lamb=0.002, lamb_entropy=10.0, lamb_coef=1.0); 121 | 122 | 123 | .. parsed-literal:: 124 | 125 | | train_loss: 1.80e-01 | test_loss: 1.78e-01 | reg: 3.77e+01 | : 100%|█| 50/50 [00:13<00:00, 3.76it 126 | 127 | .. parsed-literal:: 128 | 129 | saving model version 0.2 130 | 131 | 132 | .. parsed-literal:: 133 | 134 | 135 | 136 | 137 | .. code:: ipython3 138 | 139 | model.plot(in_vars=[r'$x_{}$'.format(i) for i in range(1,7)]) 140 | 141 | 142 | 143 | .. image:: Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_8_0.png 144 | 145 | 146 | This gives the dependence among :math:`(x_4,x_5)`. Another random seed 147 | can give dependence among :math:`(x_1,x_2,x_3)`. 148 | 149 | -------------------------------------------------------------------------------- /docs/Example/Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_4_0.png -------------------------------------------------------------------------------- /docs/Example/Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_6_0.png -------------------------------------------------------------------------------- /docs/Example/Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_8_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_12_unsupervised_learning_files/Example_12_unsupervised_learning_8_0.png -------------------------------------------------------------------------------- /docs/Example/Example_13_phase_transition.rst: -------------------------------------------------------------------------------- 1 | Example 13: Phase transition 2 | ============================ 3 | 4 | In this example, we will use KAN to learn phase transitions in data. 5 | Phase transition is an important concept in science. We consider a toy 6 | example :math:`f(x_1,x_2,x_3)` is 1 if :math:`g(x_1,x_2,x_3)>0`, and is 7 | 0 if :math:`g(x_1,x_2,x_3)<0`. 8 | :math:`g(x_1,x_2,x_3)={\rm sin}(\pi x_1)+{\rm cos}(\pi x_2)+{\rm tan}(\frac{\pi}{2}x_3)`. 9 | 10 | Intialize model and create dataset 11 | 12 | .. code:: ipython3 13 | 14 | from kan import KAN, create_dataset 15 | import torch 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | print(device) 19 | 20 | model = KAN(width=[3,1,1], grid=3, k=3, device=device) 21 | 22 | # create dataset 23 | f = lambda x: (torch.sin(torch.pi*x[:,[0]]) + torch.cos(torch.pi*x[:,[1]]) + torch.tan(torch.pi/2*x[:,[2]]) > 0).float() 24 | dataset = create_dataset(f, n_var=3, device=device) 25 | 26 | 27 | 28 | .. parsed-literal:: 29 | 30 | cuda 31 | checkpoint directory created: ./model 32 | saving model version 0.0 33 | 34 | 35 | .. code:: ipython3 36 | 37 | torch.mean(dataset['train_label']) 38 | 39 | 40 | 41 | 42 | .. parsed-literal:: 43 | 44 | tensor(0.5060, device='cuda:0') 45 | 46 | 47 | 48 | .. code:: ipython3 49 | 50 | model(dataset['train_input']) 51 | model.plot(beta=10) 52 | 53 | 54 | 55 | .. image:: Example_13_phase_transition_files/Example_13_phase_transition_5_0.png 56 | 57 | 58 | .. code:: ipython3 59 | 60 | # set the last activation to be tanh 61 | model.fix_symbolic(1,0,0,'tanh',fit_params_bool=False) 62 | 63 | 64 | .. parsed-literal:: 65 | 66 | saving model version 0.1 67 | 68 | 69 | .. code:: ipython3 70 | 71 | model(dataset['train_input']) 72 | model.plot(beta=10) 73 | 74 | 75 | 76 | .. image:: Example_13_phase_transition_files/Example_13_phase_transition_7_0.png 77 | 78 | 79 | .. code:: ipython3 80 | 81 | model.fit(dataset, opt="LBFGS", steps=50); 82 | 83 | 84 | .. parsed-literal:: 85 | 86 | | train_loss: 7.71e-02 | test_loss: 1.17e-01 | reg: 2.43e+02 | : 100%|█| 50/50 [00:09<00:00, 5.32it 87 | 88 | 89 | .. parsed-literal:: 90 | 91 | saving model version 0.2 92 | 93 | 94 | .. code:: ipython3 95 | 96 | model.plot(beta=10) 97 | 98 | 99 | 100 | .. image:: Example_13_phase_transition_files/Example_13_phase_transition_9_0.png 101 | 102 | 103 | -------------------------------------------------------------------------------- /docs/Example/Example_13_phase_transition_files/Example_13_phase_transition_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_13_phase_transition_files/Example_13_phase_transition_5_0.png -------------------------------------------------------------------------------- /docs/Example/Example_13_phase_transition_files/Example_13_phase_transition_7_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_13_phase_transition_files/Example_13_phase_transition_7_0.png -------------------------------------------------------------------------------- /docs/Example/Example_13_phase_transition_files/Example_13_phase_transition_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_13_phase_transition_files/Example_13_phase_transition_9_0.png -------------------------------------------------------------------------------- /docs/Example/Example_1_function_fitting_files/Example_1_function_fitting_12_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_1_function_fitting_files/Example_1_function_fitting_12_0.png -------------------------------------------------------------------------------- /docs/Example/Example_1_function_fitting_files/Example_1_function_fitting_14_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_1_function_fitting_files/Example_1_function_fitting_14_1.png -------------------------------------------------------------------------------- /docs/Example/Example_3_deep_formula_files/Example_3_deep_formula_11_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_3_deep_formula_files/Example_3_deep_formula_11_1.png -------------------------------------------------------------------------------- /docs/Example/Example_3_deep_formula_files/Example_3_deep_formula_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_3_deep_formula_files/Example_3_deep_formula_4_0.png -------------------------------------------------------------------------------- /docs/Example/Example_3_deep_formula_files/Example_3_deep_formula_7_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_3_deep_formula_files/Example_3_deep_formula_7_1.png -------------------------------------------------------------------------------- /docs/Example/Example_3_deep_formula_files/Example_3_deep_formula_9_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_3_deep_formula_files/Example_3_deep_formula_9_3.png -------------------------------------------------------------------------------- /docs/Example/Example_4_classfication_files/Example_4_classfication_12_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_4_classfication_files/Example_4_classfication_12_1.png -------------------------------------------------------------------------------- /docs/Example/Example_4_classfication_files/Example_4_classfication_3_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_4_classfication_files/Example_4_classfication_3_2.png -------------------------------------------------------------------------------- /docs/Example/Example_5_special_functions_files/Example_5_special_functions_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_5_special_functions_files/Example_5_special_functions_4_0.png -------------------------------------------------------------------------------- /docs/Example/Example_5_special_functions_files/Example_5_special_functions_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_5_special_functions_files/Example_5_special_functions_6_0.png -------------------------------------------------------------------------------- /docs/Example/Example_6_PDE_interpretation_files/Example_6_PDE_interpretation_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_6_PDE_interpretation_files/Example_6_PDE_interpretation_4_0.png -------------------------------------------------------------------------------- /docs/Example/Example_7_PDE_accuracy_files/Example_7_PDE_accuracy_3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_7_PDE_accuracy_files/Example_7_PDE_accuracy_3_1.png -------------------------------------------------------------------------------- /docs/Example/Example_8_continual_learning_files/Example_8_continual_learning_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_8_continual_learning_files/Example_8_continual_learning_2_1.png -------------------------------------------------------------------------------- /docs/Example/Example_8_continual_learning_files/Example_8_continual_learning_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_8_continual_learning_files/Example_8_continual_learning_4_0.png -------------------------------------------------------------------------------- /docs/Example/Example_8_continual_learning_files/Example_8_continual_learning_8_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_8_continual_learning_files/Example_8_continual_learning_8_0.png -------------------------------------------------------------------------------- /docs/Example/Example_9_singularity_files/Example_9_singularity_3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_9_singularity_files/Example_9_singularity_3_0.png -------------------------------------------------------------------------------- /docs/Example/Example_9_singularity_files/Example_9_singularity_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Example/Example_9_singularity_files/Example_9_singularity_9_0.png -------------------------------------------------------------------------------- /docs/Interp/.ipynb_checkpoints/Interp_11_sparse_init-checkpoint.rst: -------------------------------------------------------------------------------- 1 | Interpretability 11: sparse initialization 2 | ========================================== 3 | 4 | .. code:: ipython3 5 | 6 | from kan import * 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | print(device) 10 | 11 | model = KAN([5,5,5,1], sparse_init=False, device=device) 12 | x = torch.rand(100,5).to(device) 13 | model.get_act(x) 14 | model.plot() 15 | 16 | 17 | .. parsed-literal:: 18 | 19 | cuda 20 | checkpoint directory created: ./model 21 | saving model version 0.0 22 | 23 | 24 | 25 | .. image:: Interp_11_sparse_init_files/Interp_11_sparse_init_1_1.png 26 | 27 | 28 | .. code:: ipython3 29 | 30 | from kan import * 31 | 32 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | print(device) 34 | 35 | model = KAN([5,5,5,1], sparse_init=True, device=device) 36 | x = torch.rand(100,5).to(device) 37 | model.get_act(x) 38 | model.plot() 39 | 40 | 41 | .. parsed-literal:: 42 | 43 | cuda 44 | checkpoint directory created: ./model 45 | saving model version 0.0 46 | 47 | 48 | 49 | .. image:: Interp_11_sparse_init_files/Interp_11_sparse_init_2_1.png 50 | 51 | 52 | -------------------------------------------------------------------------------- /docs/Interp/Interp_10A_swap_files/Interp_10A_swap_11_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10A_swap_files/Interp_10A_swap_11_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10A_swap_files/Interp_10A_swap_13_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10A_swap_files/Interp_10A_swap_13_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10A_swap_files/Interp_10A_swap_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10A_swap_files/Interp_10A_swap_6_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10A_swap_files/Interp_10A_swap_8_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10A_swap_files/Interp_10A_swap_8_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10B_swap.rst: -------------------------------------------------------------------------------- 1 | Interpretability 10B: swap 2 | ========================== 3 | 4 | The multitask parity problem has 10 input bits 5 | :math:`(x_1, x_2, \cdots, x_{10})`, :math:`x_i\in\{0,1\}`. 6 | 7 | The are five output bits :math:`y_1, \cdots, y_5`, where 8 | :math:`y_i = x_{2i-1} + x_{2i-1} ({\rm mod} 2)` 9 | 10 | .. code:: ipython3 11 | 12 | from kan import * 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | print(device) 16 | 17 | model = KAN(width=[10,10,5], seed=1, device=device) 18 | x = torch.normal(0,1,size=(100,2), device=device) 19 | 20 | #f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) 21 | f = lambda x: torch.cat([x[:,[0]] + x[:,[1]], x[:,[2]] + x[:,[3]], x[:,[4]] + x[:,[5]], x[:,[6]] + x[:,[7]], x[:,[8]] + x[:,[9]]], dim=1) 22 | dataset = create_dataset(f, n_var=10, device=device) 23 | model.fit(dataset, steps=20, lamb=1e-2); 24 | 25 | 26 | 27 | .. parsed-literal:: 28 | 29 | cuda 30 | checkpoint directory created: ./model 31 | saving model version 0.0 32 | 33 | 34 | .. parsed-literal:: 35 | 36 | | train_loss: 8.26e-02 | test_loss: 7.72e-02 | reg: 1.66e+01 | : 100%|█| 20/20 [00:04<00:00, 4.93it 37 | 38 | .. parsed-literal:: 39 | 40 | saving model version 0.1 41 | 42 | 43 | .. parsed-literal:: 44 | 45 | 46 | 47 | 48 | .. code:: ipython3 49 | 50 | model.plot() 51 | 52 | 53 | 54 | .. image:: Interp_10B_swap_files/Interp_10B_swap_3_0.png 55 | 56 | 57 | .. code:: ipython3 58 | 59 | model.auto_swap() 60 | 61 | 62 | .. parsed-literal:: 63 | 64 | saving model version 0.2 65 | 66 | 67 | .. code:: ipython3 68 | 69 | model.plot() 70 | 71 | 72 | 73 | .. image:: Interp_10B_swap_files/Interp_10B_swap_5_0.png 74 | 75 | 76 | .. code:: ipython3 77 | 78 | # MLP 79 | from kan import * 80 | from kan.MLP import MLP 81 | 82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 83 | print(device) 84 | 85 | inputs = [] 86 | for i in range(2**10): 87 | string = "{0:b}".format(i) 88 | sample = [int(string[i]) for i in range(len(string))] 89 | sample = (10 - len(sample)) * [0] + sample 90 | inputs.append(sample) 91 | 92 | inputs = np.array(inputs).astype(np.float32) 93 | labels = np.sum(inputs.reshape(2**10,5,2), axis=2) % 2 94 | inputs = torch.tensor(inputs) 95 | labels = torch.tensor(labels) 96 | 97 | dataset = create_dataset_from_data(inputs, labels, device=device) 98 | 99 | model = MLP(width=[10,20,5], seed=5, device=device) 100 | model.fit(dataset, steps=100, lamb=2e-4, reg_metric='w'); 101 | 102 | 103 | .. parsed-literal:: 104 | 105 | cuda 106 | 107 | 108 | .. parsed-literal:: 109 | 110 | | train_loss: 4.58e-03 | test_loss: 4.63e-03 | reg: 5.09e+01 | : 100%|█| 100/100 [00:04<00:00, 23.41 111 | 112 | 113 | .. code:: ipython3 114 | 115 | model.plot(scale=1.5) 116 | 117 | 118 | 119 | .. image:: Interp_10B_swap_files/Interp_10B_swap_7_0.png 120 | 121 | 122 | .. code:: ipython3 123 | 124 | model.auto_swap() 125 | 126 | .. code:: ipython3 127 | 128 | model.plot(scale=1.5) 129 | 130 | 131 | 132 | .. image:: Interp_10B_swap_files/Interp_10B_swap_9_0.png 133 | 134 | 135 | .. code:: ipython3 136 | 137 | model.auto_swap() 138 | 139 | .. code:: ipython3 140 | 141 | model.plot(scale=1.5) 142 | 143 | 144 | 145 | .. image:: Interp_10B_swap_files/Interp_10B_swap_11_0.png 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /docs/Interp/Interp_10B_swap_files/Interp_10B_swap_11_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10B_swap_files/Interp_10B_swap_11_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10B_swap_files/Interp_10B_swap_3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10B_swap_files/Interp_10B_swap_3_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10B_swap_files/Interp_10B_swap_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10B_swap_files/Interp_10B_swap_5_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10B_swap_files/Interp_10B_swap_7_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10B_swap_files/Interp_10B_swap_7_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10B_swap_files/Interp_10B_swap_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10B_swap_files/Interp_10B_swap_9_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10_hessian.rst: -------------------------------------------------------------------------------- 1 | Interpretability 10: Hessian 2 | ============================ 3 | 4 | To understand the loss lanscape, we compute the hessian (loss wrt model 5 | parameters) and get its eigenvalues 6 | 7 | Try both KAN and MLP, you will usually see that KANs have more non-zero 8 | eigenvalues than MLPs, meaning that KANs have more effective number of 9 | parameters than MLP. 10 | 11 | .. code:: ipython3 12 | 13 | from kan.utils import get_derivative 14 | import torch 15 | from kan.MLP import MLP 16 | from kan.MultKAN import KAN 17 | from kan.utils import create_dataset, model2param 18 | import copy 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | print(device) 22 | 23 | f = lambda x: x[:,[0]]**2 24 | dataset = create_dataset(f, n_var=1, train_num=1000, device=device) 25 | 26 | inputs = dataset['train_input'] 27 | labels = dataset['train_label'] 28 | 29 | #model = MLP(width = [1,30,1]) 30 | model = KAN(width=[1,5,1], device=device) 31 | model.fit(dataset, opt='Adam', lr=1e-2, lamb=0.000, steps=1000); 32 | 33 | 34 | .. parsed-literal:: 35 | 36 | cuda 37 | checkpoint directory created: ./model 38 | saving model version 0.0 39 | 40 | 41 | .. parsed-literal:: 42 | 43 | | train_loss: 8.51e-04 | test_loss: 8.26e-04 | reg: 1.11e+01 | : 100%|█| 1000/1000 [00:08<00:00, 114 44 | 45 | 46 | .. parsed-literal:: 47 | 48 | saving model version 0.1 49 | 50 | 51 | .. code:: ipython3 52 | 53 | model.plot() 54 | 55 | 56 | 57 | .. image:: Interp_10_hessian_files/Interp_10_hessian_4_0.png 58 | 59 | 60 | .. code:: ipython3 61 | 62 | hess = get_derivative(model, inputs, labels, derivative='hessian') 63 | values, vectors = torch.linalg.eigh(hess) 64 | 65 | .. code:: ipython3 66 | 67 | import matplotlib.pyplot as plt 68 | plt.plot(values.cpu().numpy()[0], marker='o'); 69 | plt.yscale('log') 70 | 71 | 72 | 73 | .. image:: Interp_10_hessian_files/Interp_10_hessian_6_0.png 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /docs/Interp/Interp_10_hessian_files/Interp_10_hessian_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10_hessian_files/Interp_10_hessian_4_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_10_hessian_files/Interp_10_hessian_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_10_hessian_files/Interp_10_hessian_6_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_11_sparse_init.rst: -------------------------------------------------------------------------------- 1 | Interpretability 11: sparse initialization 2 | ========================================== 3 | 4 | .. code:: ipython3 5 | 6 | from kan import * 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | print(device) 10 | 11 | model = KAN([5,5,5,1], sparse_init=False, device=device) 12 | x = torch.rand(100,5).to(device) 13 | model.get_act(x) 14 | model.plot() 15 | 16 | 17 | .. parsed-literal:: 18 | 19 | cuda 20 | checkpoint directory created: ./model 21 | saving model version 0.0 22 | 23 | 24 | 25 | .. image:: Interp_11_sparse_init_files/Interp_11_sparse_init_1_1.png 26 | 27 | 28 | .. code:: ipython3 29 | 30 | from kan import * 31 | 32 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | print(device) 34 | 35 | model = KAN([5,5,5,1], sparse_init=True, device=device) 36 | x = torch.rand(100,5).to(device) 37 | model.get_act(x) 38 | model.plot() 39 | 40 | 41 | .. parsed-literal:: 42 | 43 | cuda 44 | checkpoint directory created: ./model 45 | saving model version 0.0 46 | 47 | 48 | 49 | .. image:: Interp_11_sparse_init_files/Interp_11_sparse_init_2_1.png 50 | 51 | 52 | -------------------------------------------------------------------------------- /docs/Interp/Interp_11_sparse_init_files/Interp_11_sparse_init_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_11_sparse_init_files/Interp_11_sparse_init_1_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_11_sparse_init_files/Interp_11_sparse_init_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_11_sparse_init_files/Interp_11_sparse_init_2_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_1_Hello, MultKAN_files/Interp_1_Hello, MultKAN_11_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_1_Hello, MultKAN_files/Interp_1_Hello, MultKAN_11_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_1_Hello, MultKAN_files/Interp_1_Hello, MultKAN_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_1_Hello, MultKAN_files/Interp_1_Hello, MultKAN_4_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_1_Hello, MultKAN_files/Interp_1_Hello, MultKAN_7_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_1_Hello, MultKAN_files/Interp_1_Hello, MultKAN_7_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_1_Hello, MultKAN_files/Interp_1_Hello, MultKAN_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_1_Hello, MultKAN_files/Interp_1_Hello, MultKAN_9_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_2_Advanced MultKAN.rst: -------------------------------------------------------------------------------- 1 | Interpretability 2: Advanced MultKAN 2 | ==================================== 3 | 4 | In the last tutorial, we introduced multiplications to KANs which makes 5 | interpretation easier in the case when multiplications are needed. 6 | Multiplication nodes by default takes in two numbers, but can take more 7 | variables specified by the user. This is done through the mult_arity 8 | argument (by default mult_arity=2). 9 | 10 | .. code:: ipython3 11 | 12 | from kan import * 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | print(device) 16 | 17 | model = KAN(width=[2,[3,2],1], device=device) 18 | x = torch.randn(100,2).to(device) 19 | model(x) 20 | model.plot() 21 | 22 | 23 | .. parsed-literal:: 24 | 25 | cuda 26 | checkpoint directory created: ./model 27 | saving model version 0.0 28 | 29 | 30 | 31 | .. image:: Interp_2_Advanced%20MultKAN_files/Interp_2_Advanced%20MultKAN_2_1.png 32 | 33 | 34 | mult_arity=3 35 | 36 | .. code:: ipython3 37 | 38 | model = KAN(width=[2,[3,2],1], mult_arity=3, device=device) 39 | model(x) 40 | model.plot() 41 | 42 | 43 | .. parsed-literal:: 44 | 45 | checkpoint directory created: ./model 46 | saving model version 0.0 47 | 48 | 49 | 50 | .. image:: Interp_2_Advanced%20MultKAN_files/Interp_2_Advanced%20MultKAN_4_1.png 51 | 52 | 53 | mult_arity=4 54 | 55 | .. code:: ipython3 56 | 57 | model = KAN(width=[2,[3,2],1], mult_arity=4, device=device) 58 | model(x) 59 | model.plot() 60 | 61 | 62 | .. parsed-literal:: 63 | 64 | checkpoint directory created: ./model 65 | saving model version 0.0 66 | 67 | 68 | 69 | .. image:: Interp_2_Advanced%20MultKAN_files/Interp_2_Advanced%20MultKAN_6_1.png 70 | 71 | 72 | You may want different multiplication nodes to take in different number 73 | of variables. This is also possible: pass in mult_arity as a list of 74 | lists, specifying the arities in each layer, including input layer, 75 | hidden layer(s), and output layer. 76 | 77 | In the following example, we have 0 multiplications in the input or in 78 | the output layer, corresponding to empty lists. In the hidden layer, we 79 | have two multiplications with arity = 2 and arity = 3, so we have the 80 | list [2,3] in the middle. 81 | 82 | .. code:: ipython3 83 | 84 | model = KAN(width=[2,[3,2],1], mult_arity=[[],[2,3],[]], device=device) 85 | model(x) 86 | model.plot() 87 | 88 | 89 | .. parsed-literal:: 90 | 91 | checkpoint directory created: ./model 92 | saving model version 0.0 93 | 94 | 95 | 96 | .. image:: Interp_2_Advanced%20MultKAN_files/Interp_2_Advanced%20MultKAN_9_1.png 97 | 98 | 99 | Make a deeper network 100 | 101 | .. code:: ipython3 102 | 103 | model = KAN(width=[2,[2,2],[1,3],[3,2],[1,1]], mult_arity=2, device=device) 104 | model(x) 105 | model.plot() 106 | 107 | 108 | .. parsed-literal:: 109 | 110 | checkpoint directory created: ./model 111 | saving model version 0.0 112 | 113 | 114 | 115 | .. image:: Interp_2_Advanced%20MultKAN_files/Interp_2_Advanced%20MultKAN_11_1.png 116 | 117 | 118 | -------------------------------------------------------------------------------- /docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_11_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_11_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_2_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_4_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_6_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_6_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_9_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_2_Advanced MultKAN_files/Interp_2_Advanced MultKAN_9_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_11_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_11_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_13_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_13_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_15_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_15_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_16_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_16_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_2_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_4_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_3_KAN_Compiler_files/Interp_3_KAN_Compiler_9_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_10_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_10_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_13_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_13_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_15_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_15_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_17_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_17_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_3_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_7_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_7_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_8_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_4_feature_attribution_files/Interp_4_feature_attribution_8_1.png -------------------------------------------------------------------------------- /docs/Interp/Interp_5_test_symmetry_files/Interp_5_test_symmetry_16_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_5_test_symmetry_files/Interp_5_test_symmetry_16_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_5_test_symmetry_files/Interp_5_test_symmetry_17_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_5_test_symmetry_files/Interp_5_test_symmetry_17_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_5_test_symmetry_files/Interp_5_test_symmetry_18_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_5_test_symmetry_files/Interp_5_test_symmetry_18_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_5_test_symmetry_files/Interp_5_test_symmetry_19_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_5_test_symmetry_files/Interp_5_test_symmetry_19_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_6_test_symmetry_NN.rst: -------------------------------------------------------------------------------- 1 | Interprebility 6: Test symmetries of trained NN 2 | =============================================== 3 | 4 | .. code:: ipython3 5 | 6 | from kan import * 7 | from kan.hypothesis import plot_tree 8 | 9 | f = lambda x: (x[:,[0]]**2 + x[:,[1]]**2) ** 2 + (x[:,[2]]**2 + x[:,[3]]**2) ** 2 10 | x = torch.rand(100,4) * 2 - 1 11 | plot_tree(f, x) 12 | 13 | 14 | 15 | .. image:: Interp_6_test_symmetry_NN_files/Interp_6_test_symmetry_NN_1_0.png 16 | 17 | 18 | .. code:: ipython3 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | print(device) 22 | 23 | dataset = create_dataset(f, n_var=4, device=device) 24 | model = KAN(width=[4,5,5,1], seed=0, device=device) 25 | model.fit(dataset, steps=100); 26 | 27 | 28 | .. parsed-literal:: 29 | 30 | cuda 31 | checkpoint directory created: ./model 32 | saving model version 0.0 33 | 34 | 35 | .. parsed-literal:: 36 | 37 | | train_loss: 1.58e-03 | test_loss: 4.79e-03 | reg: 2.38e+01 | : 100%|█| 100/100 [00:20<00:00, 4.93 38 | 39 | .. parsed-literal:: 40 | 41 | saving model version 0.1 42 | 43 | 44 | .. parsed-literal:: 45 | 46 | 47 | 48 | 49 | .. code:: ipython3 50 | 51 | model.tree(sym_th=1e-2, sep_th=5e-1) 52 | 53 | 54 | 55 | .. image:: Interp_6_test_symmetry_NN_files/Interp_6_test_symmetry_NN_3_0.png 56 | 57 | 58 | -------------------------------------------------------------------------------- /docs/Interp/Interp_6_test_symmetry_NN_files/Interp_6_test_symmetry_NN_1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_6_test_symmetry_NN_files/Interp_6_test_symmetry_NN_1_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_6_test_symmetry_NN_files/Interp_6_test_symmetry_NN_3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_6_test_symmetry_NN_files/Interp_6_test_symmetry_NN_3_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_8_adding_auxillary_variables.rst: -------------------------------------------------------------------------------- 1 | Interpretability 8: Adding auxiliary variables 2 | ============================================== 3 | 4 | When we do a regression task, it might be good to include auxiliary 5 | input variables, even though they might be dependent on other variables. 6 | For example, to regress :math:`m(m_0, v, c)=m_0/\sqrt{1-(v/c)^2}`, it is 7 | desirable to include the dimensionaless varabile :math:`\beta = v/c` as 8 | a separate input variable. If we also know this is a task in relativity, 9 | we may also include :math:`\gamma=1/\sqrt{1-(v/c)^2}` because 10 | :math:`\gamma` appears frequently in relativity. 11 | 12 | .. code:: ipython3 13 | 14 | from kan.MultKAN import MultKAN 15 | from sympy import * 16 | from kan.utils import create_dataset, augment_input 17 | import torch 18 | 19 | seed = 1 20 | torch.manual_seed(seed) 21 | torch.set_default_dtype(torch.float64) 22 | 23 | input_variables = m0, v, c = symbols('m0 v c') 24 | 25 | # define auxillary variables 26 | beta = v/c 27 | gamma = 1/sqrt(1-beta**2) 28 | 29 | aux_vars = (beta, gamma) 30 | 31 | f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2) 32 | dataset = create_dataset(f, n_var=3, ranges=[[0,1],[0,0.9],[1.1,2]]) 33 | 34 | # add auxillary variables 35 | dataset = augment_input(input_variables, aux_vars, dataset) 36 | input_variables = aux_vars + input_variables 37 | 38 | .. code:: ipython3 39 | 40 | model = MultKAN(width=[5,[0,1]], mult_arity=2, grid=3, k=3, seed=seed) 41 | 42 | 43 | .. parsed-literal:: 44 | 45 | checkpoint directory created: ./model 46 | saving model version 0.0 47 | 48 | 49 | .. code:: ipython3 50 | 51 | model(dataset['train_input']) 52 | model.plot(in_vars=input_variables, out_vars=[m0/sqrt(1-v**2/c**2)], scale=1.0, varscale=0.7) 53 | 54 | 55 | 56 | .. image:: Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_4_0.png 57 | 58 | 59 | .. code:: ipython3 60 | 61 | model.fit(dataset, steps=50, lamb=1e-5, lamb_coef=1.0); 62 | 63 | 64 | .. parsed-literal:: 65 | 66 | | train_loss: 5.13e-04 | test_loss: 6.64e-04 | reg: 3.18e+00 | : 100%|█| 50/50 [00:07<00:00, 7.10it 67 | 68 | .. parsed-literal:: 69 | 70 | saving model version 0.1 71 | 72 | 73 | .. parsed-literal:: 74 | 75 | 76 | 77 | 78 | .. code:: ipython3 79 | 80 | model.plot(in_vars=input_variables, out_vars=[m0/sqrt(1-v**2/c**2)], scale=1.0, varscale=0.7) 81 | 82 | 83 | 84 | .. image:: Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_6_0.png 85 | 86 | 87 | .. code:: ipython3 88 | 89 | model = model.prune(edge_th=5e-2) 90 | 91 | 92 | .. parsed-literal:: 93 | 94 | saving model version 0.2 95 | 96 | 97 | .. code:: ipython3 98 | 99 | model.plot(in_vars=input_variables, out_vars=[m0/sqrt(1-v**2/c**2)], scale=1.0, varscale=0.7) 100 | 101 | 102 | 103 | .. image:: Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_8_0.png 104 | 105 | 106 | .. code:: ipython3 107 | 108 | model.fit(dataset, steps=100, lamb=0e-3); 109 | 110 | 111 | .. parsed-literal:: 112 | 113 | | train_loss: 3.15e-06 | test_loss: 1.99e-05 | reg: 2.74e+00 | : 100%|█| 100/100 [00:10<00:00, 9.48 114 | 115 | .. parsed-literal:: 116 | 117 | saving model version 0.3 118 | 119 | 120 | .. parsed-literal:: 121 | 122 | 123 | 124 | 125 | .. code:: ipython3 126 | 127 | model.auto_symbolic() 128 | 129 | 130 | .. parsed-literal:: 131 | 132 | fixing (0,0,0) with 0 133 | fixing (0,0,1) with 0 134 | fixing (0,1,0) with x, r2=0.999998976626967, c=1 135 | fixing (0,1,1) with 0 136 | fixing (0,2,0) with 0 137 | fixing (0,2,1) with x, r2=0.9999999998075859, c=1 138 | fixing (0,3,0) with 0 139 | fixing (0,3,1) with 0 140 | fixing (0,4,0) with 0 141 | fixing (0,4,1) with 0 142 | saving model version 0.4 143 | 144 | 145 | .. code:: ipython3 146 | 147 | sf = model.symbolic_formula(var=input_variables)[0][0] 148 | sf 149 | 150 | 151 | 152 | 153 | .. math:: 154 | 155 | \displaystyle 1.0 \cdot \left(0.000189505852432992 - \frac{0.817980335069318}{\sqrt{1 - \frac{v^{2}}{c^{2}}}}\right) \left(- 1.22278885546569 m_{0} - 2.33019836537451 \cdot 10^{-7}\right) 156 | 157 | 158 | 159 | .. code:: ipython3 160 | 161 | from kan.utils import ex_round 162 | 163 | nsimplify(ex_round(ex_round(ex_round(sf,6),3),3)) 164 | 165 | 166 | 167 | 168 | .. math:: 169 | 170 | \displaystyle \frac{m_{0}}{\sqrt{1 - \frac{v^{2}}{c^{2}}}} 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /docs/Interp/Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_4_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_6_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_8_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_8_adding_auxillary_variables_files/Interp_8_adding_auxillary_variables_8_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_9_different_plotting_metrics.rst: -------------------------------------------------------------------------------- 1 | Interpretability 9: Different plotting metrics 2 | ============================================== 3 | 4 | .. code:: ipython3 5 | 6 | from kan import * 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | print(device) 10 | 11 | model = KAN(width=[2,5,1], device=device) 12 | f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) 13 | dataset = create_dataset(f, n_var=2, device=device) 14 | model.fit(dataset, steps = 20, lamb=1e-3); 15 | 16 | 17 | .. parsed-literal:: 18 | 19 | cuda 20 | checkpoint directory created: ./model 21 | saving model version 0.0 22 | 23 | 24 | .. parsed-literal:: 25 | 26 | | train_loss: 1.48e-02 | test_loss: 1.53e-02 | reg: 7.01e+00 | : 100%|█| 20/20 [00:04<00:00, 4.64it 27 | 28 | 29 | .. parsed-literal:: 30 | 31 | saving model version 0.1 32 | 33 | 34 | Note: To plot the KAN diagram, there are also three options \* 35 | forward_u: the “norm” of edge, normalized (output std/input std) \* 36 | forward_n: the “norm” of edge, unnormalized (output std) \* backward: 37 | the edge attribution score (default) 38 | 39 | .. code:: ipython3 40 | 41 | model.plot(metric='forward_u') 42 | 43 | 44 | 45 | .. image:: Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_3_0.png 46 | 47 | 48 | .. code:: ipython3 49 | 50 | model.plot(metric='forward_n') 51 | 52 | 53 | 54 | .. image:: Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_4_0.png 55 | 56 | 57 | .. code:: ipython3 58 | 59 | model.plot(metric='backward') 60 | 61 | 62 | 63 | .. image:: Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_5_0.png 64 | 65 | -------------------------------------------------------------------------------- /docs/Interp/Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_3_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_4_0.png -------------------------------------------------------------------------------- /docs/Interp/Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Interp/Interp_9_different_plotting_metrics_files/Interp_9_different_plotting_metrics_5_0.png -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_10_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_12_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_12_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_15_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_15_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_3_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_4_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_1_Lagrangian_files/Physics_1_Lagrangian_6_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_2A_conservation_law.rst: -------------------------------------------------------------------------------- 1 | Physics 2A: Conservation Laws 2 | ============================= 3 | 4 | .. code:: ipython3 5 | 6 | from kan import * 7 | from kan.utils import batch_jacobian, create_dataset_from_data 8 | import numpy as np 9 | 10 | model = KAN(width=[2,1], seed=42) 11 | 12 | # the model learns the Hamiltonian H = 1/2 * (x**2 + p**2) 13 | x = torch.rand(1000,2) * 2 - 1 14 | flow = torch.cat([x[:,[1]], -x[:,[0]]], dim=1) 15 | 16 | def pred_fn(model, x): 17 | grad = batch_jacobian(model, x, create_graph=True) 18 | grad_normalized = grad/torch.linalg.norm(grad, dim=1, keepdim=True) 19 | return grad_normalized 20 | 21 | loss_fn = lambda grad_normalized, flow: torch.mean(torch.sum(flow * grad_normalized, dim=1)**2) 22 | 23 | 24 | dataset = create_dataset_from_data(x, flow) 25 | model.fit(dataset, steps=20, pred_fn=pred_fn, loss_fn=loss_fn); 26 | 27 | 28 | .. parsed-literal:: 29 | 30 | checkpoint directory created: ./model 31 | saving model version 0.0 32 | 33 | 34 | .. parsed-literal:: 35 | 36 | | train_loss: 1.07e-04 | test_loss: 1.17e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:01<00:00, 16.52it 37 | 38 | .. parsed-literal:: 39 | 40 | saving model version 0.1 41 | 42 | 43 | .. parsed-literal:: 44 | 45 | 46 | 47 | 48 | .. code:: ipython3 49 | 50 | model.plot() 51 | 52 | 53 | 54 | .. image:: Physics_2A_conservation_law_files/Physics_2A_conservation_law_2_0.png 55 | 56 | 57 | .. code:: ipython3 58 | 59 | model.auto_symbolic() 60 | 61 | 62 | .. parsed-literal:: 63 | 64 | fixing (0,0,0) with x^2, r2=1.0000003576278687, c=2 65 | fixing (0,1,0) with x^2, r2=1.0000004768371582, c=2 66 | saving model version 0.2 67 | 68 | 69 | .. code:: ipython3 70 | 71 | from kan.utils import ex_round 72 | ex_round(model.symbolic_formula()[0][0], 3) 73 | 74 | 75 | 76 | 77 | .. math:: 78 | 79 | \displaystyle - 1.191 x_{1}^{2} - 1.191 x_{2}^{2} + 2.329 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /docs/Physics/Physics_2A_conservation_law_files/Physics_2A_conservation_law_2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_2A_conservation_law_files/Physics_2A_conservation_law_2_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_2B_conservation_law_2D_files/Physics_2B_conservation_law_2D_12_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_2B_conservation_law_2D_files/Physics_2B_conservation_law_2D_12_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_2B_conservation_law_2D_files/Physics_2B_conservation_law_2D_2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_2B_conservation_law_2D_files/Physics_2B_conservation_law_2D_2_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_2B_conservation_law_2D_files/Physics_2B_conservation_law_2D_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_2B_conservation_law_2D_files/Physics_2B_conservation_law_2D_6_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_3_blackhole_files/Physics_3_blackhole_10_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_3_blackhole_files/Physics_3_blackhole_10_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_3_blackhole_files/Physics_3_blackhole_5_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_3_blackhole_files/Physics_3_blackhole_5_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_3_blackhole_files/Physics_3_blackhole_6_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_3_blackhole_files/Physics_3_blackhole_6_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_3_blackhole_files/Physics_3_blackhole_7_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_3_blackhole_files/Physics_3_blackhole_7_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_10_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_11_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_11_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_14_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_14_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_2_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_3_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_5_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_8_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4A_constitutive_laws_P11_files/Physics_4A_constitutive_laws_P11_8_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_10_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_10_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_13_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_13_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_2_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_3_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_4_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_6_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4B_constitutive_laws_P12_with_prior_files/Physics_4B_constitutive_laws_P12_with_prior_9_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4C_constitutive_laws_P12_without_prior_files/Physics_4C_constitutive_laws_P12_without_prior_3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4C_constitutive_laws_P12_without_prior_files/Physics_4C_constitutive_laws_P12_without_prior_3_1.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4C_constitutive_laws_P12_without_prior_files/Physics_4C_constitutive_laws_P12_without_prior_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4C_constitutive_laws_P12_without_prior_files/Physics_4C_constitutive_laws_P12_without_prior_4_0.png -------------------------------------------------------------------------------- /docs/Physics/Physics_4C_constitutive_laws_P12_without_prior_files/Physics_4C_constitutive_laws_P12_without_prior_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/Physics/Physics_4C_constitutive_laws_P12_without_prior_files/Physics_4C_constitutive_laws_P12_without_prior_6_0.png -------------------------------------------------------------------------------- /docs/community.rst: -------------------------------------------------------------------------------- 1 | .. _community: 2 | 3 | Community 4 | --------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | Community/Community_1_physics_informed_kan.rst 10 | Community/Community_2_protein_sequence_classification.rst 11 | 12 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import sphinx_rtd_theme 2 | 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # For the full list of built-in configuration values, see the documentation: 6 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 7 | 8 | # -- Project information ----------------------------------------------------- 9 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 10 | 11 | project = 'Kolmogorov Arnold Network' 12 | copyright = '2024, Ziming Liu' 13 | author = 'Ziming Liu' 14 | 15 | # -- General configuration --------------------------------------------------- 16 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 17 | 18 | extensions = ["sphinx_rtd_theme", 19 | "sphinx.ext.autodoc", 20 | "sphinx.ext.autosectionlabel" 21 | ] 22 | 23 | templates_path = ['_templates'] 24 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 25 | 26 | 27 | 28 | # -- Options for HTML output ------------------------------------------------- 29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 30 | 31 | #html_theme = 'alabaster' 32 | html_theme = "sphinx_rtd_theme" 33 | html_static_path = ['_static'] 34 | 35 | def skip(app, what, name, obj, would_skip, options): 36 | if name == "__init__": 37 | return False 38 | return would_skip 39 | 40 | def setup(app): 41 | app.connect("autodoc-skip-member", skip) 42 | 43 | autodoc_mock_imports = ["numpy", 44 | "torch", 45 | "torch.nn", 46 | "matplotlib", 47 | "matplotlib.pyplot", 48 | "tqdm", 49 | "sympy", 50 | "scipy", 51 | "sklearn", 52 | "torch.optim", 53 | "re", 54 | "yaml", 55 | "pandas"] 56 | 57 | 58 | source_suffix = [".rst", ".md"] 59 | #source_suffix = [".rst", ".md", ".ipynb"] 60 | #source_suffix = { 61 | # '.rst': 'restructuredtext', 62 | # '.ipynb': 'myst-nb', 63 | # '.myst': 'myst-nb', 64 | #} 65 | -------------------------------------------------------------------------------- /docs/demos.rst: -------------------------------------------------------------------------------- 1 | .. _api-demo: 2 | 3 | API Demos 4 | --------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | API_demo/API_1_indexing.rst 10 | API_demo/API_2_plotting.rst 11 | API_demo/API_3_extract_activations.rst 12 | API_demo/API_4_initialization.rst 13 | API_demo/API_5_grid.rst 14 | API_demo/API_6_training_hyperparameter.rst 15 | API_demo/API_7_pruning.rst 16 | API_demo/API_8_regularization.rst 17 | API_demo/API_9_video.rst 18 | API_demo/API_10_device.rst 19 | API_demo/API_11_create_dataset.rst 20 | API_demo/API_12_checkpoint_save_load_model.rst -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | .. _examples: 2 | 3 | Examples 4 | -------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | Example/Example_1_function_fitting.rst 10 | Example/Example_3_deep_formula.rst 11 | Example/Example_4_classfication.rst 12 | Example/Example_5_special_functions.rst 13 | Example/Example_6_PDE_interpretation.rst 14 | Example/Example_7_PDE_accuracy.rst 15 | Example/Example_8_continual_learning.rst 16 | Example/Example_9_singularity.rst 17 | Example/Example_10_relativity-addition.rst 18 | Example/Example_11_encouraing_linear.rst 19 | Example/Example_12_unsupervised_learning.rst 20 | Example/Example_13_phase_transition.rst 21 | Example/Example_14_knot_supervised.rst 22 | Example/Example_15_knot_unsupervised.rst 23 | 24 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. kolmogorov-arnold-network documentation master file, created by 2 | sphinx-quickstart on Sun Apr 21 12:57:28 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Kolmogorov Arnold Network (KAN) documentation! 7 | ========================================================== 8 | 9 | .. image:: kan_plot.png 10 | 11 | This documentation is for the `paper`_ "KAN: Kolmogorov-Arnold Networks" and the `github repo`_. 12 | Kolmogorov-Arnold Networks, inspired by the Kolmogorov-Arnold representation theorem, are promising alternatives 13 | of Multi-Layer Preceptrons (MLPs). KANs have activation functions on edges, whereas MLPs have activation functions on nodes. 14 | This simple change makes KAN better than MLPs in terms of both accuracy and interpretability. 15 | 16 | .. _github repo: https://github.com/KindXiaoming/pykan 17 | .. _paper: https://arxiv.org/abs/2404.19756 18 | 19 | Installation 20 | ------------ 21 | 22 | Installation via github 23 | ~~~~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | .. code-block:: python 26 | 27 | git clone https://github.com/KindXiaoming/pykan.git 28 | cd pykan 29 | pip install -e . 30 | # pip install -r requirements.txt # install requirements 31 | 32 | 33 | Installation via PyPI 34 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 35 | 36 | .. code-block:: python 37 | 38 | pip install pykan 39 | 40 | 41 | Requirements 42 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 43 | 44 | .. code-block:: python 45 | # python==3.9.7 46 | matplotlib==3.6.2 47 | numpy==1.24.4 48 | scikit_learn==1.1.3 49 | setuptools==65.5.0 50 | sympy==1.11.1 51 | torch==2.2.2 52 | tqdm==4.66.2 53 | 54 | Get started 55 | ----------- 56 | 57 | * Quickstart: :ref:`hello-kan` 58 | * KANs in Action: :ref:`api-demo`, :ref:`examples` 59 | * API (advanced): :ref:`api`. 60 | 61 | .. toctree:: 62 | :maxdepth: 1 63 | :caption: Contents: 64 | 65 | intro.rst 66 | modules.rst 67 | demos.rst 68 | examples.rst 69 | interp.rst 70 | physics.rst 71 | community.rst 72 | 73 | Indices and tables 74 | ================== 75 | 76 | * :ref:`genindex` 77 | * :ref:`modindex` 78 | * :ref:`search` 79 | -------------------------------------------------------------------------------- /docs/interp.rst: -------------------------------------------------------------------------------- 1 | .. _interp: 2 | 3 | Interpretability 4 | ---------------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | Interp/Interp_1_Hello, MultKAN.rst 10 | Interp/Interp_2_Advanced MultKAN.rst 11 | Interp/Interp_3_KAN_Compiler.rst 12 | Interp/Interp_4_feature_attribution.rst 13 | Interp/Interp_5_test_symmetry.rst 14 | Interp/Interp_6_test_symmetry_NN.rst 15 | Interp/Interp_8_adding_auxillary_variables.rst 16 | Interp/Interp_9_different_plotting_metrics.rst 17 | Interp/Interp_10_hessian.rst 18 | Interp/Interp_10A_swap.rst 19 | Interp/Interp_10B_swap.rst 20 | Interp/Interp_11_sparse_init.rst 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/intro_files/intro_10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_10_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_12_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_12_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_14_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_14_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_15_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_15_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_17_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_17_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_19_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_19_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_21_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_21_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_23_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_23_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_26_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_26_0.png -------------------------------------------------------------------------------- /docs/intro_files/intro_6_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/intro_files/intro_6_0.png -------------------------------------------------------------------------------- /docs/kan.rst: -------------------------------------------------------------------------------- 1 | kan package 2 | =========== 3 | 4 | kan.KAN module 5 | -------------- 6 | 7 | .. automodule:: kan.MultKAN 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | kan.KANLayer module 13 | ------------------- 14 | 15 | .. automodule:: kan.KANLayer 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | kan.LBFGS module 21 | ---------------- 22 | 23 | .. automodule:: kan.LBFGS 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | kan.Symbolic\_KANLayer module 29 | ----------------------------- 30 | 31 | .. automodule:: kan.Symbolic_KANLayer 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | kan.spline module 37 | ----------------- 38 | 39 | .. automodule:: kan.spline 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | kan.utils module 45 | ---------------- 46 | 47 | .. automodule:: kan.utils 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | kan.compiler module 53 | ------------------- 54 | 55 | .. automodule:: kan.compiler 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | kan.hypothesis module 61 | --------------------- 62 | 63 | .. automodule:: kan.hypothesis 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | -------------------------------------------------------------------------------- /docs/kan_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/docs/kan_plot.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | API 4 | === 5 | 6 | .. toctree:: 7 | :maxdepth: 4 8 | 9 | kan 10 | -------------------------------------------------------------------------------- /docs/physics.rst: -------------------------------------------------------------------------------- 1 | .. _physics: 2 | 3 | Physics 4 | ------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | Physics/Physics_1_Lagrangian.rst 10 | Physics/Physics_2A_conservation_law.rst 11 | Physics/Physics_2B_conservation_law_2D.rst 12 | Physics/Physics_3_blackhole.rst 13 | Physics/Physics_4A_constitutive_laws_P11.rst 14 | Physics/Physics_4B_constitutive_laws_P12_with_prior.rst 15 | Physics/Physics_4C_constitutive_laws_P12_without_prior.rst 16 | 17 | 18 | -------------------------------------------------------------------------------- /kan/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .MultKAN import * 2 | from .utils import * 3 | #torch.use_deterministic_algorithms(True) -------------------------------------------------------------------------------- /kan/.ipynb_checkpoints/experiment-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .MultKAN import * 3 | 4 | 5 | def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None, seed=1): 6 | 7 | result = {} 8 | result['test_loss'] = [] 9 | result['c'] = [] 10 | result['G'] = [] 11 | result['id'] = [] 12 | if metrics != None: 13 | for i in range(len(metrics)): 14 | result[metrics[i].__name__] = [] 15 | 16 | def collect(evaluation): 17 | result['test_loss'].append(evaluation['test_loss']) 18 | result['c'].append(evaluation['n_edge']) 19 | result['G'].append(evaluation['n_grid']) 20 | result['id'].append(f'{model.round}.{model.state_id}') 21 | if metrics != None: 22 | for i in range(len(metrics)): 23 | result[metrics[i].__name__].append(metrics[i](model, dataset).item()) 24 | 25 | for i in range(prune_round): 26 | # train and prune 27 | if i == 0: 28 | model = KAN(width=width, grid=grids[0], seed=seed) 29 | else: 30 | model = model.rewind(f'{i-1}.{2*i}') 31 | 32 | model.fit(dataset, steps=steps, lamb=lamb) 33 | model = model.prune(edge_th=edge_th, node_th=node_th) 34 | evaluation = model.evaluate(dataset) 35 | collect(evaluation) 36 | 37 | for j in range(refine_round): 38 | model = model.refine(grids[j]) 39 | model.fit(dataset, steps=steps) 40 | evaluation = model.evaluate(dataset) 41 | collect(evaluation) 42 | 43 | for key in list(result.keys()): 44 | result[key] = np.array(result[key]) 45 | 46 | return result 47 | 48 | 49 | def pareto_frontier(x,y): 50 | 51 | pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0] 52 | x_pf = x[pf_id] 53 | y_pf = y[pf_id] 54 | 55 | return x_pf, y_pf, pf_id -------------------------------------------------------------------------------- /kan/.ipynb_checkpoints/spline-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def B_batch(x, grid, k=0, extend=True, device='cpu'): 5 | ''' 6 | evaludate x on B-spline bases 7 | 8 | Args: 9 | ----- 10 | x : 2D torch.tensor 11 | inputs, shape (number of splines, number of samples) 12 | grid : 2D torch.tensor 13 | grids, shape (number of splines, number of grid points) 14 | k : int 15 | the piecewise polynomial order of splines. 16 | extend : bool 17 | If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True 18 | device : str 19 | devicde 20 | 21 | Returns: 22 | -------- 23 | spline values : 3D torch.tensor 24 | shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. 25 | 26 | Example 27 | ------- 28 | >>> from kan.spline import B_batch 29 | >>> x = torch.rand(100,2) 30 | >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) 31 | >>> B_batch(x, grid, k=3).shape 32 | ''' 33 | 34 | x = x.unsqueeze(dim=2) 35 | grid = grid.unsqueeze(dim=0) 36 | 37 | if k == 0: 38 | value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) 39 | else: 40 | B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1) 41 | 42 | value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + ( 43 | grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:] 44 | 45 | # in case grid is degenerate 46 | value = torch.nan_to_num(value) 47 | return value 48 | 49 | 50 | 51 | def coef2curve(x_eval, grid, coef, k, device="cpu"): 52 | ''' 53 | converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). 54 | 55 | Args: 56 | ----- 57 | x_eval : 2D torch.tensor 58 | shape (batch, in_dim) 59 | grid : 2D torch.tensor 60 | shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. 61 | coef : 3D torch.tensor 62 | shape (in_dim, out_dim, G+k) 63 | k : int 64 | the piecewise polynomial order of splines. 65 | device : str 66 | devicde 67 | 68 | Returns: 69 | -------- 70 | y_eval : 3D torch.tensor 71 | shape (batch, in_dim, out_dim) 72 | 73 | ''' 74 | 75 | b_splines = B_batch(x_eval, grid, k=k) 76 | y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device)) 77 | 78 | return y_eval 79 | 80 | 81 | def curve2coef(x_eval, y_eval, grid, k): 82 | ''' 83 | converting B-spline curves to B-spline coefficients using least squares. 84 | 85 | Args: 86 | ----- 87 | x_eval : 2D torch.tensor 88 | shape (batch, in_dim) 89 | y_eval : 3D torch.tensor 90 | shape (batch, in_dim, out_dim) 91 | grid : 2D torch.tensor 92 | shape (in_dim, grid+2*k) 93 | k : int 94 | spline order 95 | lamb : float 96 | regularized least square lambda 97 | 98 | Returns: 99 | -------- 100 | coef : 3D torch.tensor 101 | shape (in_dim, out_dim, G+k) 102 | ''' 103 | #print('haha', x_eval.shape, y_eval.shape, grid.shape) 104 | batch = x_eval.shape[0] 105 | in_dim = x_eval.shape[1] 106 | out_dim = y_eval.shape[2] 107 | n_coef = grid.shape[1] - k - 1 108 | mat = B_batch(x_eval, grid, k) 109 | mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef) 110 | #print('mat', mat.shape) 111 | y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3) 112 | #print('y_eval', y_eval.shape) 113 | device = mat.device 114 | 115 | #coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0] 116 | try: 117 | coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0] 118 | except: 119 | print('lstsq failed') 120 | 121 | # manual psuedo-inverse 122 | '''lamb=1e-8 123 | XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) 124 | Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) 125 | n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] 126 | identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device) 127 | A = XtX + lamb * identity 128 | B = Xty 129 | coef = (A.pinverse() @ B)[:,:,:,0]''' 130 | 131 | return coef 132 | 133 | 134 | def extend_grid(grid, k_extend=0): 135 | ''' 136 | extend grid 137 | ''' 138 | h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) 139 | 140 | for i in range(k_extend): 141 | grid = torch.cat([grid[:, [0]] - h, grid], dim=1) 142 | grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) 143 | 144 | return grid -------------------------------------------------------------------------------- /kan/__init__.py: -------------------------------------------------------------------------------- 1 | from .MultKAN import * 2 | from .utils import * 3 | #torch.use_deterministic_algorithms(True) -------------------------------------------------------------------------------- /kan/assets/img/mult_symbol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/kan/assets/img/mult_symbol.png -------------------------------------------------------------------------------- /kan/assets/img/sum_symbol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/kan/assets/img/sum_symbol.png -------------------------------------------------------------------------------- /kan/experiment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .MultKAN import * 3 | 4 | 5 | def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None, seed=1): 6 | 7 | result = {} 8 | result['test_loss'] = [] 9 | result['c'] = [] 10 | result['G'] = [] 11 | result['id'] = [] 12 | if metrics != None: 13 | for i in range(len(metrics)): 14 | result[metrics[i].__name__] = [] 15 | 16 | def collect(evaluation): 17 | result['test_loss'].append(evaluation['test_loss']) 18 | result['c'].append(evaluation['n_edge']) 19 | result['G'].append(evaluation['n_grid']) 20 | result['id'].append(f'{model.round}.{model.state_id}') 21 | if metrics != None: 22 | for i in range(len(metrics)): 23 | result[metrics[i].__name__].append(metrics[i](model, dataset).item()) 24 | 25 | for i in range(prune_round): 26 | # train and prune 27 | if i == 0: 28 | model = KAN(width=width, grid=grids[0], seed=seed) 29 | else: 30 | model = model.rewind(f'{i-1}.{2*i}') 31 | 32 | model.fit(dataset, steps=steps, lamb=lamb) 33 | model = model.prune(edge_th=edge_th, node_th=node_th) 34 | evaluation = model.evaluate(dataset) 35 | collect(evaluation) 36 | 37 | for j in range(refine_round): 38 | model = model.refine(grids[j]) 39 | model.fit(dataset, steps=steps) 40 | evaluation = model.evaluate(dataset) 41 | collect(evaluation) 42 | 43 | for key in list(result.keys()): 44 | result[key] = np.array(result[key]) 45 | 46 | return result 47 | 48 | 49 | def pareto_frontier(x,y): 50 | 51 | pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0] 52 | x_pf = x[pf_id] 53 | y_pf = y[pf_id] 54 | 55 | return x_pf, y_pf, pf_id -------------------------------------------------------------------------------- /kan/spline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def B_batch(x, grid, k=0, extend=True, device='cpu'): 5 | ''' 6 | evaludate x on B-spline bases 7 | 8 | Args: 9 | ----- 10 | x : 2D torch.tensor 11 | inputs, shape (number of splines, number of samples) 12 | grid : 2D torch.tensor 13 | grids, shape (number of splines, number of grid points) 14 | k : int 15 | the piecewise polynomial order of splines. 16 | extend : bool 17 | If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True 18 | device : str 19 | devicde 20 | 21 | Returns: 22 | -------- 23 | spline values : 3D torch.tensor 24 | shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. 25 | 26 | Example 27 | ------- 28 | >>> from kan.spline import B_batch 29 | >>> x = torch.rand(100,2) 30 | >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) 31 | >>> B_batch(x, grid, k=3).shape 32 | ''' 33 | 34 | x = x.unsqueeze(dim=2) 35 | grid = grid.unsqueeze(dim=0) 36 | 37 | if k == 0: 38 | value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) 39 | else: 40 | B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1) 41 | 42 | value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + ( 43 | grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:] 44 | 45 | # in case grid is degenerate 46 | value = torch.nan_to_num(value) 47 | return value 48 | 49 | 50 | 51 | def coef2curve(x_eval, grid, coef, k, device="cpu"): 52 | ''' 53 | converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). 54 | 55 | Args: 56 | ----- 57 | x_eval : 2D torch.tensor 58 | shape (batch, in_dim) 59 | grid : 2D torch.tensor 60 | shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. 61 | coef : 3D torch.tensor 62 | shape (in_dim, out_dim, G+k) 63 | k : int 64 | the piecewise polynomial order of splines. 65 | device : str 66 | devicde 67 | 68 | Returns: 69 | -------- 70 | y_eval : 3D torch.tensor 71 | shape (batch, in_dim, out_dim) 72 | 73 | ''' 74 | 75 | b_splines = B_batch(x_eval, grid, k=k) 76 | y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device)) 77 | 78 | return y_eval 79 | 80 | 81 | def curve2coef(x_eval, y_eval, grid, k): 82 | ''' 83 | converting B-spline curves to B-spline coefficients using least squares. 84 | 85 | Args: 86 | ----- 87 | x_eval : 2D torch.tensor 88 | shape (batch, in_dim) 89 | y_eval : 3D torch.tensor 90 | shape (batch, in_dim, out_dim) 91 | grid : 2D torch.tensor 92 | shape (in_dim, grid+2*k) 93 | k : int 94 | spline order 95 | lamb : float 96 | regularized least square lambda 97 | 98 | Returns: 99 | -------- 100 | coef : 3D torch.tensor 101 | shape (in_dim, out_dim, G+k) 102 | ''' 103 | #print('haha', x_eval.shape, y_eval.shape, grid.shape) 104 | batch = x_eval.shape[0] 105 | in_dim = x_eval.shape[1] 106 | out_dim = y_eval.shape[2] 107 | n_coef = grid.shape[1] - k - 1 108 | mat = B_batch(x_eval, grid, k) 109 | mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef) 110 | #print('mat', mat.shape) 111 | y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3) 112 | #print('y_eval', y_eval.shape) 113 | device = mat.device 114 | 115 | #coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0] 116 | try: 117 | coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0] 118 | except: 119 | print('lstsq failed') 120 | 121 | # manual psuedo-inverse 122 | '''lamb=1e-8 123 | XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) 124 | Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) 125 | n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] 126 | identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device) 127 | A = XtX + lamb * identity 128 | B = Xty 129 | coef = (A.pinverse() @ B)[:,:,:,0]''' 130 | 131 | return coef 132 | 133 | 134 | def extend_grid(grid, k_extend=0): 135 | ''' 136 | extend grid 137 | ''' 138 | h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) 139 | 140 | for i in range(k_extend): 141 | grid = torch.cat([grid[:, [0]] - h, grid], dim=1) 142 | grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) 143 | 144 | return grid -------------------------------------------------------------------------------- /model/0.0_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/model/0.0_cache_data -------------------------------------------------------------------------------- /model/0.0_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: 3 7 | grid_eps: 0.02 8 | grid_range: 9 | - -1 10 | - 1 11 | k: 3 12 | mult_arity: 2 13 | round: 0 14 | sb_trainable: true 15 | sp_trainable: true 16 | state_id: 0 17 | symbolic.funs_name.0: 18 | - - '0' 19 | - '0' 20 | - - '0' 21 | - '0' 22 | - - '0' 23 | - '0' 24 | - - '0' 25 | - '0' 26 | - - '0' 27 | - '0' 28 | symbolic.funs_name.1: 29 | - - '0' 30 | - '0' 31 | - '0' 32 | - '0' 33 | - '0' 34 | symbolic_enabled: true 35 | width: 36 | - - 2 37 | - 0 38 | - - 5 39 | - 0 40 | - - 1 41 | - 0 42 | -------------------------------------------------------------------------------- /model/0.0_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/model/0.0_state -------------------------------------------------------------------------------- /model/history.txt: -------------------------------------------------------------------------------- 1 | ### Round 0 ### 2 | init => 0.0 3 | -------------------------------------------------------------------------------- /pykan.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.py 4 | experiments/__init__.py 5 | experiments/baselines/MLP.py 6 | experiments/baselines/__init__.py 7 | kan/KANLayer.py 8 | kan/LBFGS.py 9 | kan/MLP.py 10 | kan/MultKAN.py 11 | kan/Symbolic_KANLayer.py 12 | kan/__init__.py 13 | kan/compiler.py 14 | kan/experiment.py 15 | kan/feynman.py 16 | kan/hypothesis.py 17 | kan/spline.py 18 | kan/utils.py 19 | kan/assets/img/mult_symbol.png 20 | kan/assets/img/sum_symbol.png 21 | pykan.egg-info/PKG-INFO 22 | pykan.egg-info/SOURCES.txt 23 | pykan.egg-info/dependency_links.txt 24 | pykan.egg-info/top_level.txt -------------------------------------------------------------------------------- /pykan.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pykan.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | experiments 2 | kan 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.6.2 2 | numpy==1.24.4 3 | scikit_learn==1.1.3 4 | setuptools==65.5.0 5 | sympy==1.11.1 6 | torch==2.2.2 7 | tqdm==4.66.2 8 | pandas==2.0.1 9 | seaborn 10 | pyyaml 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | # Load the long_description from README.md 4 | with open("README.md", "r", encoding="utf8") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="pykan", 9 | version="0.2.8", 10 | author="Ziming Liu", 11 | author_email="zmliu@mit.edu", 12 | description="Kolmogorov Arnold Networks", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | # url="https://github.com/kindxiaoming/", 16 | packages=setuptools.find_packages(), 17 | include_package_data=True, 18 | package_data={ 19 | 'pykan': [ 20 | 'figures/lock.png', 21 | 'assets/img/sum_symbol.png', 22 | 'assets/img/mult_symbol.png', 23 | ], 24 | }, 25 | classifiers=[ 26 | "Programming Language :: Python :: 3", 27 | "License :: OSI Approved :: MIT License", 28 | "Operating System :: OS Independent", 29 | ], 30 | python_requires='>=3.6', 31 | ) 32 | -------------------------------------------------------------------------------- /tutorials/.ipynb_checkpoints/API_11_create_dataset-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "53ff2e87", 6 | "metadata": {}, 7 | "source": [ 8 | "# API 11: Create dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "25a90774", 14 | "metadata": {}, 15 | "source": [ 16 | "how to use create_dataset in kan.utils" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "2f9ae0c7", 22 | "metadata": {}, 23 | "source": [ 24 | "Standard way" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "id": "3e2b9f8b", 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "cuda\n" 38 | ] 39 | }, 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "torch.Size([1000, 1])" 44 | ] 45 | }, 46 | "execution_count": 1, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "from kan.utils import create_dataset\n", 53 | "import torch\n", 54 | "\n", 55 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 56 | "print(device)\n", 57 | "\n", 58 | "f = lambda x: x[:,[0]] * x[:,[1]]\n", 59 | "dataset = create_dataset(f, n_var=2, device=device)\n", 60 | "dataset['train_label'].shape" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "877956c9", 66 | "metadata": {}, 67 | "source": [ 68 | "Lazier way. We sometimes forget to add the bracket, i.e., write x[:,[0]] as x[:,0], and this used to lead to an error in training (loss not going down). Now the create_dataset can automatically detect this simplification and produce the correct behavior." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "id": "b14dd4a2", 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "text/plain": [ 80 | "torch.Size([1000, 1])" 81 | ] 82 | }, 83 | "execution_count": 2, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "f = lambda x: x[:,0] * x[:,1]\n", 90 | "dataset = create_dataset(f, n_var=2, device=device)\n", 91 | "dataset['train_label'].shape" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "60230da4", 97 | "metadata": {}, 98 | "source": [ 99 | "Laziest way. If you even want to get rid of the colon symbol, i.e., you want to write x[;,0] as x[0], you can do that but need to pass in f_mode = 'row'." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "id": "e764f415", 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/plain": [ 111 | "torch.Size([1000, 1])" 112 | ] 113 | }, 114 | "execution_count": 3, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "f = lambda x: x[0] * x[1]\n", 121 | "dataset = create_dataset(f, n_var=2, f_mode='row', device=device)\n", 122 | "dataset['train_label'].shape" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "8e1f1732", 128 | "metadata": {}, 129 | "source": [ 130 | "if you already have x (inputs) and y (outputs), and you only want to partition them into train/test, use create_dataset_from_data" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 4, 136 | "id": "accf900a", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "import torch\n", 141 | "from kan.utils import create_dataset_from_data\n", 142 | "\n", 143 | "x = torch.rand(100,2)\n", 144 | "y = torch.rand(100,1)\n", 145 | "dataset = create_dataset_from_data(x, y, device=device)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "c45062a8", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3 (ipykernel)", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.9.16" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 5 178 | } 179 | -------------------------------------------------------------------------------- /tutorials/.ipynb_checkpoints/API_9_video-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "134e7f9d", 6 | "metadata": {}, 7 | "source": [ 8 | "# API 9: Videos\n", 9 | "\n", 10 | "We have shown one can visualize KAN with the plot() method. If one wants to save the training dynamics of KAN plots, one only needs to pass argument save_video = True to train() method (and set some video related parameters)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 6, 16 | "id": "2075ef56", 17 | "metadata": { 18 | "tags": [] 19 | }, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "cuda\n", 26 | "checkpoint directory created: ./model\n", 27 | "saving model version 0.0\n" 28 | ] 29 | }, 30 | { 31 | "name": "stderr", 32 | "output_type": "stream", 33 | "text": [ 34 | "| train_loss: 2.89e-01 | test_loss: 2.96e-01 | reg: 1.31e+01 | : 100%|█| 5/5 [00:09<00:00, 1.94s/it" 35 | ] 36 | }, 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "saving model version 0.1\n" 42 | ] 43 | }, 44 | { 45 | "name": "stderr", 46 | "output_type": "stream", 47 | "text": [ 48 | "\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "from kan import *\n", 54 | "import torch\n", 55 | "\n", 56 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 57 | "print(device)\n", 58 | "\n", 59 | "# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n", 60 | "model = KAN(width=[4,2,1,1], grid=3, k=3, seed=1, device=device)\n", 61 | "f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n", 62 | "dataset = create_dataset(f, n_var=4, train_num=3000, device=device)\n", 63 | "\n", 64 | "image_folder = 'video_img'\n", 65 | "\n", 66 | "# train the model\n", 67 | "#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n", 68 | "model.fit(dataset, opt=\"LBFGS\", steps=5, lamb=0.001, lamb_entropy=2., save_fig=True, beta=10, \n", 69 | " in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n", 70 | " out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n", 71 | " img_folder=image_folder);\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 2, 77 | "id": "c18245a3", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Moviepy - Building video video.mp4.\n", 85 | "Moviepy - Writing video video.mp4\n", 86 | "\n" 87 | ] 88 | }, 89 | { 90 | "name": "stderr", 91 | "output_type": "stream", 92 | "text": [ 93 | " \r" 94 | ] 95 | }, 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Moviepy - Done !\n", 101 | "Moviepy - video ready video.mp4\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "import os\n", 107 | "import numpy as np\n", 108 | "import moviepy.video.io.ImageSequenceClip # moviepy == 1.0.3\n", 109 | "\n", 110 | "video_name='video'\n", 111 | "fps=5\n", 112 | "\n", 113 | "fps = fps\n", 114 | "files = os.listdir(image_folder)\n", 115 | "train_index = []\n", 116 | "for file in files:\n", 117 | " if file[0].isdigit() and file.endswith('.jpg'):\n", 118 | " train_index.append(int(file[:-4]))\n", 119 | "\n", 120 | "train_index = np.sort(train_index)\n", 121 | "\n", 122 | "image_files = [image_folder+'/'+str(train_index[index])+'.jpg' for index in train_index]\n", 123 | "\n", 124 | "clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)\n", 125 | "clip.write_videofile(video_name+'.mp4')" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "88d0d737", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "Python 3 (ipykernel)", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.9.16" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 5 158 | } 159 | -------------------------------------------------------------------------------- /tutorials/API_demo/API_11_create_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "53ff2e87", 6 | "metadata": {}, 7 | "source": [ 8 | "# API 11: Create dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "25a90774", 14 | "metadata": {}, 15 | "source": [ 16 | "how to use create_dataset in kan.utils" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "2f9ae0c7", 22 | "metadata": {}, 23 | "source": [ 24 | "Standard way" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "id": "3e2b9f8b", 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "cuda\n" 38 | ] 39 | }, 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "torch.Size([1000, 1])" 44 | ] 45 | }, 46 | "execution_count": 1, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "from kan.utils import create_dataset\n", 53 | "import torch\n", 54 | "\n", 55 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 56 | "print(device)\n", 57 | "\n", 58 | "f = lambda x: x[:,[0]] * x[:,[1]]\n", 59 | "dataset = create_dataset(f, n_var=2, device=device)\n", 60 | "dataset['train_label'].shape" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "877956c9", 66 | "metadata": {}, 67 | "source": [ 68 | "Lazier way. We sometimes forget to add the bracket, i.e., write x[:,[0]] as x[:,0], and this used to lead to an error in training (loss not going down). Now the create_dataset can automatically detect this simplification and produce the correct behavior." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "id": "b14dd4a2", 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "text/plain": [ 80 | "torch.Size([1000, 1])" 81 | ] 82 | }, 83 | "execution_count": 2, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "f = lambda x: x[:,0] * x[:,1]\n", 90 | "dataset = create_dataset(f, n_var=2, device=device)\n", 91 | "dataset['train_label'].shape" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "60230da4", 97 | "metadata": {}, 98 | "source": [ 99 | "Laziest way. If you even want to get rid of the colon symbol, i.e., you want to write x[;,0] as x[0], you can do that but need to pass in f_mode = 'row'." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "id": "e764f415", 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/plain": [ 111 | "torch.Size([1000, 1])" 112 | ] 113 | }, 114 | "execution_count": 3, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "f = lambda x: x[0] * x[1]\n", 121 | "dataset = create_dataset(f, n_var=2, f_mode='row', device=device)\n", 122 | "dataset['train_label'].shape" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "8e1f1732", 128 | "metadata": {}, 129 | "source": [ 130 | "if you already have x (inputs) and y (outputs), and you only want to partition them into train/test, use create_dataset_from_data" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 4, 136 | "id": "accf900a", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "import torch\n", 141 | "from kan.utils import create_dataset_from_data\n", 142 | "\n", 143 | "x = torch.rand(100,2)\n", 144 | "y = torch.rand(100,1)\n", 145 | "dataset = create_dataset_from_data(x, y, device=device)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "c45062a8", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3 (ipykernel)", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.9.16" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 5 178 | } 179 | -------------------------------------------------------------------------------- /tutorials/API_demo/API_9_video.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "134e7f9d", 6 | "metadata": {}, 7 | "source": [ 8 | "# API 9: Videos\n", 9 | "\n", 10 | "We have shown one can visualize KAN with the plot() method. If one wants to save the training dynamics of KAN plots, one only needs to pass argument save_video = True to train() method (and set some video related parameters)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 6, 16 | "id": "2075ef56", 17 | "metadata": { 18 | "tags": [] 19 | }, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "cuda\n", 26 | "checkpoint directory created: ./model\n", 27 | "saving model version 0.0\n" 28 | ] 29 | }, 30 | { 31 | "name": "stderr", 32 | "output_type": "stream", 33 | "text": [ 34 | "| train_loss: 2.89e-01 | test_loss: 2.96e-01 | reg: 1.31e+01 | : 100%|█| 5/5 [00:09<00:00, 1.94s/it" 35 | ] 36 | }, 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "saving model version 0.1\n" 42 | ] 43 | }, 44 | { 45 | "name": "stderr", 46 | "output_type": "stream", 47 | "text": [ 48 | "\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "from kan import *\n", 54 | "import torch\n", 55 | "\n", 56 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 57 | "print(device)\n", 58 | "\n", 59 | "# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n", 60 | "model = KAN(width=[4,2,1,1], grid=3, k=3, seed=1, device=device)\n", 61 | "f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n", 62 | "dataset = create_dataset(f, n_var=4, train_num=3000, device=device)\n", 63 | "\n", 64 | "image_folder = 'video_img'\n", 65 | "\n", 66 | "# train the model\n", 67 | "#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n", 68 | "model.fit(dataset, opt=\"LBFGS\", steps=5, lamb=0.001, lamb_entropy=2., save_fig=True, beta=10, \n", 69 | " in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n", 70 | " out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n", 71 | " img_folder=image_folder);\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 2, 77 | "id": "c18245a3", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Moviepy - Building video video.mp4.\n", 85 | "Moviepy - Writing video video.mp4\n", 86 | "\n" 87 | ] 88 | }, 89 | { 90 | "name": "stderr", 91 | "output_type": "stream", 92 | "text": [ 93 | " \r" 94 | ] 95 | }, 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Moviepy - Done !\n", 101 | "Moviepy - video ready video.mp4\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "import os\n", 107 | "import numpy as np\n", 108 | "import moviepy.video.io.ImageSequenceClip # moviepy == 1.0.3\n", 109 | "\n", 110 | "video_name='video'\n", 111 | "fps=5\n", 112 | "\n", 113 | "fps = fps\n", 114 | "files = os.listdir(image_folder)\n", 115 | "train_index = []\n", 116 | "for file in files:\n", 117 | " if file[0].isdigit() and file.endswith('.jpg'):\n", 118 | " train_index.append(int(file[:-4]))\n", 119 | "\n", 120 | "train_index = np.sort(train_index)\n", 121 | "\n", 122 | "image_files = [image_folder+'/'+str(train_index[index])+'.jpg' for index in train_index]\n", 123 | "\n", 124 | "clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)\n", 125 | "clip.write_videofile(video_name+'.mp4')" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "88d0d737", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "Python 3 (ipykernel)", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.9.16" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 5 158 | } 159 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.0_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.0_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.0_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | AwAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 0 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: true 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.0_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.0_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.10_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.10_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.10_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | MgAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 10 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: true 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.10_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.10_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.11_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.11_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.11_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | ZAAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 11 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: false 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.11_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.11_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.1_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.1_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.1_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | AwAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 1 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: false 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.1_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.1_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.2_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.2_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.2_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | AwAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 2 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: true 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.2_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.2_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.3_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.3_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.3_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | BQAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 3 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: false 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.3_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.3_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.4_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.4_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.4_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | BQAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 4 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: true 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.4_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.4_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.5_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.5_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.5_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | CgAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 5 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: false 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.5_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.5_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.6_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.6_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.6_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | CgAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 6 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: true 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.6_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.6_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.7_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.7_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.7_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | FAAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 7 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: false 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.7_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.7_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.8_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.8_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.8_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | FAAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 8 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: true 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.8_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.8_state -------------------------------------------------------------------------------- /tutorials/Example/model/0.9_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.9_cache_data -------------------------------------------------------------------------------- /tutorials/Example/model/0.9_config.yml: -------------------------------------------------------------------------------- 1 | affine_trainable: false 2 | auto_save: true 3 | base_fun_name: silu 4 | ckpt_path: ./model 5 | device: cpu 6 | grid: !!python/object/apply:numpy.core.multiarray.scalar 7 | - !!python/object/apply:numpy.dtype 8 | args: 9 | - i8 10 | - false 11 | - true 12 | state: !!python/tuple 13 | - 3 14 | - < 15 | - null 16 | - null 17 | - null 18 | - -1 19 | - -1 20 | - 0 21 | - !!binary | 22 | MgAAAAAAAAA= 23 | grid_eps: 0.02 24 | grid_range: 25 | - -1 26 | - 1 27 | k: 3 28 | mult_arity: 2 29 | round: 0 30 | sb_trainable: true 31 | sp_trainable: true 32 | state_id: 9 33 | symbolic.funs_name.0: 34 | - - '0' 35 | - '0' 36 | symbolic.funs_name.1: 37 | - - '0' 38 | symbolic_enabled: false 39 | width: 40 | - - 2 41 | - 0 42 | - - 1 43 | - 0 44 | - - 1 45 | - 0 46 | -------------------------------------------------------------------------------- /tutorials/Example/model/0.9_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/pykan/ecde4ec3274d3bef1ad737479cf126aed38ab530/tutorials/Example/model/0.9_state -------------------------------------------------------------------------------- /tutorials/Example/model/history.txt: -------------------------------------------------------------------------------- 1 | ### Round 0 ### 2 | init => 0.0 3 | --------------------------------------------------------------------------------