├── MNIST ├── MNIST_read_and_plot_rotation5M.ipynb ├── MNIST_read_and_plot_whitepatch2M.ipynb ├── mnist_get_pretrained.ipynb ├── semitorchMNISTclass.py ├── simu_MNIST_patches.py ├── simu_MNIST_patches_2M.py ├── simudata_MNIST.py ├── submit_simu_MNIST_patches.py └── submit_simu_MNIST_patches_2M.py ├── README.md ├── amazon ├── amazon_review_data_2018_subset_regression.py ├── read_and_preprocess_amazon_review_data_2018_subset.ipynb ├── results_amazon │ └── .gitignore └── submit_amazon_review_data_2018_subset_regression.py ├── mmd.py ├── myrandom.py ├── sem.py ├── semiclass.py ├── semitorchclass.py ├── semitorchstocclass.py ├── sim ├── sim_linearSCM_mean_shift_exp1-7.ipynb ├── sim_linearSCM_var_shift_exp8_box_run.py ├── sim_linearSCM_var_shift_exp8_box_submit.py ├── sim_linearSCM_var_shift_exp8_scat_run.py ├── sim_linearSCM_var_shift_exp8_scat_submit.py ├── sim_linearSCM_var_shift_exp9_scat_run.py ├── sim_linearSCM_var_shift_exp9_scat_submit.py ├── sim_linearSCM_variance_shift_exp8-9.ipynb ├── simu_results │ └── .gitignore └── simudata.py └── util.py /MNIST/MNIST_read_and_plot_rotation5M.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import seaborn as sns\n", 12 | "\n", 13 | "import pandas as pd\n", 14 | "\n", 15 | "plt.rcParams['axes.facecolor'] = 'lightgray'\n", 16 | "sns.set(style=\"darkgrid\")\n", 17 | "np.set_printoptions(precision=3)" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "def boxplot_all_methods(plt_handle, res_all, title='', names=[], color=[]):\n", 27 | " res_all_df = pd.DataFrame(res_all.T)\n", 28 | " res_all_df.columns = names\n", 29 | " res_all_df_melt = res_all_df.melt(var_name='methods', value_name='accuracy')\n", 30 | " res_all_mean = np.mean(res_all, axis=1)\n", 31 | "# print(res_all_df_melt)\n", 32 | " print(res_all_df.shape, res_all_mean.shape, res_all_df_melt.shape)\n", 33 | " \n", 34 | "# plt_handle.set_title(title, fontsize=15)\n", 35 | "\n", 36 | " plt_handle.axhline(res_all_mean[2], ls='--', color='b')\n", 37 | " plt_handle.axhline(res_all_mean[1], ls='--', color='r')\n", 38 | " ax = sns.boxplot(x=\"methods\", y=\"accuracy\", data=res_all_df_melt, palette=color, ax=plt_handle)\n", 39 | " ax.set_xticklabels(ax.get_xticklabels(), rotation=-60, ha='left', fontsize=15)\n", 40 | " ax.tick_params(labelsize=15)\n", 41 | " ax.yaxis.grid(False) # Hide the horizontal gridlines\n", 42 | " ax.xaxis.grid(True) # Show the vertical gridlines\n", 43 | " ax.set_xlabel(\"methods\")\n", 44 | " ax.set_ylabel(\"accuracy\")\n", 45 | " \n", 46 | " ax.set_xlabel(\"\")\n", 47 | " ax.set_ylabel(\"Accuracy (%)\", fontsize=15)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# perturb = 'whitepatch'\n", 57 | "perturb = 'rotation'\n", 58 | "M = 5\n", 59 | "subset_prop = 0.2\n", 60 | "lamL2 = 0.\n", 61 | "lamL1 = 0.\n", 62 | "lr = 1e-4\n", 63 | "epochs= 100" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "names_short = ['Original', \"Tar\", \"Src[1]\",\n", 73 | " 'DIP[1]', 'DIPweigh', 'CIP', 'CIRMweigh',\n", 74 | " 'DIP[1]-MMD', 'DIPweigh-MMD', 'CIP-MMD', 'CIRMweigh-MMD']\n", 75 | "\n", 76 | "prefix_template = 'results_MNIST/report_v8_%s_M%d_subsetprop%s_%s_lamMatch%s_lamCIP%s_lamMatchMMD%s_lamCIPMMD%s_epochs%d_seed%d'\n", 77 | "\n" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "repeats = 10\n", 87 | "nb_ba = 3 # Original, Tar, Src[1]\n", 88 | "results_src_ba = np.zeros((M-1, nb_ba, 2, 10))\n", 89 | "results_tar_ba = np.zeros((nb_ba, 2, 10))\n", 90 | "for seed in range(repeats):\n", 91 | " savefilename_prefix = prefix_template % (perturb,\n", 92 | " M, str(subset_prop), 'baseline', 1., 0.1, 1., 0.1, epochs, seed)\n", 93 | " res = np.load(\"%s.npy\" %savefilename_prefix, allow_pickle=True)\n", 94 | "\n", 95 | " results_src_ba[:, :, :, seed] =res.item()['src']\n", 96 | " results_tar_ba[:, :, seed] = res.item()['tar']" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 6, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "lamMatches = [10.**(k) for k in (np.arange(10)-5)]" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 7, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# DAmean methods: DIP, DIPOracle, DIPweigh, CIRMweigh\n", 115 | "nb_methods_damean = 4\n", 116 | "repeats = 10\n", 117 | "results_src_damean = np.zeros((len(lamMatches), M-1, nb_methods_damean, 2, 10))\n", 118 | "results_tar_damean = np.zeros((len(lamMatches), nb_methods_damean, 2, 10))\n", 119 | "for i, lam in enumerate(lamMatches):\n", 120 | " for seed in range(repeats):\n", 121 | " savefilename_prefix = prefix_template % (perturb,\n", 122 | " M, str(subset_prop), 'DAmean', lam, 10., lam, 10., epochs, seed)\n", 123 | " res = np.load(\"%s.npy\" %savefilename_prefix, allow_pickle=True)\n", 124 | "\n", 125 | " results_src_damean[i, :, :, :, seed] =res.item()['src']\n", 126 | " results_tar_damean[i, :, :, seed] = res.item()['tar']" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 8, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "7 100.0\n", 139 | "3 0.01\n", 140 | "7 100.0\n", 141 | "4 0.1\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "# choose lambda based on the source test performance\n", 147 | "lam_index_damean = np.zeros(nb_methods_damean, dtype=int)\n", 148 | "for i in range(nb_methods_damean):\n", 149 | " if i == 0 or i == 1:\n", 150 | " src_test_acc_all = results_src_damean[:, 0, i, 1, :].mean(axis=1)\n", 151 | " else:\n", 152 | " # M-2 for the source environment that is selected by weighting methods\n", 153 | " src_test_acc_all = results_src_damean[:, M-2, i, 1, :].mean(axis=1)\n", 154 | " # choose the largest lambda such that the source performance does not drop too much (5%)\n", 155 | " lam_index = 0\n", 156 | " for k, src_test_acc in enumerate(src_test_acc_all):\n", 157 | " \n", 158 | " if src_test_acc > np.max(src_test_acc_all) * 0.99:\n", 159 | " lam_index = k\n", 160 | " lam_index_damean[i] = lam_index\n", 161 | " print(lam_index, lamMatches[lam_index])" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 9, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# DAMMD methods: DIP-MMD, DIPweigh-MMD, CIRMweigh-MMD\n", 171 | "nb_methods_dammd = 3\n", 172 | "repeats = 10\n", 173 | "results_src_dammd = np.zeros((len(lamMatches), M-1, nb_methods_dammd, 2, 10))\n", 174 | "results_tar_dammd = np.zeros((len(lamMatches), nb_methods_dammd, 2, 10))\n", 175 | "for i, lam in enumerate(lamMatches):\n", 176 | " for seed in range(repeats):\n", 177 | " savefilename_prefix = prefix_template % (perturb,\n", 178 | " M, str(subset_prop), 'DAMMD', lam, 10., lam, 10., epochs, seed)\n", 179 | " res = np.load(\"%s.npy\" %savefilename_prefix, allow_pickle=True)\n", 180 | "\n", 181 | " results_src_dammd[i, :, :, :, seed] =res.item()['src']\n", 182 | " results_tar_dammd[i, :, :, seed] = res.item()['tar']" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 10, 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "6\n", 195 | "5\n", 196 | "7\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "# choose lambda based on the source test performance\n", 202 | "lam_index_dammd = np.zeros(nb_methods_dammd, dtype=int)\n", 203 | "for i in range(nb_methods_dammd):\n", 204 | " if i == 0:\n", 205 | " src_test_acc_all = results_src_dammd[:, 0, i, 1, :].mean(axis=1)\n", 206 | " else:\n", 207 | " # M-2 for the source environment that is selected by weighting methods\n", 208 | " src_test_acc_all = results_src_dammd[:, M-2, i, 1, :].mean(axis=1)\n", 209 | " # choose the largest lambda such that the source performance does not drop too much (5%)\n", 210 | " lam_index = 0\n", 211 | " for k, src_test_acc in enumerate(src_test_acc_all):\n", 212 | " if src_test_acc > np.max(src_test_acc_all) * 0.99:\n", 213 | " lam_index = k\n", 214 | " lam_index_dammd[i] = lam_index\n", 215 | " print(lam_index)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 11, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "# DACIPmean methods\n", 225 | "nb_methods_dacipmean = 1\n", 226 | "repeats = 10\n", 227 | "results_src_dacipmean = np.zeros((len(lamMatches), M-1, nb_methods_dacipmean, 2, 10))\n", 228 | "results_tar_dacipmean = np.zeros((len(lamMatches), nb_methods_dacipmean, 2, 10))\n", 229 | "for i, lam in enumerate(lamMatches):\n", 230 | " for seed in range(repeats):\n", 231 | " savefilename_prefix = prefix_template % (perturb,\n", 232 | " M, str(subset_prop), 'DACIPmean', 1., lam, 1., lam, 100, seed)\n", 233 | " res = np.load(\"%s.npy\" %savefilename_prefix, allow_pickle=True)\n", 234 | "\n", 235 | " results_src_dacipmean[i, :, :, :, seed] = res.item()['src']\n", 236 | " results_tar_dacipmean[i, :, :, seed] = res.item()['tar']" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 12, 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "name": "stdout", 246 | "output_type": "stream", 247 | "text": [ 248 | "5 1.0\n" 249 | ] 250 | } 251 | ], 252 | "source": [ 253 | "# choose lambda based on the source test performance\n", 254 | "lam_index_dacipmean = np.zeros(nb_methods_dacipmean, dtype=int)\n", 255 | "for i in range(nb_methods_dacipmean):\n", 256 | " src_test_acc_all = results_src_dacipmean[:, :-1, i, 1, :].mean(axis=2).mean(axis=1)\n", 257 | " # choose the largest lambda such that the source performance does not drop too much (5%)\n", 258 | " lam_index = 0\n", 259 | " for k, src_test_acc in enumerate(src_test_acc_all):\n", 260 | " \n", 261 | " if src_test_acc > np.max(src_test_acc_all) * 0.99:\n", 262 | " lam_index = k\n", 263 | " lam_index_dacipmean[i] = lam_index\n", 264 | " print(lam_index, lamMatches[lam_index])" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 13, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "# DACIPMMD methods\n", 274 | "nb_methods_dacipmmd = 1\n", 275 | "repeats = 10\n", 276 | "results_src_dacipmmd = np.zeros((len(lamMatches), M-1, nb_methods_dacipmmd, 2, 10))\n", 277 | "results_tar_dacipmmd = np.zeros((len(lamMatches), nb_methods_dacipmmd, 2, 10))\n", 278 | "for i, lam in enumerate(lamMatches):\n", 279 | " for seed in range(repeats):\n", 280 | " savefilename_prefix = prefix_template % (perturb,\n", 281 | " M, str(subset_prop), 'DACIPMMD', 1., lam, 1., lam, 100, seed)\n", 282 | " res = np.load(\"%s.npy\" %savefilename_prefix, allow_pickle=True)\n", 283 | "\n", 284 | " results_src_dacipmmd[i, :, :, :, seed] = res.item()['src']\n", 285 | " results_tar_dacipmmd[i, :, :, seed] = res.item()['tar']" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 14, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "name": "stdout", 295 | "output_type": "stream", 296 | "text": [ 297 | "6 10.0\n" 298 | ] 299 | } 300 | ], 301 | "source": [ 302 | "# choose lambda based on the source test performance\n", 303 | "lam_index_dacipmmd = np.zeros(nb_methods_dacipmmd, dtype=int)\n", 304 | "for i in range(nb_methods_dacipmmd):\n", 305 | " src_test_acc_all = results_src_dacipmmd[:, :-1, i, 1, :].mean(axis=2).mean(axis=1)\n", 306 | " # choose the largest lambda such that the source performance does not drop too much (5%)\n", 307 | " lam_index = 0\n", 308 | " for k, src_test_acc in enumerate(src_test_acc_all):\n", 309 | " \n", 310 | " if src_test_acc > np.max(src_test_acc_all) * 0.99:\n", 311 | " lam_index = k\n", 312 | " lam_index_dacipmmd[i] = lam_index\n", 313 | " print(lam_index, lamMatches[lam_index])" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 15, 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "data": { 323 | "text/plain": [ 324 | "array([7, 7, 4])" 325 | ] 326 | }, 327 | "execution_count": 15, 328 | "metadata": {}, 329 | "output_type": "execute_result" 330 | } 331 | ], 332 | "source": [ 333 | "lam_index_damean[[0, 2, 3]]" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 16, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "test_err_index = 0\n", 343 | "\n", 344 | "results_tar_plot = np.concatenate((results_tar_ba[:, test_err_index, :],\n", 345 | " results_tar_damean[lam_index_damean[0], 0, test_err_index, :].reshape((-1, 10), order='F'),\n", 346 | " results_tar_damean[lam_index_damean[2], 2, test_err_index, :].reshape((-1, 10), order='F'),\n", 347 | " results_tar_dacipmean[lam_index_dacipmean, 0, test_err_index, :].reshape((-1, 10), order='F'),\n", 348 | " results_tar_damean[lam_index_damean[3], 3, test_err_index, :].reshape((-1, 10), order='F'),\n", 349 | " results_tar_dammd[lam_index_dammd[0], 0, test_err_index, :].reshape((-1, 10), order='F'),\n", 350 | " results_tar_dammd[lam_index_dammd[1], 1, test_err_index, :].reshape((-1, 10), order='F'),\n", 351 | " results_tar_dacipmmd[lam_index_dacipmmd, 0, test_err_index, :].reshape((-1, 10), order='F'),\n", 352 | " results_tar_dammd[lam_index_dammd[2], 2, test_err_index, :].reshape((-1, 10), order='F')), axis=0)\n", 353 | " \n" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 17, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "data": { 363 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnQAAABECAYAAAAIjKhLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAADJklEQVR4nO3bsWucdRzH8U96xYgnTRqtECzYP8A1WNBOjoIl4JTOgohoFhcHJycX26JTZ2eJ0NFBWqHSVQQ3hcqBrSWBPGKEcA6aIUc8Ufn15/d4vZbneH4cfIbngfdwtzSdTgMAQF2neg8AAOC/EXQAAMUJOgCA4gQdAEBxp+ecLSfZSDJJcvho5gAAcIJRkvUkd5MczB7OC7qNJLcajQIA4J+7lOT27M15QTdJkp2dnQzD0GpUV1tbW5m8cLH3jGbWv76TzY++7D2jide+fZgrNzbz6euf9Z7SxJUbm8nVC71ntLP9fQ6u3em9opnldy7m87de7j2jmVc//iLX3v+k94wmLqx9l8vb17Nz9e3eU5q4vH0939z8qveMZp5/5cV8+O6bvWc0cebsWt5474Pkzz6bNS/oDpNkGIbs7+83mPb/cHjvXu8JTU12f+09oYn9n4Zj14W0+0PvBU1N9xbz2Twy3P+x94Sm9h7u9Z7QxHDqwR/X3Qedl7Tz2y+L/e7t/ny/94TWTvwZnD9FAAAUJ+gAAIoTdAAAxQk6AIDiBB0AQHGCDgCgOEEHAFCcoAMAKE7QAQAUJ+gAAIoTdAAAxQk6AIDiBB0AQHGCDgCgOEEHAFCcoAMAKE7QAQAUJ+gAAIoTdAAAxQk6AIDiBB0AQHGCDgCgOEEHAFCcoAMAKE7QAQAUJ+gAAIoTdAAAxQk6AIDiBB0AQHGCDgCgOEEHAFCcoAMAKE7QAQAUJ+gAAIoTdAAAxQk6AIDiBB0AQHGCDgCgOEEHAFCcoAMAKE7QAQAUJ+gAAIoTdAAAxQk6AIDiBB0AQHGCDgCgOEEHAFCcoAMAKE7QAQAUJ+gAAIoTdAAAxZ2eczZKkvF4/Iim9DE6f773hKbWVx/vPaGJJ58ZH7supNXnei9oamllMZ/NI+Nzz/ae0NTK2krvCU2MV58+dl1Ejz2x2O/e6lPnek9o4szZtaOPo5POl6bT6V9996UktxpsAgDg37mU5PbszXlBt5xkI8kkyWG7XQAA/I1RkvUkd5MczB7OCzoAAArwpwgAgOIEHQBAcYIOAKA4QQcAUNzval1qaR6YgroAAAAASUVORK5CYII=\n", 364 | "text/plain": [ 365 | "
" 366 | ] 367 | }, 368 | "metadata": { 369 | "needs_background": "light" 370 | }, 371 | "output_type": "display_data" 372 | } 373 | ], 374 | "source": [ 375 | "COLOR_PALETTE1 = sns.color_palette(\"Set1\", 9, desat=1.)\n", 376 | "COLOR_PALETTE2 = sns.color_palette(\"Set1\", 9, desat=.7)\n", 377 | "COLOR_PALETTE3 = sns.color_palette(\"Set1\", 9, desat=.5)\n", 378 | "COLOR_PALETTE4 = sns.color_palette(\"Set1\", 9, desat=.3)\n", 379 | "# COLOR_PALETTE2 = sns.color_palette(\"Dark2\", 30)\n", 380 | "# COLOR_PALETTE = COLOR_PALETTE1[:8] + COLOR_PALETTE2[:30]\n", 381 | "COLOR_PALETTE = [COLOR_PALETTE1[8], COLOR_PALETTE1[0], COLOR_PALETTE1[1],\n", 382 | " COLOR_PALETTE1[3], COLOR_PALETTE1[4], COLOR_PALETTE1[7], \n", 383 | " COLOR_PALETTE1[6], \n", 384 | " COLOR_PALETTE4[3], COLOR_PALETTE4[4], COLOR_PALETTE4[7], \n", 385 | " COLOR_PALETTE4[6]]\n", 386 | "sns.palplot(COLOR_PALETTE)\n" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 18, 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "(10, 11) (11,) (110, 2)\n" 399 | ] 400 | }, 401 | { 402 | "data": { 403 | "image/png": "\n", 404 | "text/plain": [ 405 | "
" 406 | ] 407 | }, 408 | "metadata": { 409 | "needs_background": "light" 410 | }, 411 | "output_type": "display_data" 412 | } 413 | ], 414 | "source": [ 415 | "# hyperparameter choice plot\n", 416 | "fig, axs = plt.subplots(1, 1, figsize=(10,5))\n", 417 | "boxplot_all_methods(axs, results_tar_plot*100,\n", 418 | " title=\"\", names=names_short,\n", 419 | " color=np.array(COLOR_PALETTE)[:len(names_short)])\n", 420 | "\n", 421 | "plt.savefig(\"paper_figures/MNIST_%s_M5_Yintervention.pdf\" %perturb, bbox_inches=\"tight\")\n", 422 | "plt.show()" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [] 431 | } 432 | ], 433 | "metadata": { 434 | "kernelspec": { 435 | "display_name": "Python 3", 436 | "language": "python", 437 | "name": "python3" 438 | }, 439 | "language_info": { 440 | "codemirror_mode": { 441 | "name": "ipython", 442 | "version": 3 443 | }, 444 | "file_extension": ".py", 445 | "mimetype": "text/x-python", 446 | "name": "python", 447 | "nbconvert_exporter": "python", 448 | "pygments_lexer": "ipython3", 449 | "version": "3.6.4" 450 | } 451 | }, 452 | "nbformat": 4, 453 | "nbformat_minor": 2 454 | } 455 | -------------------------------------------------------------------------------- /MNIST/MNIST_read_and_plot_whitepatch2M.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import seaborn as sns\n", 12 | "\n", 13 | "import pandas as pd\n", 14 | "\n", 15 | "plt.rcParams['axes.facecolor'] = 'lightgray'\n", 16 | "sns.set(style=\"darkgrid\")\n", 17 | "np.set_printoptions(precision=3)" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "def boxplot_all_methods(plt_handle, res_all, title='', names=[], color=[]):\n", 27 | " res_all_df = pd.DataFrame(res_all.T)\n", 28 | " res_all_df.columns = names\n", 29 | " res_all_df_melt = res_all_df.melt(var_name='methods', value_name='accuracy')\n", 30 | " res_all_mean = np.mean(res_all, axis=1)\n", 31 | " \n", 32 | "# plt_handle.set_title(title, fontsize=15)\n", 33 | "\n", 34 | " plt_handle.axhline(res_all_mean[2], ls='--', color='b')\n", 35 | " plt_handle.axhline(res_all_mean[1], ls='--', color='r')\n", 36 | " ax = sns.boxplot(x=\"methods\", y=\"accuracy\", data=res_all_df_melt, palette=color, ax=plt_handle)\n", 37 | " ax.set_xticklabels(ax.get_xticklabels(), rotation=-60, ha='left', fontsize=20)\n", 38 | " ax.tick_params(labelsize=20)\n", 39 | " ax.yaxis.grid(False) # Hide the horizontal gridlines\n", 40 | " ax.xaxis.grid(True) # Show the vertical gridlines\n", 41 | " ax.set_xlabel(\"methods\")\n", 42 | " ax.set_ylabel(\"accuracy\")\n", 43 | " \n", 44 | " ax.set_xlabel(\"\")\n", 45 | " ax.set_ylabel(\"Accuracy (%)\", fontsize=20)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# perturb = 'whitepatch2M'\n", 55 | "# perturb = 'rotation2M'\n", 56 | "perturb = 'rotation2Ma'\n", 57 | "# perturb = 'translation2M'\n", 58 | "M = 2\n", 59 | "subset_prop = 0.2\n", 60 | "lamL2 = 0.\n", 61 | "lamL1 = 0.\n", 62 | "lr = 1e-4\n", 63 | "epochs=100" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "names_short = [\"Original\", \"Tar\", \"Src[1]\", 'DIP[1]', 'DIP[1]-MMD']\n", 73 | "\n", 74 | "prefix_template = 'results_MNIST/report_v8_%s_M%d_subsetprop%s_%s_lamMatch%s_lamMatchMMD%s_epochs%d_seed%d'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "repeats = 10\n", 84 | "nb_ba = 3 # Original, Tar, Src[1]\n", 85 | "results_src_ba = np.zeros((M-1, nb_ba, 2, repeats))\n", 86 | "results_tar_ba = np.zeros((nb_ba, 2, repeats))\n", 87 | "for seed in range(repeats):\n", 88 | " savefilename_prefix = prefix_template % (perturb,\n", 89 | " M, str(subset_prop), 'baseline', 1., 1., epochs, seed)\n", 90 | " res = np.load(\"%s.npy\" %savefilename_prefix, allow_pickle=True)\n", 91 | " \n", 92 | " results_src_ba[:, :, :, seed] =res.item()['src']\n", 93 | " results_tar_ba[:, :, seed] = res.item()['tar']" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 6, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "lamMatches = [10.**(k) for k in (np.arange(10)-5)]" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 7, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "nb_dip = 3 # DIP DIPOracle DIP-MMD\n", 112 | "results_src_dip = np.zeros((len(lamMatches), M-1, nb_dip, 2, repeats))\n", 113 | "results_tar_dip = np.zeros((len(lamMatches), nb_dip, 2, repeats))\n", 114 | "for i, lam in enumerate(lamMatches):\n", 115 | " for seed in range(repeats):\n", 116 | " savefilename_prefix = prefix_template % (perturb,\n", 117 | " M, str(subset_prop), 'DIP', lam, lam, epochs, seed)\n", 118 | " res = np.load(\"%s.npy\" %savefilename_prefix, allow_pickle=True)\n", 119 | "\n", 120 | " results_src_dip[i, :, :, :, seed] =res.item()['src']\n", 121 | " results_tar_dip[i, :, :, seed] = res.item()['tar']" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 8, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "8\n", 134 | "8\n", 135 | "7\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "# choose lambda based on the source test performance\n", 141 | "lam_index_dip = np.zeros(nb_dip, dtype=int)\n", 142 | "for i in range(nb_dip):\n", 143 | " src_test_acc_all = results_src_dip[:, 0, i, 1, :].mean(axis=1)\n", 144 | " # choose the largest lambda such that the source performance does not drop too much (5%)\n", 145 | " lam_index = 0\n", 146 | " for k, src_test_acc in enumerate(src_test_acc_all):\n", 147 | " if src_test_acc > src_test_acc_all[0] * 0.95:\n", 148 | " lam_index = k\n", 149 | " lam_index_dip[i] = lam_index\n", 150 | " print(lam_index)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 9, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "results_tar_plot = np.concatenate((results_tar_ba[:, 0, :],\n", 160 | " results_tar_dip[lam_index_dip[0], 0, 0, :].reshape(1, -1),\n", 161 | " results_tar_dip[lam_index_dip[2], 2, 0, :].reshape(1, -1)), axis=0)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 10, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAABECAYAAADHuCM8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAACHUlEQVR4nO3YMWoUYQCG4S+uEGENBgshkGsELfQQAasELAQv4AG8glhYpbaWHEGSQrH3BIEFi2AgC6YIa6NgFt1uMh/J8zQzzM/AV73Fv7ZYLALQ4s7YAwD+JkpAFVECqogSUOXuirP1JDtJZkkur2cOcAtMkmwl+ZrkYvlwVZR2khwNNArgWZLj5Y+rojRLksPDw8zn86FGjWpvby+zx0/GnjGYrS+fs/v209gzBvH822n2D3bz4dXHsacMYv9gN+/evB97xiA2Njfy8vWL5Hdjlq2K0mWSzOfznJ+fDzCtw+XJydgTBjX78XPsCYM4/z6/8ryJzk7Pxp4wtH9eC7noBqqIElBFlIAqogRUESWgiigBVUQJqCJKQBVRAqqIElBFlIAqogRUESWgiigBVUQJqCJKQBVRAqqIElBFlIAqogRUESWgiigBVUQJqCJKQBVRAqqIElBFlIAqogRUESWgiigBVUQJqCJKQBVRAqqIElBFlIAqogRUESWgiigBVUQJqCJKQBVRAqqIElBFlIAqogRUESWgiigBVUQJqCJKQBVRAqqIElBFlIAqd1ecTZJkOp1e05RxTLa3x54wqK3Ne2NPGMT9R9Mrz5vowcMHY08YxMbmxp/Xyb/O1xaLxf/+fZrkaIBNAEnyLMnx8sdVUVpPspNkluRyuF3ALTNJspXka5KL5cNVUQK4di66gSqiBFQRJaCKKAFVfgGw30VQqzUq5QAAAABJRU5ErkJggg==\n", 172 | "text/plain": [ 173 | "
" 174 | ] 175 | }, 176 | "metadata": { 177 | "needs_background": "light" 178 | }, 179 | "output_type": "display_data" 180 | } 181 | ], 182 | "source": [ 183 | "COLOR_PALETTE1 = sns.color_palette(\"Set1\", 9, desat=1.)\n", 184 | "COLOR_PALETTE2 = sns.color_palette(\"Set1\", 9, desat=.7)\n", 185 | "COLOR_PALETTE3 = sns.color_palette(\"Set1\", 9, desat=.5)\n", 186 | "COLOR_PALETTE4 = sns.color_palette(\"Set1\", 9, desat=.3)\n", 187 | "# COLOR_PALETTE2 = sns.color_palette(\"Dark2\", 30)\n", 188 | "# COLOR_PALETTE = COLOR_PALETTE1[:8] + COLOR_PALETTE2[:30]\n", 189 | "COLOR_PALETTE = [COLOR_PALETTE1[8], COLOR_PALETTE1[0], COLOR_PALETTE1[1], COLOR_PALETTE1[3], COLOR_PALETTE4[3]]\n", 190 | "sns.palplot(COLOR_PALETTE)\n" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 11, 196 | "metadata": { 197 | "scrolled": false 198 | }, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "image/png": "\n", 203 | "text/plain": [ 204 | "
" 205 | ] 206 | }, 207 | "metadata": { 208 | "needs_background": "light" 209 | }, 210 | "output_type": "display_data" 211 | } 212 | ], 213 | "source": [ 214 | "fig, axs = plt.subplots(1, 1, figsize=(5,5))\n", 215 | "boxplot_all_methods(axs, results_tar_plot*100,\n", 216 | " title=\"MNIST: single source patch intervention\", names=names_short,\n", 217 | " color=np.array(COLOR_PALETTE)[:len(names_short)])\n", 218 | "\n", 219 | "plt.savefig(\"paper_figures/%s\" %\"MNIST_%s_2M.pdf\" %perturb, bbox_inches=\"tight\")\n", 220 | "plt.show()" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [] 229 | } 230 | ], 231 | "metadata": { 232 | "kernelspec": { 233 | "display_name": "Python 3", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.6.4" 248 | } 249 | }, 250 | "nbformat": 4, 251 | "nbformat_minor": 2 252 | } 253 | -------------------------------------------------------------------------------- /MNIST/simu_MNIST_patches.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import sys 10 | import argparse 11 | 12 | np.set_printoptions(precision=3) 13 | 14 | import torch 15 | 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.optim as optim 19 | 20 | import torchvision 21 | import time 22 | 23 | import simudata_MNIST 24 | import semitorchMNISTclass 25 | 26 | 27 | # In[2]: 28 | 29 | 30 | # check gpu avail 31 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 32 | 33 | # Assuming that we are on a CUDA machine, this should print a CUDA device: 34 | 35 | # print(device) 36 | 37 | 38 | # In[3]: 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--perturb", type=str, default="whitepatch", help="type of perturbation") 42 | parser.add_argument("--subset_prop", type=float, default=0.2, help="proportion of data points to be used for each env") 43 | parser.add_argument("--seed", type=int, default=0, help="seed") 44 | parser.add_argument("--lamMatch", type=float, default=1., help="DIP matching penalty") 45 | parser.add_argument("--lamMatchMMD", type=float, default=1., help="DIP matching penalty with MMD") 46 | parser.add_argument("--lamCIP", type=float, default=0.1, help="CIP matching penalty") 47 | parser.add_argument("--lamCIPMMD", type=float, default=0.1, help="CIP matching penalty with MMD") 48 | parser.add_argument("--epochs", type=int, default=50, help="number of epochs") 49 | parser.add_argument("--tag_DA", type=str, default="baseline", help="whether to run baseline methods or DA methods") 50 | parser.add_argument("--M", type=int, default=12, help="total number of environments") 51 | myargs = parser.parse_args() 52 | print(myargs) 53 | 54 | M = myargs.M 55 | train_batch_size = 500 56 | test_batch_size = 500 57 | np.random.seed(123456+myargs.seed) 58 | 59 | trainloaders, testloaders = simudata_MNIST.generate_MNIST_envs(perturb=myargs.perturb, subset_prop=myargs.subset_prop, 60 | M=M, interY=True, 61 | train_batch_size=train_batch_size, 62 | test_batch_size=test_batch_size) 63 | 64 | lamL2 = 0. 65 | lamL1 = 0. 66 | 67 | lr = 1e-4 68 | 69 | source = list(np.arange(M)) 70 | target = M-1 71 | source.remove(target) 72 | 73 | savefilename_prefix = 'results_MNIST/report_v8_%s_M%d_subsetprop%s_%s_lamMatch%s_lamCIP%s_lamMatchMMD%s_lamCIPMMD%s_epochs%d_seed%d' % (myargs.perturb, M, 74 | str(myargs.subset_prop), myargs.tag_DA, str(myargs.lamMatch), str(myargs.lamCIP), str(myargs.lamMatchMMD), 75 | str(myargs.lamCIPMMD), myargs.epochs, myargs.seed) 76 | savefilename = '%s.txt' % savefilename_prefix 77 | savefile = open(savefilename, 'w') 78 | 79 | if myargs.tag_DA == 'baseline': 80 | methods = [ 81 | semitorchMNISTclass.Original(), 82 | semitorchMNISTclass.Tar(lamL2=lamL2, lamL1=lamL1, lr=lr, epochs=myargs.epochs), 83 | semitorchMNISTclass.SrcPool(lamL2=lamL2, lamL1=lamL1, lr=lr, epochs=myargs.epochs), 84 | ] 85 | elif myargs.tag_DA == 'DAmean': 86 | methods = [ 87 | semitorchMNISTclass.DIP(lamMatch=myargs.lamMatch, lamL2=0., lamL1=0., 88 | sourceInd = 0, lr=lr, epochs=myargs.epochs, wayMatch='mean'), 89 | semitorchMNISTclass.DIPOracle(lamMatch=myargs.lamMatch, lamL2=0., lamL1=0., 90 | sourceInd = 0, lr=lr, epochs=myargs.epochs, wayMatch='mean'), 91 | semitorchMNISTclass.DIPweigh(lamMatch=myargs.lamMatch, lamL2=0., lamL1=0., 92 | lr=lr, epochs=myargs.epochs, wayMatch='mean'), 93 | semitorchMNISTclass.CIRMweigh(lamMatch=myargs.lamMatch, lamCIP=myargs.lamCIP, lamL2=0., lamL1=0., 94 | lr=lr, epochs=myargs.epochs, wayMatch='mean'), 95 | ] 96 | elif myargs.tag_DA == 'DAMMD': 97 | methods = [ 98 | semitorchMNISTclass.DIP_MMD(lamMatch=myargs.lamMatchMMD, lamL2=0., lamL1=0., 99 | sourceInd = 0, lr=lr, epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1., 10.]), 100 | semitorchMNISTclass.DIPweigh_MMD(lamMatch=myargs.lamMatchMMD, lamL2=0., lamL1=0., 101 | lr=lr, epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1., 10.]), 102 | semitorchMNISTclass.CIRMweigh_MMD(lamMatch=myargs.lamMatchMMD, lamCIP=myargs.lamCIPMMD, lamL2=0., lamL1=0., 103 | lr=lr, epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1., 10.]), 104 | ] 105 | elif myargs.tag_DA == 'DACIPMMD': 106 | methods = [ 107 | semitorchMNISTclass.CIP_MMD(lamCIP=myargs.lamCIPMMD, lamL2=0., lamL1=0., 108 | lr=lr, epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1., 10.]), 109 | ] 110 | elif myargs.tag_DA == 'DACIPmean': 111 | methods = [ 112 | semitorchMNISTclass.CIP(lamCIP=myargs.lamCIP, lamL2=0., lamL1=0., 113 | lr=lr, epochs=myargs.epochs, wayMatch='mean'), 114 | ] 115 | else: 116 | print('tag_DA unrecognized') 117 | 118 | names = [str(m) for m in methods] 119 | print(names, file=savefile) 120 | 121 | 122 | 123 | trained_methods = [None]*len(methods) 124 | results_src_all = np.zeros((M-1, len(methods), 2)) 125 | results_tar_all = np.zeros((len(methods), 2)) 126 | 127 | def compute_accuracy(loader, env, me): 128 | total = 0 129 | correct = 0 130 | with torch.no_grad(): 131 | for data in loader[env]: 132 | images, labels = data[0].to(device), data[1].to(device) 133 | predicted = me.predict(images) 134 | total += labels.size(0) 135 | correct += (predicted == labels).sum().item() 136 | return correct/total 137 | 138 | for i, me in enumerate(methods): 139 | starttime = time.time() 140 | print("fitting %s" %names[i], file=savefile) 141 | me = me.fit(trainloaders, source=source, target=target) 142 | if hasattr(me, 'losses'): 143 | print(me.losses, file=savefile) 144 | if hasattr(me, 'minDiffIndx'): 145 | print("best index="+str(me.minDiffIndx), file=savefile) 146 | trained_methods[i] = me 147 | # evaluate the methods 148 | # target train and test accuracy 149 | results_tar_all[i, 0] = compute_accuracy(trainloaders, target, me) 150 | results_tar_all[i, 1] = compute_accuracy(testloaders, target, me) 151 | # source train and test accuracy 152 | for j, sourcej in enumerate(source): 153 | results_src_all[j, i, 0] = compute_accuracy(trainloaders, sourcej, me) 154 | results_src_all[j, i, 1] = compute_accuracy(testloaders, sourcej, me) 155 | 156 | print('Method %-30s, Target %d, Source accuracy: %.3f %%, Target accuracy: %.3f %%' % (names[i], target, 157 | 100 * results_tar_all[i, 0], 100 * results_tar_all[i, 1]), file=savefile) 158 | endtime = time.time() 159 | print("time elapsed: %.1f s" % (endtime - starttime), file=savefile) 160 | print("\n", file=savefile) 161 | 162 | results_all = {} 163 | results_all['src'] = results_src_all 164 | results_all['tar'] = results_tar_all 165 | np.save("%s.npy" %savefilename_prefix, results_all) 166 | 167 | -------------------------------------------------------------------------------- /MNIST/simu_MNIST_patches_2M.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import sys 10 | import argparse 11 | 12 | np.set_printoptions(precision=3) 13 | 14 | import torch 15 | 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.optim as optim 19 | 20 | import torchvision 21 | import time 22 | 23 | import simudata_MNIST 24 | import semitorchMNISTclass 25 | 26 | 27 | # In[2]: 28 | 29 | 30 | # check gpu avail 31 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 32 | 33 | # Assuming that we are on a CUDA machine, this should print a CUDA device: 34 | 35 | # print(device) 36 | 37 | 38 | # In[3]: 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--perturb", type=str, default="whitepatch2M", help="type of perturbation") 42 | parser.add_argument("--subset_prop", type=float, default=0.2, help="proportion of data points to be used for each env") 43 | parser.add_argument("--seed", type=int, default=0, help="seed") 44 | parser.add_argument("--lamMatch", type=float, default=1., help="DIP matching penalty") 45 | parser.add_argument("--lamMatchMMD", type=float, default=1., help="DIP matching penalty with MMD") 46 | parser.add_argument("--epochs", type=int, default=50, help="number of epochs") 47 | parser.add_argument("--tag_DA", type=str, default="baseline", help="whether to run baseline methods or DA methods") 48 | myargs = parser.parse_args() 49 | print(myargs) 50 | 51 | M = 2 52 | train_batch_size = 64 53 | test_batch_size = 500 54 | save_model = False 55 | np.random.seed(123456+myargs.seed) 56 | 57 | trainloaders, testloaders = simudata_MNIST.generate_MNIST_envs(perturb=myargs.perturb, subset_prop=myargs.subset_prop, 58 | M=M, interY=False, 59 | train_batch_size=train_batch_size, 60 | test_batch_size=test_batch_size) 61 | 62 | lamL2 = 0. 63 | lamL1 = 0. 64 | lr = 1e-4 65 | 66 | source = list(np.arange(M)) 67 | target = M-1 68 | source.remove(target) 69 | 70 | savefilename_prefix = 'results_MNIST/report_v8_%s_M%d_subsetprop%s_%s_lamMatch%s_lamMatchMMD%s_epochs%d_seed%d' % (myargs.perturb, M, 71 | str(myargs.subset_prop), myargs.tag_DA, myargs.lamMatch, myargs.lamMatchMMD, myargs.epochs, myargs.seed) 72 | savefilename = '%s.txt' % savefilename_prefix 73 | savefile = open(savefilename, 'w') 74 | 75 | if myargs.tag_DA == 'baseline': 76 | methods = [ 77 | semitorchMNISTclass.Original(), 78 | semitorchMNISTclass.Tar(lamL2=lamL2, lamL1=lamL1, lr=lr, epochs=myargs.epochs), 79 | semitorchMNISTclass.SrcPool(lamL2=lamL2, lamL1=lamL1, lr=lr, epochs=myargs.epochs), 80 | ] 81 | else: # DIP 82 | methods = [ 83 | semitorchMNISTclass.DIP(lamMatch=myargs.lamMatch, lamL2=0., lamL1=0., 84 | sourceInd = 0, lr=lr, epochs=myargs.epochs, wayMatch='mean'), 85 | semitorchMNISTclass.DIPOracle(lamMatch=myargs.lamMatch, lamL2=0., lamL1=0., 86 | sourceInd = 0, lr=lr, epochs=myargs.epochs, wayMatch='mean'), 87 | semitorchMNISTclass.DIP_MMD(lamMatch=myargs.lamMatchMMD, lamL2=0., lamL1=0., 88 | sourceInd = 0, lr=lr, epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1., 10.]), 89 | ] 90 | 91 | names = [str(m) for m in methods] 92 | print(names, file=savefile) 93 | 94 | 95 | 96 | trained_methods = [None]*len(methods) 97 | results_src_all = np.zeros((M-1, len(methods), 2)) 98 | results_tar_all = np.zeros((len(methods), 2)) 99 | 100 | def compute_accuracy(loader, env, me): 101 | total = 0 102 | correct = 0 103 | with torch.no_grad(): 104 | for data in loader[env]: 105 | images, labels = data[0].to(device), data[1].to(device) 106 | predicted = me.predict(images) 107 | total += labels.size(0) 108 | correct += (predicted == labels).sum().item() 109 | return correct/total 110 | 111 | for i, me in enumerate(methods): 112 | starttime = time.time() 113 | print("fitting %s" %names[i], file=savefile) 114 | me = me.fit(trainloaders, source=source, target=target) 115 | trained_methods[i] = me 116 | # evaluate the methods 117 | # target train and test accuracy 118 | results_tar_all[i, 0] = compute_accuracy(trainloaders, target, me) 119 | results_tar_all[i, 1] = compute_accuracy(testloaders, target, me) 120 | # source train and test accuracy 121 | for j, sourcej in enumerate(source): 122 | results_src_all[j, i, 0] = compute_accuracy(trainloaders, sourcej, me) 123 | results_src_all[j, i, 1] = compute_accuracy(testloaders, sourcej, me) 124 | 125 | 126 | print('Method %-30s, Target %d, Source accuracy: %.3f %%, Target accuracy: %.3f %%' % (names[i], target, 127 | 100 * results_tar_all[i, 0], 100 * results_tar_all[i, 1]), file=savefile) 128 | endtime = time.time() 129 | print("time elapsed: %.1f s" % (endtime - starttime), file=savefile) 130 | print("\n", file=savefile) 131 | 132 | results_all = {} 133 | results_all['src'] = results_src_all 134 | results_all['tar'] = results_tar_all 135 | np.save("%s.npy" %savefilename_prefix, results_all) 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /MNIST/simudata_MNIST.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | import torchvision 9 | import time 10 | 11 | 12 | # This file generates MNIST perturbed environments 13 | datapath = '/cluster/home/chenyua/Code/causal/data' 14 | 15 | def get_actual_data_idx(all_labels, subset_prop, interY=False): 16 | num_classes = 10 17 | # mask a certain proportion of data 18 | data_idx_mask_prop = torch.rand_like(all_labels.float()) < subset_prop 19 | 20 | if not interY: 21 | data_idx_list = torch.where(data_idx_mask_prop)[0] 22 | else: 23 | # intervention on Y 24 | # the following digits will only have 50% data 25 | mod_digits = [3, 4, 5, 6, 8, 9] 26 | data_idx_mask_mod = torch.rand_like(all_labels.float()) < 0.8 27 | idx_trainset_mod = (all_labels == mod_digits[0]) 28 | for digit in mod_digits: 29 | idx_trainset_mod = idx_trainset_mod | (all_labels == digit) 30 | 31 | data_idx_mask_final = data_idx_mask_prop & (~(data_idx_mask_mod & idx_trainset_mod)) 32 | data_idx_list = torch.where(data_idx_mask_final)[0] 33 | return data_idx_list 34 | 35 | 36 | def generate_MNIST_envs(perturb='noisepatch', subset_prop=0.1, M=12, interY=False, train_batch_size=64, test_batch_size=1000): 37 | trainset_original = torchvision.datasets.MNIST(root=datapath, train=True, 38 | download=False) 39 | testset_original = torchvision.datasets.MNIST(root=datapath, train=False, 40 | download=False) 41 | 42 | # to be commented 43 | # idx_trainsetlist = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=interY) 44 | # idx_testsetlist = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=interY) 45 | 46 | 47 | 48 | trainloaders = {} 49 | testloaders = {} 50 | if perturb == 'whitepatch': 51 | # prepare the noise patces 52 | # noise_patches = [torch.zeros(1, 28, 28)] 53 | noise_patches = [] 54 | offset = 10 55 | initpos = 2 56 | # sqsize 12 works to make CIRM better than CIP 57 | # offset = 4, sqsize = 12, initpos = 6, interY = 0.3 works for CIRM better than CIP, DIP 58 | sqsizesmall = 12 59 | sqsizelarge = 16 60 | if M == 12: 61 | for m in range(M-2): 62 | a = torch.zeros(1, 28, 28) 63 | a[0, (initpos-m+offset):initpos+sqsizelarge-m+offset, (initpos-m+offset):initpos+sqsizelarge-m+offset] = 3.25 64 | noise_patches.append(a) 65 | 66 | for m in [M-2, M-1]: 67 | a = torch.zeros(1, 28, 28) 68 | a[0, (initpos-m+offset):initpos+sqsizesmall-m+offset, (initpos-m+offset):initpos+sqsizesmall-m+offset] = 3.25 69 | noise_patches.append(a) 70 | elif M == 6: 71 | for m in range(M): 72 | a = torch.zeros(1, 28, 28) 73 | a[0, (initpos-2*m+offset):initpos+sqsizelarge-2*m+offset, (initpos-2*m+offset):initpos+sqsizelarge-2*m+offset] = 3.25 74 | noise_patches.append(a) 75 | 76 | 77 | # now transform the data 78 | for m in range(M): 79 | # load MNIST data 80 | transformer = torchvision.transforms.Compose( 81 | [ 82 | torchvision.transforms.ToTensor(), 83 | torchvision.transforms.Normalize((0.1306,), (0.3081,)), 84 | torchvision.transforms.Lambda((lambda y: lambda x: torch.max(x, noise_patches[y]))(m)), 85 | ]) 86 | 87 | trainset = torchvision.datasets.MNIST(root=datapath, train=True, 88 | download=False, transform=transformer) 89 | 90 | testset = torchvision.datasets.MNIST(root=datapath, train=False, 91 | download=False, transform=transformer) 92 | 93 | if m != M-1: 94 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=False) 95 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=False) 96 | else: 97 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=interY) 98 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=interY) 99 | print('actual env=%d trainsize %s, testsize %s' %(m, idx_trainsetlist_loc.shape, idx_testsetlist_loc.shape)) 100 | 101 | trainloaders[m] = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, num_workers=0, 102 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_trainsetlist_loc)) 103 | testloaders[m] = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, num_workers=0, 104 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_testsetlist_loc)) 105 | trainloaders[m].dataset.targets_mod = trainloaders[m].dataset.targets[idx_trainsetlist_loc] 106 | testloaders[m].dataset.targets_mod = testloaders[m].dataset.targets[idx_testsetlist_loc] 107 | elif perturb == 'whitepatch2M': 108 | # prepare the noise patces 109 | # noise_patches = [torch.zeros(1, 28, 28)] 110 | noise_patches = [] 111 | offset = 10 112 | initpos = 2 113 | # sqsize 12 works to make CIRM better than CIP 114 | # offset = 4, sqsize = 12, initpos = 6, interY = 0.3 works for CIRM better than CIP, DIP 115 | sqsizesmall = 12 116 | for m in range(M): 117 | a = torch.zeros(1, 28, 28) 118 | # this will make the pixel white, 3.25 is because of normalization 119 | a[0, (initpos-m*5+offset):initpos+sqsizesmall-m+offset, (initpos-m+offset):initpos+sqsizesmall-m*5+offset] = 3.25 120 | noise_patches.append(a) 121 | 122 | # now transform the data 123 | for m in range(M): 124 | # load MNIST data 125 | transformer = torchvision.transforms.Compose( 126 | [ 127 | torchvision.transforms.ToTensor(), 128 | torchvision.transforms.Normalize((0.1306,), (0.3081,)), 129 | torchvision.transforms.Lambda((lambda y: lambda x: torch.max(x, noise_patches[y]))(m)), 130 | ]) 131 | 132 | trainset = torchvision.datasets.MNIST(root=datapath, train=True, 133 | download=True, transform=transformer) 134 | 135 | testset = torchvision.datasets.MNIST(root=datapath, train=False, 136 | download=True, transform=transformer) 137 | 138 | if m != M-1: 139 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=False) 140 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=False) 141 | else: 142 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=interY) 143 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=interY) 144 | print('actual env=%d trainsize %s, testsize %s' %(m, idx_trainsetlist_loc.shape, idx_testsetlist_loc.shape)) 145 | 146 | trainloaders[m] = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, num_workers=0, 147 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_trainsetlist_loc)) 148 | testloaders[m] = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, num_workers=0, 149 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_testsetlist_loc)) 150 | trainloaders[m].dataset.targets_mod = trainloaders[m].dataset.targets[idx_trainsetlist_loc] 151 | testloaders[m].dataset.targets_mod = testloaders[m].dataset.targets[idx_testsetlist_loc] 152 | elif perturb == 'noisepatch': 153 | # prepare the noise patces 154 | noise_patches = [] 155 | offset = 10 156 | initpos = 2 157 | # sqsize 12 works to make CIRM better than CIP 158 | # offset = 4, sqsize = 12, initpos = 6, interY = 0.3 works for CIRM better than CIP, DIP 159 | sqsizesmall = 12 160 | sqsizelarge = 16 161 | for m in range(M-2): 162 | a = torch.zeros(1, 28, 28) 163 | a[0, (initpos-m+offset):initpos+sqsizelarge-m+offset, (initpos-m+offset):initpos+sqsizelarge-m+offset] = 3.25 * (torch.rand(1, sqsizelarge, sqsizelarge) > 0.5) 164 | noise_patches.append(a) 165 | 166 | for m in [M-2, M-1]: 167 | a = torch.zeros(1, 28, 28) 168 | a[0, (initpos-m+offset):initpos+sqsizesmall-m+offset, (initpos-m+offset):initpos+sqsizesmall-m+offset] = 3.25 * (torch.rand(1, sqsizesmall, sqsizesmall) > 0.5) 169 | noise_patches.append(a) 170 | 171 | # now transform the data 172 | for m in range(M): 173 | # load MNIST data 174 | transformer = torchvision.transforms.Compose( 175 | [ 176 | torchvision.transforms.ToTensor(), 177 | torchvision.transforms.Normalize((0.1306,), (0.3081,)), 178 | torchvision.transforms.Lambda((lambda y: lambda x: torch.max(x, noise_patches[y]))(m)), 179 | ]) 180 | 181 | trainset = torchvision.datasets.MNIST(root=datapath, train=True, 182 | download=True, transform=transformer) 183 | 184 | testset = torchvision.datasets.MNIST(root=datapath, train=False, 185 | download=True, transform=transformer) 186 | 187 | if m != M-1: 188 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=False) 189 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=False) 190 | else: 191 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=interY) 192 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=interY) 193 | print('actual env=%d trainsize %s, testsize %s' %(m, idx_trainsetlist_loc.shape, idx_testsetlist_loc.shape)) 194 | 195 | trainloaders[m] = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, num_workers=0, 196 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_trainsetlist_loc)) 197 | testloaders[m] = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, num_workers=0, 198 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_testsetlist_loc)) 199 | trainloaders[m].dataset.targets_mod = trainloaders[m].dataset.targets[idx_trainsetlist_loc] 200 | testloaders[m].dataset.targets_mod = testloaders[m].dataset.targets[idx_testsetlist_loc] 201 | 202 | elif perturb == 'rotation': 203 | angles = np.zeros(M) 204 | if M == 12: 205 | angles = np.arange(M) * 10 - 45 206 | elif M == 10: 207 | angles = np.arange(M) * 10 - 35 208 | angles[M-1] = 50 209 | elif M == 5: 210 | angles = np.arange(M) * 15 - 30 211 | # now transform the data 212 | for m in range(M): 213 | # load MNIST data 214 | transformer = torchvision.transforms.Compose( 215 | [torchvision.transforms.RandomRotation((angles[m], angles[m])), 216 | torchvision.transforms.ToTensor(), 217 | torchvision.transforms.Normalize((0.1306,), (0.3081,)) 218 | ]) 219 | trainset = torchvision.datasets.MNIST(root=datapath, train=True, 220 | download=True, transform=transformer) 221 | 222 | testset = torchvision.datasets.MNIST(root=datapath, train=False, 223 | download=True, transform=transformer) 224 | 225 | if m != M-1: 226 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=False) 227 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=False) 228 | else: 229 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=interY) 230 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=interY) 231 | print('actual env=%d trainsize %s, testsize %s' %(m, idx_trainsetlist_loc.shape, idx_testsetlist_loc.shape)) 232 | 233 | trainloaders[m] = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, num_workers=0, 234 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_trainsetlist_loc)) 235 | testloaders[m] = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, num_workers=0, 236 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_testsetlist_loc)) 237 | trainloaders[m].dataset.targets_mod = trainloaders[m].dataset.targets[idx_trainsetlist_loc] 238 | testloaders[m].dataset.targets_mod = testloaders[m].dataset.targets[idx_testsetlist_loc] 239 | elif perturb == 'rotation2M' or perturb == 'rotation2Ma': 240 | if perturb == 'rotation2M': 241 | angles = [30, 45] 242 | else: 243 | angles = [10, 45] 244 | # now transform the data 245 | for m in range(M): 246 | # load MNIST data 247 | transformer = torchvision.transforms.Compose( 248 | [torchvision.transforms.RandomRotation((angles[m], angles[m])), 249 | torchvision.transforms.ToTensor(), 250 | torchvision.transforms.Normalize((0.1306,), (0.3081,)) 251 | ]) 252 | trainset = torchvision.datasets.MNIST(root=datapath, train=True, 253 | download=True, transform=transformer) 254 | 255 | testset = torchvision.datasets.MNIST(root=datapath, train=False, 256 | download=True, transform=transformer) 257 | 258 | if m != M-1: 259 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=False) 260 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=False) 261 | else: 262 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=interY) 263 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=interY) 264 | print('actual env=%d trainsize %s, testsize %s' %(m, idx_trainsetlist_loc.shape, idx_testsetlist_loc.shape)) 265 | 266 | trainloaders[m] = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, num_workers=0, 267 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_trainsetlist_loc)) 268 | testloaders[m] = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, num_workers=0, 269 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_testsetlist_loc)) 270 | trainloaders[m].dataset.targets_mod = trainloaders[m].dataset.targets[idx_trainsetlist_loc] 271 | testloaders[m].dataset.targets_mod = testloaders[m].dataset.targets[idx_testsetlist_loc] 272 | 273 | elif perturb == 'translation2M': 274 | translates = [(0.2, 0), (0, 0.2)] 275 | # now transform the data 276 | for m in range(M): 277 | # load MNIST data 278 | transformer = torchvision.transforms.Compose( 279 | [torchvision.transforms.RandomAffine(0, translates[m]), 280 | torchvision.transforms.ToTensor(), 281 | torchvision.transforms.Normalize((0.1306,), (0.3081,)) 282 | ]) 283 | trainset = torchvision.datasets.MNIST(root=datapath, train=True, 284 | download=True, transform=transformer) 285 | 286 | testset = torchvision.datasets.MNIST(root=datapath, train=False, 287 | download=True, transform=transformer) 288 | 289 | if m != M-1: 290 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=False) 291 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=False) 292 | else: 293 | idx_trainsetlist_loc = get_actual_data_idx(trainset_original.targets, subset_prop=subset_prop, interY=interY) 294 | idx_testsetlist_loc = get_actual_data_idx(testset_original.targets, subset_prop=subset_prop, interY=interY) 295 | print('actual env=%d trainsize %s, testsize %s' %(m, idx_trainsetlist_loc.shape, idx_testsetlist_loc.shape)) 296 | 297 | trainloaders[m] = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, num_workers=0, 298 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_trainsetlist_loc)) 299 | testloaders[m] = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, num_workers=0, 300 | sampler = torch.utils.data.sampler.SubsetRandomSampler(idx_testsetlist_loc)) 301 | trainloaders[m].dataset.targets_mod = trainloaders[m].dataset.targets[idx_trainsetlist_loc] 302 | testloaders[m].dataset.targets_mod = testloaders[m].dataset.targets[idx_testsetlist_loc] 303 | 304 | return trainloaders, testloaders 305 | 306 | -------------------------------------------------------------------------------- /MNIST/submit_simu_MNIST_patches.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import subprocess 6 | 7 | 8 | # perturb = 'whitepatch' 9 | perturb = 'rotation' 10 | epochs=100 11 | M = 5 12 | lamMatches = [10.**(k) for k in (np.arange(10)-5)] 13 | lamCIPs = [10.**(k) for k in (np.arange(10)-5)] 14 | 15 | tag_DA = 'baseline' 16 | for seed in range(10): 17 | for k in [2]: 18 | subprocess.call(['bsub', '-W 3:50', '-n 4', '-R', "rusage[ngpus_excl_p=1,mem=2048]", 19 | "./simu_MNIST_patches.py --perturb=%s --M=%d --subset_prop=%.1f --seed=%d --epochs=%d --tag_DA=%s" %(perturb, M, k/10, seed, epochs, tag_DA)]) 20 | 21 | for tag_DA in ['DACIPmean']: 22 | for lam in lamCIPs: 23 | for seed in range(10): 24 | for k in [2]: 25 | subprocess.call(['bsub', '-W 3:50', '-n 8', '-R', "rusage[ngpus_excl_p=1,mem=2048]", 26 | "./simu_MNIST_patches.py --perturb=%s --M=%d --subset_prop=%.1f --seed=%d --lamMatch=%f --lamCIP=%f --lamMatchMMD=%f --lamCIPMMD=%f --epochs=%d --tag_DA=%s" %(perturb, M, k/10, seed, 1., lam, 1., lam, epochs, tag_DA)]) 27 | 28 | for tag_DA in ['DACIPMMD']: 29 | for lam in lamCIPs: 30 | for seed in range(10): 31 | for k in [2]: 32 | subprocess.call(['bsub', '-W 23:50', '-n 8', '-R', "rusage[ngpus_excl_p=1,mem=2048]", 33 | "./simu_MNIST_patches.py --perturb=%s --M=%d --subset_prop=%.1f --seed=%d --lamMatch=%f --lamCIP=%f --lamMatchMMD=%f --lamCIPMMD=%f --epochs=%d --tag_DA=%s" %(perturb, M, k/10, seed, 1., lam, 1., lam, epochs, tag_DA)]) 34 | 35 | # pick lamCIP after looking at CIP source results 36 | lamCIP = 1. 37 | for tag_DA in ['DAmean']: 38 | for lam in lamMatches: 39 | for seed in range(10): 40 | for k in [2]: 41 | subprocess.call(['bsub', '-W 23:50', '-n 4', '-R', "rusage[ngpus_excl_p=1,mem=2048]", 42 | "./simu_MNIST_patches.py --perturb=%s --M=%d --subset_prop=%.1f --seed=%d --lamMatch=%f --lamCIP=%f --lamMatchMMD=%f --lamCIPMMD=%f --epochs=%d --tag_DA=%s" %(perturb, M, k/10, seed, lam, lamCIP, lam, lamCIP, epochs, tag_DA)]) 43 | 44 | 45 | lamCIP = 1. 46 | for tag_DA in ['DAMMD']: 47 | for lam in lamMatches: 48 | for seed in range(10): 49 | for k in [2]: 50 | subprocess.call(['bsub', '-W 23:50', '-n 8', '-R', "rusage[ngpus_excl_p=1,mem=2048]", 51 | "./simu_MNIST_patches.py --perturb=%s --M=%d --subset_prop=%.1f --seed=%d --lamMatch=%f --lamCIP=%f --lamMatchMMD=%f --lamCIPMMD=%f --epochs=%d --tag_DA=%s" %(perturb, M, k/10, seed, lam, lamCIP, lam, lamCIP, epochs, tag_DA)]) 52 | -------------------------------------------------------------------------------- /MNIST/submit_simu_MNIST_patches_2M.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import subprocess 6 | 7 | 8 | #perturb = 'whitepatch2M' 9 | #perturb = 'rotation2M' 10 | perturb = 'rotation2Ma' 11 | #perturb = 'translation2M' 12 | tag_DA = 'baseline' 13 | epochs = 100 14 | for seed in range(10): 15 | for k in [2]: 16 | subprocess.call(['bsub', '-W 3:50', '-n 4', '-R', "rusage[ngpus_excl_p=1,mem=2048]", 17 | "./simu_MNIST_patches_2M.py --perturb=%s --subset_prop=%.1f --seed=%d --epochs=%d --tag_DA=%s" %(perturb, k/10, seed, epochs, tag_DA)]) 18 | 19 | tag_DA = 'DIP' 20 | lamMatches = [10.**(k) for k in (np.arange(10)-5)] 21 | for lam in lamMatches: 22 | for seed in range(10): 23 | for k in [2]: 24 | subprocess.call(['bsub', '-W 3:50', '-n 4', '-R', "rusage[ngpus_excl_p=1,mem=2048]", 25 | "./simu_MNIST_patches_2M.py --perturb=%s --subset_prop=%.1f --seed=%d --lamMatch=%f --lamMatchMMD=%f --epochs=%d --tag_DA=%s" %(perturb, k/10, seed, lam, lam, epochs, tag_DA)]) 26 | 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CausalDA 2 | Code to reproduce the numerical experiments in the paper "Domain adaptation under structural causal models" (https://arxiv.org/abs/2010.15764) 3 | by Yuansi Chen and Peter Bühlmann. 4 | 5 | Code written with Python 3.7.1 and PyTorch version as follows 6 | 7 | torch==1.2.0 8 | torchvision==0.4.0 9 | 10 | 11 | 12 | ## User Guide 13 | 14 | - semiclass.py implements the main DA methods 15 | - semitorchclass, semitorchstocclass implement the same functions with PyTorch 16 | - semitorchMNISTclass is tailored to convolutional neural nets 17 | - Linear SCM simulations: first seven experiments 18 | - sim_linearSCM_mean_shift_exp1-7.ipynb can run on a single core and plot 19 | - Linear SCM simulations: last two experiments 20 | - Run the following simulations in a computer cluster 21 | - sim_linearSCM_var_shift_exp8_box_submit.py 22 | - sim_linearSCM_var_shift_exp8_scat_submit.py 23 | - sim_linearSCM_var_shift_exp9_scat_submit.py 24 | - Read the results and plot with sim_linearSCM_variance_shift_exp8-9.ipynb 25 | - MNIST experiments: 26 | - Need to set the MNIST data folder! 27 | - Run mnist_get_pretrained.ipynb to get a pretrained CNN on original MNIST 28 | - Single source exp: run submit_simu_MNIST_patches_2M.py 29 | - Mutiple source exp: run submit_simu_MNIST_patches.py 30 | - Read the results and plot with MNIST_read_and_plot_whitepatch2M.ipynb and MNIST_read_and_plot_rotation5M.ipynb 31 | - Amazon review dataset experiments 32 | - Need to set the Amazon review data folder! 33 | - Preprocess the data with read_and_preprocess_amazon_review_data_2018_subset.ipynb 34 | - Run the simulations with submit_amazon_review_data_2018_subset_regression.py 35 | - Plot with amazon_read_and_plot.ipynb 36 | 37 | 38 | 39 | ## License and Citation 40 | Code is released under MIT License. 41 | Please cite our paper if the code helps your research. 42 | 43 | ```bibtex 44 | @article{chen2020domain, 45 | title={Domain adaptation under structural causal models}, 46 | author={Chen, Yuansi and Peter B{\"u}hlmann}, 47 | journal={arXiv preprint arXiv:2010.15764}, 48 | year={2018} 49 | } 50 | ``` -------------------------------------------------------------------------------- /amazon/amazon_review_data_2018_subset_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | 6 | import pandas as pd 7 | import os 8 | import h5py 9 | import sys 10 | import argparse 11 | 12 | import torch 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | np.set_printoptions(precision=3) 19 | 20 | from sklearn.linear_model import Ridge 21 | from sklearn.model_selection import train_test_split 22 | from sklearn.feature_extraction.text import TfidfVectorizer 23 | 24 | # local packages 25 | import sys 26 | sys.path.append('../') 27 | import semiclass 28 | import semitorchclass 29 | import semitorchstocclass 30 | import util 31 | 32 | # check gpu avail 33 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 34 | 35 | # Assuming that we are on a CUDA machine, this should print a CUDA device: 36 | 37 | print(device) 38 | 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--lamMatch", type=float, default=1., help="DIP matching penalty") 41 | parser.add_argument("--lamCIP", type=float, default=0.1, help="CIP matching penalty") 42 | parser.add_argument("--lamL2", type=float, default=1., help="L2 penalty") 43 | parser.add_argument("--tag_DA", type=str, default="baseline", help="choose whether to run baseline methods or DA methods") 44 | parser.add_argument("--seed", type=int, default=0, help="seed of experiment") 45 | parser.add_argument("--target", type=int, default=0, help="target category") 46 | parser.add_argument("--minDf", type=float, default=0.008, help="minimum term frequency") 47 | parser.add_argument("--epochs", type=int, default=2000, help="number of epochs") 48 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 49 | myargs = parser.parse_args() 50 | print(myargs) 51 | 52 | data_folder = 'data/amazon_review_data_2018_subset' 53 | 54 | # 'All_Beauty_5', 'AMAZON_FASHION_5', 'Appliances_5', 'Gift_Cards_5', 'Magazine_Subscriptions_5' 55 | # are removed for now for not enough number of data points 56 | categories = [ 57 | 'Arts_Crafts_and_Sewing_5', 'Automotive_5', 'CDs_and_Vinyl_5', 58 | 'Cell_Phones_and_Accessories_5', 'Digital_Music_5', 59 | 'Grocery_and_Gourmet_Food_5', 'Industrial_and_Scientific_5', 'Luxury_Beauty_5', 60 | 'Musical_Instruments_5', 'Office_Products_5', 61 | 'Patio_Lawn_and_Garden_5', 'Pet_Supplies_5', 'Prime_Pantry_5', 62 | 'Software_5', 'Tools_and_Home_Improvement_5', 'Toys_and_Games_5'] 63 | 64 | nb_reviews = 10000 65 | dfs = {} 66 | for i, cate in enumerate(categories): 67 | df = pd.read_csv('%s/%s_%d.csv' %(data_folder, cate, nb_reviews)) 68 | dfs[i] = df 69 | print(cate, dfs[i].shape) 70 | 71 | allReviews = pd.concat([dfs[i]['reviewText'] for i in range(len(categories))]) 72 | ngramMin = 1 73 | ngramMax = 2 74 | stop_words = 'english' 75 | vectTF = TfidfVectorizer(min_df = myargs.minDf, stop_words=stop_words, ngram_range=(ngramMin, ngramMax)).fit(allReviews.values.astype('U')) 76 | print("Number of reviews=%d, feature size=%d" %(allReviews.shape[0], len(vectTF.get_feature_names()))) 77 | print(vectTF.vocabulary_) 78 | 79 | lamL1 = 0. 80 | 81 | if myargs.tag_DA == 'baseline': 82 | methods = [ 83 | semiclass.Tar(lamL2=myargs.lamL2), 84 | semiclass.SrcPool(lamL2=myargs.lamL2), 85 | ] 86 | elif myargs.tag_DA == 'DAmean': 87 | methods = [ 88 | semiclass.DIP(lamMatch=myargs.lamMatch, lamL2=myargs.lamL2, sourceInd=0), 89 | semiclass.DIPOracle(lamMatch=myargs.lamMatch, lamL2=myargs.lamL2, sourceInd=0), 90 | semiclass.DIPweigh(lamMatch=myargs.lamMatch, lamL2=myargs.lamL2), 91 | semiclass.CIP(lamCIP=myargs.lamCIP, lamL2=myargs.lamL2), 92 | semiclass.CIRMweigh(lamCIP=myargs.lamCIP, lamMatch=myargs.lamMatch, lamL2=myargs.lamL2), 93 | ] 94 | elif myargs.tag_DA == 'DAstd': 95 | methods = [ 96 | semitorchclass.DIP(lamMatch=myargs.lamMatch, lamL2=myargs.lamL2, lamL1=lamL1, sourceInd=0, lr=myargs.lr, 97 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 98 | semitorchclass.DIPweigh(lamMatch=myargs.lamMatch, lamL2=myargs.lamL2, lamL1=lamL1, lr=myargs.lr, 99 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 100 | semitorchclass.CIP(lamCIP=myargs.lamCIP, lamL2=myargs.lamL2, lamL1=lamL1, lr=myargs.lr, 101 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 102 | semitorchclass.CIRMweigh(lamMatch=myargs.lamMatch, lamL2=myargs.lamL2, lamL1=lamL1, lr=myargs.lr, 103 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 104 | ] 105 | elif myargs.tag_DA == 'DAMMD': 106 | methods = [ 107 | semitorchstocclass.DIP(lamMatch=myargs.lamMatch, lamL2=myargs.lamL2, lamL1=lamL1, sourceInd = 0, lr=myargs.lr, 108 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 109 | semitorchstocclass.DIPweigh(lamMatch=myargs.lamMatch, lamL2=myargs.lamL2, lamL1=lamL1, lr=myargs.lr, 110 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 111 | semitorchstocclass.CIP(lamCIP=myargs.lamCIP, lamL2=myargs.lamL2, lamL1=lamL1, lr=myargs.lr, 112 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 113 | semitorchstocclass.CIRMweigh(lamMatch=myargs.lamMatch, lamCIP=myargs.lamCIP, lamL2=myargs.lamL2, lamL1=lamL1, lr=myargs.lr, 114 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]) 115 | ] 116 | 117 | names = [str(m) for m in methods] 118 | print(names) 119 | 120 | # random data split 121 | random_state = 123456 + myargs.seed 122 | datasets = {} 123 | datasets_test = {} 124 | for i, cate in enumerate(categories): 125 | X_train_raw, X_test_raw, y_train, y_test = train_test_split(dfs[i]['reviewText'].astype('U'), dfs[i]['overall'].astype('U'), test_size=0.10, random_state = random_state) 126 | 127 | 128 | X_train = vectTF.transform(X_train_raw) 129 | X_test = vectTF.transform(X_test_raw) 130 | 131 | 132 | datasets[i] = np.array(X_train.todense()), np.array(y_train, dtype=np.float32) 133 | datasets_test[i] = np.array(X_test.todense()), np.array(y_test, dtype=np.float32) 134 | 135 | print(cate, X_train.shape, X_test.shape) 136 | 137 | # normalize 138 | Xmean = 0 139 | N = 0 140 | for i, cate in enumerate(categories): 141 | Xmean += np.sum(datasets[i][0], axis=0) 142 | N += datasets[i][0].shape[0] 143 | Xmean /= N 144 | 145 | 146 | Xvar = 0 147 | for i, cate in enumerate(categories): 148 | Xvar += np.sum((datasets[i][0] - Xmean.reshape(1, -1))**2, axis=0) 149 | N += datasets[i][0].shape[0] 150 | 151 | Xvar /= N 152 | Xstd = np.sqrt(Xvar) 153 | 154 | for i, cate in enumerate(categories): 155 | x, y = datasets[i] 156 | x_test, y_test = datasets_test[i] 157 | datasets[i] = (x - Xmean)/Xstd, y 158 | datasets_test[i] = (x_test - Xmean)/Xstd, y_test 159 | 160 | # create torch format data 161 | dataTorch = {} 162 | dataTorchTest = {} 163 | 164 | for i in range(len(categories)): 165 | dataTorch[i] = [torch.from_numpy(datasets[i][0].astype(np.float32)).to(device), 166 | torch.from_numpy(datasets[i][1].astype(np.float32)).to(device)] 167 | dataTorchTest[i] = [torch.from_numpy(datasets_test[i][0].astype(np.float32)).to(device), 168 | torch.from_numpy(datasets_test[i][1].astype(np.float32)).to(device)] 169 | 170 | train_batch_size = 500 171 | test_batch_size = 500 172 | 173 | trainloaders = {} 174 | testloaders = {} 175 | 176 | for i in range(len(categories)): 177 | train_dataset = torch.utils.data.TensorDataset(torch.Tensor(datasets[i][0]), 178 | torch.Tensor(datasets[i][1])) 179 | test_dataset = torch.utils.data.TensorDataset(torch.Tensor(datasets_test[i][0]), 180 | torch.Tensor(datasets_test[i][1])) 181 | trainloaders[i] = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size) 182 | testloaders[i] = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size) 183 | 184 | 185 | M = len(categories) 186 | source = [i for i in range(M)] 187 | source.remove(myargs.target) 188 | 189 | print("source =", source, "target =", myargs.target, flush=True) 190 | results_src_all = np.zeros((M-1, len(methods), 2)) 191 | results_tar_all = np.zeros((len(methods), 2)) 192 | results_minDiffIndx = {} 193 | labeledsize_list = np.arange(1, 11) * 20 194 | results_tar_sub_all = np.zeros((len(methods), len(labeledsize_list))) 195 | for i, m in enumerate(methods): 196 | if m.__module__ == 'semiclass': 197 | me = m.fit(datasets, source=source, target=myargs.target) 198 | if hasattr(me, 'minDiffIndx'): 199 | print("best index="+str(me.minDiffIndx)) 200 | results_minDiffIndx[(myargs.tag_DA, i)] = me.minDiffIndx 201 | xtar, ytar= datasets[myargs.target] 202 | xtar_test, ytar_test= datasets_test[myargs.target] 203 | targetE = util.MSE(me.ypred, ytar) 204 | targetNE = util.MSE(me.predict(xtar_test), ytar_test) 205 | for j, sourcej in enumerate(source): 206 | results_src_all[j, i, 0] = util.MSE(me.predict(datasets[sourcej][0]), datasets[sourcej][1]) 207 | results_src_all[j, i, 1] = util.MSE(me.predict(datasets_test[sourcej][0]), datasets_test[sourcej][1]) 208 | # obtain target error for each labeledsize 209 | for k, labeledsize in enumerate(labeledsize_list): 210 | xtar_sub = xtar[:labeledsize, :] 211 | ytar_sub = ytar[:labeledsize] 212 | results_tar_sub_all[i, k] = util.MSE(me.predict(xtar_sub), ytar_sub) 213 | elif m.__module__ == 'semitorchclass': 214 | me = m.fit(dataTorch, source=source, target=myargs.target) 215 | if hasattr(me, 'minDiffIndx'): 216 | print("best index="+str(me.minDiffIndx)) 217 | results_minDiffIndx[(myargs.tag_DA, i)] = me.minDiffIndx 218 | xtar, ytar= dataTorch[myargs.target] 219 | xtar_test, ytar_test= dataTorchTest[myargs.target] 220 | targetE = util.torchMSE(me.ypred, ytar) 221 | targetNE = util.torchMSE(me.predict(xtar_test), ytar_test) 222 | for j, sourcej in enumerate(source): 223 | results_src_all[j, i, 0] = util.torchMSE(me.predict(dataTorch[sourcej][0]), dataTorch[sourcej][1]) 224 | results_src_all[j, i, 1] = util.torchMSE(me.predict(dataTorchTest[sourcej][0]), dataTorchTest[sourcej][1]) 225 | for k, labeledsize in enumerate(labeledsize_list): 226 | xtar_sub = xtar[:labeledsize, :] 227 | ytar_sub = ytar[:labeledsize] 228 | results_tar_sub_all[i, k] = util.torchMSE(me.predict(xtar_sub), ytar_sub) 229 | elif m.__module__ == 'semitorchstocclass': 230 | me = m.fit(trainloaders, source=source, target=myargs.target) 231 | targetE = util.torchloaderMSE(me, trainloaders[myargs.target], device) 232 | targetNE = util.torchloaderMSE(me, testloaders[myargs.target], device) 233 | for j, sourcej in enumerate(source): 234 | results_src_all[j, i, 0] = util.torchloaderMSE(me, trainloaders[sourcej], device) 235 | results_src_all[j, i, 1] = util.torchloaderMSE(me, testloaders[sourcej], device) 236 | else: 237 | raise ValueError('error') 238 | results_tar_all[i, 0] = targetE 239 | results_tar_all[i, 1] = targetNE 240 | 241 | res_all = {} 242 | res_all['src'] = results_src_all 243 | res_all['tar'] = results_tar_all 244 | res_all['minDiffIndx'] = results_minDiffIndx 245 | res_all['tar_sub'] = results_tar_sub_all 246 | res_all['labeledsize_list'] = labeledsize_list 247 | 248 | np.save('results_amazon/amazon_review_data_2018_N%d_%s_minDf%s_lamL2%s_lamMatch%s_lamCIP%s_target%d_seed%d.npy' %( 249 | nb_reviews, myargs.tag_DA, myargs.minDf, myargs.lamL2, myargs.lamMatch, myargs.lamCIP, myargs.target, myargs.seed), res_all) 250 | 251 | 252 | -------------------------------------------------------------------------------- /amazon/read_and_preprocess_amazon_review_data_2018_subset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import pandas as pd\n", 12 | "import gzip\n", 13 | "import json\n", 14 | "import os\n", 15 | "import h5py\n", 16 | "\n", 17 | "\n", 18 | "plt.rcParams['axes.facecolor'] = 'lightgray'\n", 19 | "\n", 20 | "np.set_printoptions(precision=3)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "def parse(path):\n", 30 | " g = gzip.open(path, 'rb')\n", 31 | " for l in g:\n", 32 | " yield json.loads(l)\n", 33 | "\n", 34 | "def getDF(path):\n", 35 | " i = 0\n", 36 | " df = {}\n", 37 | " for d in parse(path):\n", 38 | " df[i] = d\n", 39 | " i += 1\n", 40 | " return pd.DataFrame.from_dict(df, orient='index')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "categories = [\n", 50 | " 'Arts_Crafts_and_Sewing_5', 'Automotive_5', 'CDs_and_Vinyl_5',\n", 51 | " 'Cell_Phones_and_Accessories_5', 'Digital_Music_5',\n", 52 | " 'Grocery_and_Gourmet_Food_5', 'Industrial_and_Scientific_5', 'Luxury_Beauty_5',\n", 53 | " 'Musical_Instruments_5', 'Office_Products_5',\n", 54 | " 'Patio_Lawn_and_Garden_5', 'Pet_Supplies_5', 'Prime_Pantry_5',\n", 55 | " 'Software_5', 'Tools_and_Home_Improvement_5', 'Toys_and_Games_5']" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "Clothing_Shoes_and_Jewelry_5 (10000, 12)\n", 68 | "Electronics_5 (10000, 12)\n", 69 | "Home_and_Kitchen_5 (10000, 12)\n", 70 | "Kindle_Store_5 (10000, 12)\n", 71 | "Movies_and_TV_5 (10000, 12)\n", 72 | "Sports_and_Outdoors_5 (10000, 12)\n", 73 | "Video_Games_5 (10000, 12)\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "N = 10000\n", 79 | "np.random.seed(123456)\n", 80 | "dfs = {}\n", 81 | "for i, cate in enumerate(categories):\n", 82 | " df = getDF('data/amazon_review_data_2018_subset/%s.json.gz' %cate)\n", 83 | " df = df[~df.reviewText.isna()]\n", 84 | " if df.shape[0] > N:\n", 85 | " df = df.sample(n=N)\n", 86 | " dfs[i] = df\n", 87 | " \n", 88 | " dfs[i].to_csv(\"data/amazon_review_data_2018_subset/%s_%d.csv\" %(cate, N), index=False)\n", 89 | " print(cate, dfs[i].shape)\n", 90 | " \n", 91 | " " 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "Python 3", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.6.5" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 2 144 | } 145 | -------------------------------------------------------------------------------- /amazon/results_amazon/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /amazon/submit_amazon_review_data_2018_subset_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import subprocess 6 | import numpy as np 7 | 8 | tag_DA = 'baseline' 9 | lamL2s = [10.**(k) for k in (np.arange(10)-5)] 10 | for target in range(16): 11 | for lamL2 in lamL2s: 12 | for myseed in range(10): 13 | subprocess.call(['bsub', '-W 03:50', '-n 8', '-R', "rusage[mem=4096]", "./amazon_review_data_2018_subset_regression.py --target=%d --tag_DA=%s --lamL2=%s --seed=%d" %(target, tag_DA, lamL2, myseed)]) 14 | 15 | tag_DA = 'DAmean' 16 | lamMatches = [10.**(k) for k in (np.arange(10)-5)] 17 | for target in range(16): 18 | for lam in lamMatches: 19 | for myseed in range(10): 20 | subprocess.call(['bsub', '-W 03:50', '-n 8', '-R', "rusage[mem=4096]", "./amazon_review_data_2018_subset_regression.py --target=%d --tag_DA=%s --lamL2=%s --lamMatch=%s --epochs=%d --seed=%d" %(target, tag_DA, 1.0, lam, 20000, myseed)]) 21 | 22 | 23 | # tag_DA = 'DAstd' 24 | # for target in range(16): 25 | # for lam in lamMatches: 26 | # for myseed in range(10): 27 | # subprocess.call(['bsub', '-W 23:50', '-n 8', '-R', "rusage[mem=4096]", "./amazon_review_data_2018_subset_regression.py --target=%d --tag_DA=%s --lamMatch=%s --epochs=%d --seed=%d" %(target, tag_DA, lam, 20000, myseed)]) 28 | -------------------------------------------------------------------------------- /mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _mix_rbf_kernel(X, Y, sigma_list): 4 | m = X.size(0) 5 | 6 | Z = torch.cat((X, Y), 0) 7 | ZZT = torch.mm(Z, Z.t()) 8 | diag_ZZT = torch.diag(ZZT).unsqueeze(1) 9 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 10 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() 11 | 12 | K = 0.0 13 | for sigma in sigma_list: 14 | gamma = 1.0 / (2 * sigma**2) 15 | K += torch.exp(-gamma * exponent) 16 | 17 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) 18 | 19 | def mix_rbf_mmd2(X, Y, sigma_list): 20 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 21 | return _mmd2(K_XX, K_XY, K_YY) 22 | 23 | 24 | ################################################################################ 25 | # Helper functions to compute variances based on kernel matrices 26 | ################################################################################ 27 | 28 | 29 | def _mmd2(K_XX, K_XY, K_YY): 30 | m = K_XX.size(0) 31 | l = K_YY.size(0) 32 | 33 | K_XX_sum = K_XX.sum() 34 | K_YY_sum = K_YY.sum() 35 | K_XY_sum = K_XY.sum() 36 | 37 | mmd2 = (K_XX_sum / (m * m) 38 | + K_YY_sum / (l * l) 39 | - 2.0 * K_XY_sum / (m * l)) 40 | 41 | return mmd2 42 | -------------------------------------------------------------------------------- /myrandom.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import abstractmethod 3 | 4 | class RandomGen(): 5 | """Base class for random number generation""" 6 | @abstractmethod 7 | def generate(self, m, n, d): 8 | pass 9 | 10 | class Gaussian(RandomGen): 11 | def __init__(self, M, meanAs, varAs): 12 | self.M = M 13 | self.meanAs = meanAs 14 | self.varAs = varAs 15 | 16 | def generate(self, m, n, d): 17 | vec = np.random.randn(n, d).dot(np.diag(np.sqrt(self.varAs[m, ]))) 18 | vec += self.meanAs[m, :].reshape(1, -1) 19 | return vec 20 | 21 | class Mix2Gaussian(RandomGen): 22 | def __init__(self, M, meanAsList, varAs): 23 | self.M = M 24 | self.meanAsList = meanAsList 25 | self.varAs = varAs 26 | 27 | def generate(self, m, n, d): 28 | vec = np.random.randn(n, d).dot(np.diag(np.sqrt(self.varAs[m, ]))) 29 | mixture_idx = np.random.choice(2, size=n, replace=True, p=[0.5, 0.5]) 30 | for i in range(n): 31 | vec[i, :] += self.meanAsList[mixture_idx[i]][m, :] 32 | return vec 33 | 34 | class MixkGaussian(RandomGen): 35 | def __init__(self, M, meanAsList, varAs): 36 | self.M = M 37 | self.meanAsList = meanAsList 38 | self.varAs = varAs 39 | self.k = len(self.meanAsList) 40 | 41 | def generate(self, m, n, d): 42 | vec = np.random.randn(n, d).dot(np.diag(np.sqrt(self.varAs[m, ]))) 43 | mixture_idx = np.random.choice(self.k, size=n, replace=True, p=np.ones(self.k)/self.k) 44 | for i in range(n): 45 | vec[i, :] += self.meanAsList[mixture_idx[i]][m, :] 46 | return vec 47 | -------------------------------------------------------------------------------- /sem.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | from networkx import nx 6 | 7 | # basic structural equation 8 | class SEM(): 9 | def __init__(self, B, noisef, interAf, invariantList=[], message=None): 10 | self.B = B 11 | self.dp1 = B.shape[0] 12 | self.M = interAf.M 13 | self.interAf = interAf 14 | self.noisef = noisef 15 | 16 | # just for plot purpose 17 | self.invariantList = invariantList 18 | self.message = message 19 | 20 | # raise an error if it is not invertible 21 | self.IBinv = np.linalg.inv(np.eye(self.dp1) - self.B) 22 | 23 | def generateSamples(self, n, m=0): 24 | # generate n samples from mth environment 25 | noise = self.noisef.generate(0, n, self.dp1) 26 | interA = self.interAf.generate(m, n, self.dp1) 27 | 28 | data = (noise + interA).dot(self.IBinv.T) 29 | 30 | # return x and y separately 31 | return data[:, :-1], data[:, -1] 32 | 33 | def generateAllSamples(self, n): 34 | res = {} 35 | for m in range(self.M): 36 | res[m] = self.generateSamples(n, m) 37 | 38 | return res 39 | 40 | def draw(self, layout='circular', figsize=(12, 8)): 41 | plt.figure(figsize=figsize) 42 | G = nx.DiGraph() 43 | G.add_nodes_from(np.arange(self.dp1)) 44 | for i in range(self.dp1): 45 | for j in range(self.dp1): 46 | if not np.isclose(self.B[i, j], 0): 47 | G.add_edge(j, i, weight=self.B[i, j]) 48 | 49 | # labels 50 | labels={} 51 | for i in range(self.dp1-1): 52 | labels[i] = "X"+str(i) 53 | labels[self.dp1-1] = "Y" 54 | 55 | # position 56 | if layout=="spring": 57 | pos = nx.spring_layout(G) 58 | elif layout=="kamada_kawai": 59 | pos = nx.kamada_kawai_layout(G) 60 | else: 61 | pos = nx.circular_layout(G) 62 | 63 | nx.draw_networkx_nodes(G,pos,nodelist = list(np.arange(self.dp1-1)), node_color='b', node_size=1000, alpha=0.5) 64 | nx.draw_networkx_nodes(G,pos,nodelist = [self.dp1-1], node_color='r', 65 | node_size=1000, alpha=0.8) 66 | nx.draw_networkx_nodes(G,pos,nodelist = list(self.invariantList), node_color='y', node_size=1000, alpha=0.8) 67 | nx.draw_networkx_edges(G,pos,width=3.0,alpha=0.5, arrowsize=40) 68 | 69 | arc_weight=nx.get_edge_attributes(G,'weight') 70 | arc_weight_format = {i:'{:.2f}'.format(arc_weight[i]) for i in arc_weight} 71 | 72 | 73 | nx.draw_networkx_edge_labels(G, pos,edge_color= 'k', label_pos=0.7, edge_labels=arc_weight_format) 74 | 75 | nx.draw_networkx_labels(G,pos,labels,font_size=16) 76 | plt.draw() 77 | plt.show() 78 | 79 | 80 | -------------------------------------------------------------------------------- /semitorchclass.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``semitorchclass`` provides classes implementing various domain adaptation methods using torch and gradient method. 3 | All domain adaptation methods have to be subclass of BaseEstimator. 4 | This implementation takes advantage of gradient method to optimize covariance match or MMD match in addition to mean match. 5 | """ 6 | 7 | from abc import abstractmethod 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | 15 | import mmd 16 | 17 | # check gpu avail 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | # simple linear model in torch 22 | class LinearModel(nn.Module): 23 | def __init__(self, d): 24 | super(LinearModel, self).__init__() 25 | self.lin1 = nn.Linear(d, 1, bias=True) 26 | 27 | def forward(self, x): 28 | x = self.lin1(x) 29 | return x 30 | 31 | class BaseEstimator(): 32 | """Base class for domain adaptation""" 33 | @abstractmethod 34 | def fit(self, data, source, target): 35 | """Fit model. 36 | Arguments: 37 | data (dict of (X, y) pairs): maps env index to the (X, y) pair in that env 38 | source (list of indexes): indexes of source envs 39 | target (int): single index of the target env 40 | """ 41 | self.source = source 42 | self.target = target 43 | 44 | return self 45 | 46 | @abstractmethod 47 | def predict(self, X): 48 | """Use the learned estimator to predict labels on fresh target data X 49 | """ 50 | 51 | def __str__(self): 52 | """For easy name printing 53 | """ 54 | return self.__class__.__name__ 55 | 56 | class ZeroBeta(BaseEstimator): 57 | """Estimator that sets beta to zero""" 58 | 59 | def fit(self, data, source, target): 60 | super().fit(data, source, target) 61 | 62 | d = data[target][0].shape[1] 63 | model = LinearModel(d).to(device) 64 | with torch.no_grad(): 65 | model.lin1.weight.data = torch.zeros_like(model.lin1.weight) 66 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 67 | 68 | self.model = model 69 | 70 | xtar, _ = data[target] 71 | self.ypred = self.model(xtar) 72 | 73 | return self 74 | 75 | def predict(self, X): 76 | ypredX = self.model(X) 77 | return ypredX 78 | 79 | class Tar(BaseEstimator): 80 | """Oracle Linear regression (with l1 or l2 penalty) trained on the target domain""" 81 | def __init__(self, lamL2=0.0, lamL1=0.0, lr=1e-4, epochs=10): 82 | self.lamL2 = lamL2 83 | self.lamL1 = lamL1 84 | self.lr = lr 85 | self.epochs = epochs 86 | 87 | def fit(self, data, source, target): 88 | super().fit(data, source, target) 89 | 90 | d = data[target][0].shape[1] 91 | model = LinearModel(d).to(device) 92 | with torch.no_grad(): 93 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 94 | # torch.nn.init.kaiming_normal_(model.lin1.weight, mode='fan_in') 95 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 96 | # Define loss function 97 | loss_fn = F.mse_loss 98 | opt = optim.Adam(model.parameters(), lr=self.lr) 99 | # opt = optim.SGD(model.parameters(), lr=self.lr, momentum=0.9) 100 | 101 | self.losses = np.zeros(self.epochs) 102 | xtar, ytar = data[target] 103 | # oracle estimator uses target labels 104 | for epoch in range(self.epochs): 105 | opt.zero_grad() 106 | loss = loss_fn(model(xtar), ytar.view(-1, 1)) + \ 107 | self.lamL2 * torch.sum(model.lin1.weight ** 2) + \ 108 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 109 | # Perform gradient descent 110 | loss.backward() 111 | opt.step() 112 | self.losses[epoch] = loss.item() 113 | 114 | self.model = model 115 | 116 | self.ypred = self.model(xtar) 117 | 118 | return self 119 | 120 | def predict(self, X): 121 | ypredX = self.model(X) 122 | return ypredX 123 | 124 | def __str__(self): 125 | return self.__class__.__name__ + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 126 | 127 | 128 | class SrcPool(BaseEstimator): 129 | """Pool all source data together and then run linear regression 130 | with l1 or l2 penalty """ 131 | def __init__(self, lamL2=0.0, lamL1=0.0, lr=1e-4, epochs=10): 132 | self.lamL2 = lamL2 133 | self.lamL1 = lamL1 134 | self.lr = lr 135 | self.epochs = epochs 136 | 137 | def fit(self, data, source, target): 138 | super().fit(data, source, target) 139 | 140 | d = data[target][0].shape[1] 141 | model = LinearModel(d).to(device) 142 | # custom initialization 143 | with torch.no_grad(): 144 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 145 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 146 | # Define loss function 147 | loss_fn = F.mse_loss 148 | opt = optim.Adam(model.parameters(), lr=self.lr) 149 | # opt = optim.SGD(model.parameters(), lr=self.lr, momentum=0.9) 150 | 151 | self.losses = np.zeros(self.epochs) 152 | 153 | for epoch in range(self.epochs): 154 | loss = 0 155 | opt.zero_grad() 156 | for m in source: 157 | x, y = data[m] 158 | loss += loss_fn(model(x), y.view(-1, 1))/len(source) 159 | 160 | loss += self.lamL2 * torch.sum(model.lin1.weight ** 2) 161 | loss += self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 162 | # Perform gradient descent 163 | loss.backward() 164 | opt.step() 165 | self.losses[epoch] = loss.item() 166 | self.model = model 167 | 168 | xtar, _ = data[target] 169 | self.ypred = self.model(xtar) 170 | 171 | return self 172 | 173 | def predict(self, X): 174 | ypredX = self.model(X) 175 | return ypredX 176 | 177 | def __str__(self): 178 | return self.__class__.__name__ + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 179 | 180 | def wayMatchSelector(wayMatch='mean'): 181 | if wayMatch == 'mean': 182 | return [lambda x: torch.mean(x, dim=0)] 183 | elif wayMatch == 'std': 184 | return [lambda x: torch.std(x, dim=0)] 185 | elif wayMatch == '25p': 186 | return [lambda x: torch.kthvalue(x, (1 + round(.25 * (x.shape[0] - 1))), dim=0)] 187 | elif wayMatch == '75p': 188 | return [lambda x: torch.kthvalue(x, (1 + round(.75 * (x.shape[0] - 1))), dim=0)] 189 | elif wayMatch == 'mean+std': 190 | return [lambda x: torch.mean(x, dim=0), lambda x: torch.std(x, dim=0)] 191 | elif wayMatch == 'mean+std+25p': 192 | return [lambda x: torch.mean(x, dim=0), lambda x: torch.std(x, dim=0), lambda x: torch.kthvalue(x, (1 + round(.25 * (x.shape[0] - 1))), dim=0)[0]] 193 | elif wayMatch == 'mean+std+25p+75p': 194 | return [lambda x: torch.mean(x, dim=0), lambda x: torch.std(x, dim=0), lambda x: torch.kthvalue(x, (1 + round(.25 * (x.shape[0] - 1))), dim=0)[0], 195 | lambda x: torch.kthvalue(x, (1 + round(.75 * (x.shape[0] - 1))), dim=0)[0]] 196 | else: 197 | print("Error: wayMatch not specified correctly, using mean") 198 | return [lambda x: torch.mean(x, 0)] 199 | 200 | 201 | 202 | class DIP(BaseEstimator): 203 | """Pick one source, match mean of X * beta between source and target""" 204 | def __init__(self, lamMatch=10., lamL2=0., lamL1=0., sourceInd = 0, lr=1e-4, epochs=10, wayMatch='mean'): 205 | self.lamMatch = lamMatch 206 | self.lamL2 = lamL2 207 | self.lamL1 = lamL1 208 | self.sourceInd = sourceInd 209 | self.lr = lr 210 | self.epochs = epochs 211 | self.wayMatch = wayMatchSelector(wayMatch) 212 | 213 | def fit(self, data, source, target): 214 | super().fit(data, source, target) 215 | 216 | d = data[target][0].shape[1] 217 | model = LinearModel(d).to(device) 218 | # custom initialization 219 | with torch.no_grad(): 220 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 221 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 222 | # Define loss function 223 | loss_fn = F.mse_loss 224 | opt = optim.Adam(model.parameters(), lr=self.lr) 225 | 226 | self.losses = np.zeros(self.epochs) 227 | 228 | for epoch in range(self.epochs): 229 | x, y = data[source[self.sourceInd]] 230 | xtar, ytar = data[target] 231 | opt.zero_grad() 232 | loss = loss_fn(model(x), y.view(-1, 1)) + \ 233 | self.lamL2 * torch.sum(model.lin1.weight ** 2) + \ 234 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 235 | for wayMatchLocal in self.wayMatch: 236 | loss += self.lamMatch * loss_fn(wayMatchLocal(model(x)), wayMatchLocal(model(xtar))) 237 | 238 | # Perform gradient descent 239 | loss.backward() 240 | opt.step() 241 | 242 | self.losses[epoch] = loss.item() 243 | self.model = model 244 | 245 | xtar, _ = data[target] 246 | self.ypred = self.model(xtar) 247 | 248 | return self 249 | 250 | def predict(self, X): 251 | ypredX = self.model(X) 252 | return ypredX 253 | 254 | def __str__(self): 255 | return self.__class__.__name__ + "_Match{:.1f}".format(self.lamMatch) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 256 | 257 | 258 | class DIPOracle(BaseEstimator): 259 | """Pick one source, match mean of X * beta between source and target, use target labels to fit (oracle)""" 260 | def __init__(self, lamMatch=10., lamL2=0., lamL1=0., sourceInd = 0, lr=1e-4, epochs=10, wayMatch='mean'): 261 | self.lamMatch = lamMatch 262 | self.lamL2 = lamL2 263 | self.lamL1 = lamL1 264 | self.sourceInd = sourceInd 265 | self.lr = lr 266 | self.epochs = epochs 267 | self.wayMatch = wayMatchSelector(wayMatch) 268 | 269 | def fit(self, data, source, target): 270 | super().fit(data, source, target) 271 | 272 | d = data[target][0].shape[1] 273 | model = LinearModel(d).to(device) 274 | # custom initialization 275 | with torch.no_grad(): 276 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 277 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 278 | # Define loss function 279 | loss_fn = F.mse_loss 280 | opt = optim.Adam(model.parameters(), lr=self.lr) 281 | 282 | self.losses = np.zeros(self.epochs) 283 | 284 | for epoch in range(self.epochs): 285 | x, y = data[source[self.sourceInd]] 286 | xtar, ytar = data[target] 287 | opt.zero_grad() 288 | loss = loss_fn(model(xtar), ytar.view(-1, 1)) + \ 289 | self.lamL2 * torch.sum(model.lin1.weight ** 2) + \ 290 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 291 | 292 | for wayMatchLocal in self.wayMatch: 293 | loss += self.lamMatch * loss_fn(wayMatchLocal(model(x)), wayMatchLocal(model(xtar))) 294 | 295 | # Perform gradient descent 296 | loss.backward() 297 | opt.step() 298 | 299 | self.losses[epoch] = loss.item() 300 | self.model = model 301 | 302 | xtar, _ = data[target] 303 | self.ypred = self.model(xtar) 304 | 305 | return self 306 | 307 | def predict(self, X): 308 | ypredX = self.model(X) 309 | return ypredX 310 | 311 | def __str__(self): 312 | return self.__class__.__name__ + "_Match{:.1f}".format(self.lamMatch) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 313 | 314 | 315 | class DIPweigh(BaseEstimator): 316 | '''loop throught all source envs, match the mean of X * beta between source env i and target, weigh the final prediction based loss of env i''' 317 | def __init__(self, lamMatch=10., lamL2=0., lamL1=0., lr=1e-4, 318 | epochs=10, wayMatch='mean'): 319 | self.lamMatch = lamMatch 320 | self.lamL2 = lamL2 321 | self.lamL1 = lamL1 322 | self.lr = lr 323 | self.epochs = epochs 324 | self.wayMatch = wayMatchSelector(wayMatch) 325 | 326 | def fit(self, data, source, target): 327 | super().fit(data, source, target) 328 | 329 | d = data[target][0].shape[1] 330 | models = {} 331 | diffs = {} 332 | ypreds = {} 333 | losses_all = {} 334 | for m in source: 335 | model = LinearModel(d).to(device) 336 | models[m] = model 337 | # custom initialization 338 | with torch.no_grad(): 339 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 340 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 341 | # Define loss function 342 | loss_fn = F.mse_loss 343 | opt = optim.Adam(model.parameters(), lr=self.lr) 344 | 345 | losses_all[m] = np.zeros(self.epochs) 346 | 347 | for epoch in range(self.epochs): 348 | x, y = data[m] 349 | xtar, ytar = data[target] 350 | opt.zero_grad() 351 | 352 | loss = loss_fn(model(x), y.view(-1, 1)) + \ 353 | self.lamL2* torch.sum(model.lin1.weight ** 2) + \ 354 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 355 | 356 | for wayMatchLocal in self.wayMatch: 357 | loss += self.lamMatch * loss_fn(wayMatchLocal(model(x)), wayMatchLocal(model(xtar))) 358 | # Perform gradient descent 359 | loss.backward() 360 | opt.step() 361 | 362 | losses_all[m][epoch] = loss.item() 363 | 364 | diffs[m] =0. 365 | for wayMatchLocal in self.wayMatch: 366 | diffs[m] += loss_fn(wayMatchLocal(model(x)), wayMatchLocal(model(xtar))) 367 | ypreds[m] = models[m](xtar) 368 | 369 | # take the min diff loss to be current best losses and model 370 | minDiff = diffs[source[0]] 371 | minDiffIndx = source[0] 372 | self.losses = losses_all[source[0]] 373 | for m in source: 374 | if diffs[m] < minDiff: 375 | minDiff = diffs[m] 376 | minDiffIndx = m 377 | self.losses = losses_all[m] 378 | self.model = models[m] 379 | 380 | self.minDiffIndx = minDiffIndx 381 | self.total_weight = 0 382 | self.ypred = 0 383 | for m in self.source: 384 | self.ypred += torch.exp(-100.*diffs[m]) * ypreds[m] 385 | self.total_weight += torch.exp(-100.*diffs[m]) 386 | self.ypred /= self.total_weight 387 | self.models = models 388 | self.diffs = diffs 389 | 390 | return self 391 | 392 | def predict(self, X): 393 | ypredX1 = 0 394 | for m in self.source: 395 | ypredX1 += torch.exp(-100.*self.diffs[m]) * self.models[m](X) 396 | ypredX1 /= self.total_weight 397 | return ypredX1 398 | 399 | def __str__(self): 400 | return self.__class__.__name__ + "_Match{:.1f}".format(self.lamMatch) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 401 | 402 | 403 | class CIP(BaseEstimator): 404 | """Match the conditional (on Y) mean of X * beta across source envs, no target env is needed""" 405 | def __init__(self, lamCIP=10., lamL2=0., lamL1=0., lr=1e-4, epochs=10, wayMatch='mean'): 406 | self.lamCIP = lamCIP 407 | self.lamL2 = lamL2 408 | self.lamL1 = lamL1 409 | self.lr = lr 410 | self.epochs = epochs 411 | self.wayMatch = wayMatchSelector(wayMatch) 412 | 413 | def fit(self, data, source, target): 414 | super().fit(data, source, target) 415 | 416 | d = data[target][0].shape[1] 417 | model = LinearModel(d).to(device) 418 | # custom initialization 419 | with torch.no_grad(): 420 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 421 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 422 | 423 | # Define loss function 424 | loss_fn = F.mse_loss 425 | opt = optim.Adam(model.parameters(), lr=self.lr) 426 | 427 | self.losses = np.zeros(self.epochs) 428 | for epoch in range(self.epochs): 429 | loss = 0 430 | avgmodelxList = [0.] * len(self.wayMatch) 431 | opt.zero_grad() 432 | for m in source: 433 | x, y = data[m] 434 | # do the conditional on y 435 | xmod = x - torch.mm(y.view(-1, 1), torch.mm(y.view(1, -1), x))/torch.sum(y**2) 436 | for i, wayMatchLocal in enumerate(self.wayMatch): 437 | avgmodelxList[i] += wayMatchLocal(model(xmod))/len(source) 438 | for m in source: 439 | x, y = data[m] 440 | xmod = x - torch.mm(y.view(-1, 1), torch.mm(y.view(1, -1), x))/torch.sum(y**2) 441 | loss += loss_fn(model(x), y.view(-1, 1))/len(source) 442 | for i, wayMatchLocal in enumerate(self.wayMatch): 443 | loss += self.lamCIP * loss_fn(avgmodelxList[i], wayMatchLocal(model(xmod)))/len(source) 444 | loss += self.lamL2 * torch.sum(model.lin1.weight ** 2) 445 | loss += self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 446 | # Perform gradient descent 447 | loss.backward() 448 | opt.step() 449 | 450 | self.losses[epoch] = loss.item() 451 | self.model = model 452 | 453 | xtar, _ = data[target] 454 | self.ypred = self.model(xtar) 455 | 456 | return self 457 | 458 | def predict(self, X): 459 | ypredX = self.model(X) 460 | return ypredX 461 | 462 | def __str__(self): 463 | return self.__class__.__name__ + "_CIP{:.1f}".format(self.lamCIP) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 464 | 465 | 466 | class CIRMweigh(BaseEstimator): 467 | """Match the conditional (on Y) mean of X * beta across source envs, use Yhat as proxy of Y to remove the Y parts in X. 468 | Match on the residual between one source env and target env""" 469 | def __init__(self, lamMatch=10., lamL2=0., lamL1=0., lr=1e-4, epochs=10, wayMatch='mean'): 470 | self.lamMatch = lamMatch 471 | self.lamL2 = lamL2 472 | self.lamL1 = lamL1 473 | self.lr = lr 474 | self.epochs = epochs 475 | self.wayMatch = wayMatchSelector(wayMatch) 476 | 477 | def fit(self, data, source, target): 478 | super().fit(data, source, target) 479 | 480 | d = data[target][0].shape[1] 481 | # Step 1: use source envs to match the conditional mean 482 | # find beta_invariant 483 | models1 = LinearModel(d).to(device) 484 | # custom initialization 485 | with torch.no_grad(): 486 | models1.lin1.bias.data = torch.zeros_like(models1.lin1.bias) 487 | torch.nn.init.xavier_normal_(models1.lin1.weight, gain=0.01) 488 | 489 | # Define loss function 490 | loss_fn = F.mse_loss 491 | opt = optim.Adam(models1.parameters(), lr=self.lr) 492 | 493 | losses1 = np.zeros(self.epochs) 494 | for epoch in range(self.epochs): 495 | loss = 0 496 | avgmodelxList = [0.] * len(self.wayMatch) 497 | opt.zero_grad() 498 | for m in source: 499 | x, y = data[m] 500 | # do the conditional on y 501 | xmod = x - torch.mm(y.view(-1, 1), torch.mm(y.view(1, -1), x))/torch.sum(y**2) 502 | for i, wayMatchLocal in enumerate(self.wayMatch): 503 | avgmodelxList[i] += wayMatchLocal(models1(xmod))/len(source) 504 | for m in source: 505 | x, y = data[m] 506 | xmod = x - torch.mm(y.view(-1, 1), torch.mm(y.view(1, -1), x))/torch.sum(y**2) 507 | loss += loss_fn(models1(x), y.view(-1, 1))/len(source) 508 | for i, wayMatchLocal in enumerate(self.wayMatch): 509 | loss += self.lamMatch * loss_fn(avgmodelxList[i], wayMatchLocal(models1(xmod)))/len(source) 510 | loss += self.lamL2 * torch.sum(models1.lin1.weight ** 2) 511 | loss += self.lamL1 * torch.sum(torch.abs(models1.lin1.weight)) 512 | # Perform gradient descent 513 | loss.backward() 514 | opt.step() 515 | losses1[epoch] = loss.item() 516 | 517 | self.models1 = models1 518 | 519 | # fix grads now 520 | for param in models1.lin1.parameters(): 521 | param.requires_grad = False 522 | 523 | # Step 2: remove the invariant part on all source envs, so that everything is independent of Y 524 | # get that coefficient b 525 | YsrcMean = 0 526 | ntotal = 0 527 | for m in source: 528 | YsrcMean += torch.sum(data[m][1]) 529 | ntotal += data[m][1].shape[0] 530 | YsrcMean /= ntotal 531 | 532 | YTX = 0 533 | YTY = 0 534 | for m in source: 535 | x, y = data[m] 536 | yguess = self.models1(x) 537 | yCentered = y - YsrcMean 538 | YTY += torch.sum(yguess.t() * yCentered) 539 | YTX += torch.mm(yCentered.view(1, -1), x) 540 | 541 | b = YTX / YTY 542 | self.b = b 543 | 544 | 545 | # Step 3: mean match between source and target on the residual, after transforming the covariates X - (X * beta_invariant) * b_invariant 546 | models = {} 547 | diffs = {} 548 | ypreds = {} 549 | losses_all = {} 550 | for m in source: 551 | models[m] = LinearModel(d).to(device) 552 | # custom initialization 553 | with torch.no_grad(): 554 | models[m].lin1.bias.data = torch.zeros_like(models[m].lin1.bias) 555 | torch.nn.init.xavier_normal_(models[m].lin1.weight, gain=0.01) 556 | # Define loss function 557 | loss_fn = F.mse_loss 558 | opt = optim.Adam(models[m].parameters(), lr=self.lr) 559 | 560 | losses_all[m] = np.zeros(self.epochs) 561 | x, y = data[m] 562 | xmod = x - torch.mm(self.models1(x), b) 563 | xtar, ytar = data[target] 564 | xtarmod = xtar - torch.mm(self.models1(xtar), b) 565 | 566 | for epoch in range(self.epochs): 567 | loss = loss_fn(models[m](x), y.view(-1, 1)) + \ 568 | self.lamL2 * torch.sum(models[m].lin1.weight ** 2) + \ 569 | self.lamL1 * torch.sum(torch.abs(models[m].lin1.weight)) 570 | 571 | for wayMatchLocal in self.wayMatch: 572 | loss += self.lamMatch * loss_fn(wayMatchLocal(models[m](xmod)), wayMatchLocal(models[m](xtarmod))) 573 | # Perform gradient descent 574 | loss.backward() 575 | opt.step() 576 | opt.zero_grad() 577 | losses_all[m][epoch] = loss.item() 578 | 579 | diffs[m] = 0. 580 | for wayMatchLocal in self.wayMatch: 581 | diffs[m] += loss_fn(wayMatchLocal(models[m](xmod)), wayMatchLocal(models[m](xtarmod))) 582 | ypreds[m] = models[m](xtar) 583 | 584 | # take the min diff loss to be current best losses and model 585 | minDiff = diffs[source[0]] 586 | minDiffIndx = source[0] 587 | self.losses = losses_all[source[0]] 588 | for m in source: 589 | if diffs[m] < minDiff: 590 | minDiff = diffs[m] 591 | minDiffIndx = m 592 | self.losses = losses_all[m] 593 | self.model = models[m] 594 | 595 | self.minDiffIndx = minDiffIndx 596 | self.total_weight = 0 597 | self.ypred = 0 598 | for m in self.source: 599 | self.ypred += torch.exp(-100.*diffs[m]) * ypreds[m] 600 | self.total_weight += torch.exp(-100.*diffs[m]) 601 | self.ypred /= self.total_weight 602 | self.models = models 603 | self.diffs = diffs 604 | 605 | return self 606 | 607 | def predict(self, X): 608 | ypredX1 = 0 609 | for m in self.source: 610 | ypredX1 += torch.exp(-100.*self.diffs[m]) * self.models[m](X) 611 | ypredX1 /= self.total_weight 612 | return ypredX1 613 | 614 | def __str__(self): 615 | return self.__class__.__name__ + "_Match{:.1f}".format(self.lamMatch) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 616 | 617 | 618 | 619 | -------------------------------------------------------------------------------- /semitorchstocclass.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``semitorchstocclass`` provides classes implementing various domain adaptation methods using torch and stochastic gradient method. 3 | All domain adaptation methods have to be subclass of BaseEstimator. 4 | This implementation takes advantage of gradient method to optimize covariance match or MMD match in addition to mean match. 5 | """ 6 | 7 | from abc import abstractmethod 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | 15 | import mmd 16 | 17 | # check gpu avail 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | # simple linear model in torch 22 | class LinearModel(nn.Module): 23 | def __init__(self, d): 24 | super(LinearModel, self).__init__() 25 | self.lin1 = nn.Linear(d, 1, bias=True) 26 | 27 | def forward(self, x): 28 | x = self.lin1(x) 29 | return x 30 | 31 | class BaseEstimator(): 32 | """Base class for domain adaptation""" 33 | @abstractmethod 34 | def fit(self, data, source, target): 35 | """Fit model. 36 | Arguments: 37 | data (dict of (X, y) pairs): maps env index to the (X, y) pair in that env 38 | source (list of indexes): indexes of source envs 39 | target (int): single index of the target env 40 | """ 41 | self.source = source 42 | self.target = target 43 | 44 | return self 45 | 46 | @abstractmethod 47 | def predict(self, X): 48 | """Use the learned estimator to predict labels on fresh target data X 49 | """ 50 | 51 | def __str__(self): 52 | """For easy name printing 53 | """ 54 | return self.__class__.__name__ 55 | 56 | 57 | class Tar(BaseEstimator): 58 | """Oracle Linear regression (with l1 or l2 penalty) trained on the target domain""" 59 | def __init__(self, lamL2=0.0, lamL1=0.0, lr=1e-4, epochs=10): 60 | self.lamL2 = lamL2 61 | self.lamL1 = lamL1 62 | self.lr = lr 63 | self.epochs = epochs 64 | 65 | def fit(self, dataloaders, source, target): 66 | super().fit(dataloaders, source, target) 67 | 68 | # get the input dimension 69 | # assume it is a TensorDataset 70 | d = dataloaders[target].dataset[0][0].shape[0] 71 | model = LinearModel(d).to(device) 72 | with torch.no_grad(): 73 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 74 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 75 | # Define loss function 76 | loss_fn = F.mse_loss 77 | opt = optim.Adam(model.parameters(), lr=self.lr) 78 | 79 | self.losses = np.zeros(self.epochs) 80 | for epoch in range(self.epochs): 81 | running_loss = 0.0 82 | for i, data in enumerate(dataloaders[target]): 83 | xtar, ytar = data[0].to(device), data[1].to(device) 84 | opt.zero_grad() 85 | loss = loss_fn(model(xtar).view(-1), ytar) + \ 86 | self.lamL2 * torch.sum(model.lin1.weight ** 2) + \ 87 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 88 | # Perform gradient descent 89 | loss.backward() 90 | opt.step() 91 | running_loss += loss.item() 92 | 93 | self.losses[epoch] = running_loss 94 | 95 | self.model = model 96 | 97 | return self 98 | 99 | def predict(self, X): 100 | ypredX = self.model(X) 101 | return ypredX 102 | 103 | def __str__(self): 104 | return self.__class__.__name__ + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 105 | 106 | 107 | 108 | class Src(BaseEstimator): 109 | """Src Linear regression (with l1 or l2 penalty) trained on the source domain""" 110 | def __init__(self, lamL2=0.0, lamL1=0.0, sourceInd = 0, lr=1e-4, epochs=10): 111 | self.lamL2 = lamL2 112 | self.lamL1 = lamL1 113 | self.sourceInd = sourceInd 114 | self.lr = lr 115 | self.epochs = epochs 116 | 117 | def fit(self, dataloaders, source, target): 118 | super().fit(dataloaders, source, target) 119 | 120 | # get the input dimension 121 | # assume it is a TensorDataset 122 | d = dataloaders[target].dataset[0][0].shape[0] 123 | model = LinearModel(d).to(device) 124 | with torch.no_grad(): 125 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 126 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 127 | # Define loss function 128 | loss_fn = F.mse_loss 129 | opt = optim.Adam(model.parameters(), lr=self.lr) 130 | 131 | self.losses = np.zeros(self.epochs) 132 | for epoch in range(self.epochs): 133 | running_loss = 0.0 134 | for i, data in enumerate(dataloaders[source[self.sourceInd]]): 135 | x, y = data[0].to(device), data[1].to(device) 136 | opt.zero_grad() 137 | loss = loss_fn(model(x).view(-1), y) + \ 138 | self.lamL2 * torch.sum(model.lin1.weight ** 2) + \ 139 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 140 | # Perform gradient descent 141 | loss.backward() 142 | opt.step() 143 | running_loss += loss.item() 144 | 145 | self.losses[epoch] = running_loss 146 | 147 | self.model = model 148 | 149 | return self 150 | 151 | def predict(self, X): 152 | ypredX = self.model(X) 153 | return ypredX 154 | 155 | def __str__(self): 156 | return self.__class__.__name__ + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 157 | 158 | 159 | 160 | class SrcPool(BaseEstimator): 161 | """Pool all source data together and then run linear regression 162 | with l1 or l2 penalty """ 163 | def __init__(self, lamL2=0.0, lamL1=0.0, lr=1e-4, epochs=10): 164 | self.lamL2 = lamL2 165 | self.lamL1 = lamL1 166 | self.lr = lr 167 | self.epochs = epochs 168 | 169 | def fit(self, dataloaders, source, target): 170 | super().fit(dataloaders, source, target) 171 | 172 | # get the input dimension 173 | # assume it is a TensorDataset 174 | d = dataloaders[target].dataset[0][0].shape[0] 175 | model = LinearModel(d).to(device) 176 | # custom initialization 177 | with torch.no_grad(): 178 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 179 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 180 | # Define loss function 181 | loss_fn = F.mse_loss 182 | opt = optim.Adam(model.parameters(), lr=self.lr) 183 | 184 | self.losses = np.zeros(self.epochs) 185 | 186 | for epoch in range(self.epochs): 187 | running_loss = 0.0 188 | for i, data in enumerate(zip(*[dataloaders[m] for m in source])): 189 | opt.zero_grad() 190 | loss = 0 191 | for mindex, m in enumerate(source): 192 | x, y = data[mindex][0].to(device), data[mindex][1].to(device) 193 | loss += loss_fn(model(x).view(-1), y) / len(source) 194 | 195 | loss += self.lamL2 * torch.sum(model.lin1.weight ** 2) 196 | loss += self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 197 | 198 | loss.backward() 199 | opt.step() 200 | 201 | running_loss += loss.item() 202 | 203 | self.losses[epoch] = running_loss 204 | self.model = model 205 | 206 | return self 207 | 208 | def predict(self, X): 209 | ypredX = self.model(X) 210 | return ypredX 211 | 212 | def __str__(self): 213 | return self.__class__.__name__ + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 214 | 215 | 216 | class DIP(BaseEstimator): 217 | """Pick one source, match mean of X * beta between source and target""" 218 | def __init__(self, lamMatch=10., lamL2=0., lamL1=0., sourceInd = 0, lr=1e-4, epochs=10, 219 | wayMatch='mean', sigma_list=[0.1, 1, 10, 100]): 220 | self.lamMatch = lamMatch 221 | self.lamL2 = lamL2 222 | self.lamL1 = lamL1 223 | self.sourceInd = sourceInd 224 | self.lr = lr 225 | self.epochs = epochs 226 | self.wayMatch = wayMatch 227 | self.sigma_list = sigma_list 228 | 229 | def fit(self, dataloaders, source, target): 230 | super().fit(dataloaders, source, target) 231 | 232 | d = dataloaders[target].dataset[0][0].shape[0] 233 | model = LinearModel(d).to(device) 234 | # custom initialization 235 | with torch.no_grad(): 236 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 237 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 238 | # Define loss function 239 | loss_fn = F.mse_loss 240 | opt = optim.Adam(model.parameters(), lr=self.lr) 241 | self.losses = np.zeros(self.epochs) 242 | 243 | for epoch in range(self.epochs): 244 | running_loss = 0.0 245 | 246 | for i, data in enumerate(zip(dataloaders[source[self.sourceInd]], dataloaders[target])): 247 | opt.zero_grad() 248 | loss = 0 249 | x, y = data[0][0].to(device), data[0][1].to(device) 250 | xtar = data[1][0].to(device) 251 | loss += loss_fn(model(x).view(-1), y) 252 | if self.wayMatch == 'mean': 253 | discrepancy = torch.nn.MSELoss() 254 | loss += self.lamMatch * discrepancy(model(x), model(xtar)) 255 | elif self.wayMatch == 'mmd': 256 | loss += self.lamMatch * mmd.mix_rbf_mmd2(model(x), model(xtar), self.sigma_list) 257 | else: 258 | print('error discrepancy') 259 | loss += self.lamL2 * torch.sum(model.lin1.weight ** 2) + \ 260 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 261 | 262 | 263 | loss.backward() 264 | opt.step() 265 | 266 | running_loss += loss.item() 267 | 268 | self.losses[epoch] = running_loss 269 | self.model = model 270 | 271 | return self 272 | 273 | def predict(self, X): 274 | ypredX = self.model(X) 275 | return ypredX 276 | 277 | def __str__(self): 278 | return self.__class__.__name__ + self.wayMatch + "_Match{:.1f}".format(self.lamMatch) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 279 | 280 | 281 | class DIPweigh(BaseEstimator): 282 | '''loop throught all source envs, match the mean of X * beta between source env i and target, weigh the final prediction based loss of env i''' 283 | def __init__(self, lamMatch=10., lamL2=0., lamL1=0., lr=1e-4, 284 | epochs=10, wayMatch='mean', sigma_list=[0.1, 1, 10, 100]): 285 | self.lamMatch = lamMatch 286 | self.lamL2 = lamL2 287 | self.lamL1 = lamL1 288 | self.lr = lr 289 | self.epochs = epochs 290 | self.wayMatch = wayMatch 291 | self.sigma_list = sigma_list 292 | 293 | def fit(self, dataloaders, source, target): 294 | super().fit(dataloaders, source, target) 295 | 296 | d = dataloaders[target].dataset[0][0].shape[0] 297 | models = {} 298 | diffs = {} 299 | ypreds = {} 300 | losses_all = {} 301 | self.total_weight = 0 302 | for m in source: 303 | model = LinearModel(d).to(device) 304 | models[m] = model 305 | # custom initialization 306 | with torch.no_grad(): 307 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 308 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 309 | # Define loss function 310 | loss_fn = F.mse_loss 311 | opt = optim.Adam(model.parameters(), lr=self.lr) 312 | 313 | losses_all[m] = np.zeros(self.epochs) 314 | 315 | for epoch in range(self.epochs): 316 | running_loss = 0.0 317 | 318 | for i, data in enumerate(zip(dataloaders[m], dataloaders[target])): 319 | opt.zero_grad() 320 | loss = 0 321 | x, y = data[0][0].to(device), data[0][1].to(device) 322 | xtar = data[1][0].to(device) 323 | loss += loss_fn(model(x).view(-1), y) 324 | if self.wayMatch == 'mean': 325 | discrepancy = torch.nn.MSELoss() 326 | loss += self.lamMatch * discrepancy(model(x), model(xtar)) 327 | elif self.wayMatch == 'mmd': 328 | loss += self.lamMatch * mmd.mix_rbf_mmd2(model(x), model(xtar), self.sigma_list) 329 | else: 330 | raise('error discrepancy') 331 | loss += self.lamL2 * torch.sum(model.lin1.weight ** 2) + \ 332 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 333 | 334 | 335 | loss.backward() 336 | opt.step() 337 | 338 | running_loss += loss.item() 339 | 340 | 341 | losses_all[m][epoch] = running_loss 342 | 343 | # need to calculate the diffs 344 | diffs[m] = 0 345 | with torch.no_grad(): 346 | for i, data in enumerate(zip(dataloaders[m], dataloaders[target])): 347 | x, y = data[0][0].to(device), data[0][1].to(device) 348 | xtar = data[1][0].to(device) 349 | if self.wayMatch == 'mean': 350 | discrepancy = torch.nn.MSELoss() 351 | local_match_res = discrepancy(model(x), model(xtar)) 352 | elif self.wayMatch == 'mmd': 353 | local_match_res = mmd.mix_rbf_mmd2(model(x), model(xtar), self.sigma_list) 354 | else: 355 | raise('error discrepancy') 356 | diffs[m] += local_match_res / self.epochs / (len(dataloaders[m].dataset)/dataloaders[m].batch_size) 357 | self.total_weight += torch.exp(-100.*diffs[m]) 358 | 359 | self.models = models 360 | self.diffs = diffs 361 | 362 | minDiff = diffs[source[0]] 363 | minDiffIndx = source[0] 364 | for m in source: 365 | if diffs[m] < minDiff: 366 | minDiff = diffs[m] 367 | minDiffIndx = m 368 | self.minDiffIndx = minDiffIndx 369 | print(minDiffIndx) 370 | self.losses = losses_all[minDiffIndx] 371 | 372 | return self 373 | 374 | def predict(self, X): 375 | ypredX1 = 0 376 | for m in self.source: 377 | ypredX1 += torch.exp(-100.*self.diffs[m]) * self.models[m](X) 378 | ypredX1 /= self.total_weight 379 | return ypredX1 380 | 381 | def __str__(self): 382 | return self.__class__.__name__ + self.wayMatch + "_Match{:.1f}".format(self.lamMatch) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 383 | 384 | 385 | 386 | 387 | class CIP(BaseEstimator): 388 | """Match the conditional (on Y) mean of X * beta across source envs, no target env is needed""" 389 | def __init__(self, lamCIP=10., lamL2=0., lamL1=0., lr=1e-4, epochs=10, 390 | wayMatch='mean', sigma_list = [0.1, 1, 10, 100]): 391 | self.lamCIP = lamCIP 392 | self.lamL2 = lamL2 393 | self.lamL1 = lamL1 394 | self.lr = lr 395 | self.epochs = epochs 396 | self.wayMatch = wayMatch 397 | self.sigma_list = sigma_list 398 | 399 | def fit(self, dataloaders, source, target): 400 | super().fit(dataloaders, source, target) 401 | 402 | d = dataloaders[target].dataset[0][0].shape[0] 403 | model = LinearModel(d).to(device) 404 | # custom initialization 405 | with torch.no_grad(): 406 | model.lin1.bias.data = torch.zeros_like(model.lin1.bias) 407 | torch.nn.init.xavier_normal_(model.lin1.weight, gain=0.01) 408 | 409 | # Define loss function 410 | loss_fn = F.mse_loss 411 | opt = optim.Adam(model.parameters(), lr=self.lr) 412 | 413 | self.losses = np.zeros(self.epochs) 414 | for epoch in range(self.epochs): 415 | running_loss = 0.0 416 | for i, data in enumerate(zip(*[dataloaders[m] for m in source])): 417 | opt.zero_grad() 418 | loss = 0 419 | for mindex, m in enumerate(source): 420 | x, y = data[mindex][0].to(device), data[mindex][1].to(device) 421 | loss += loss_fn(model(x).view(-1), y)/float(len(source)) 422 | xmod = x - torch.mm(y.view(-1, 1), torch.mm(y.view(1, -1), x))/torch.sum(y**2) 423 | 424 | # conditional invariance penalty 425 | for jindex, j in enumerate(source): 426 | if j > m: 427 | xj, yj = data[jindex][0].to(device), data[jindex][1].to(device) 428 | xmodj = xj - torch.mm(yj.view(-1, 1), torch.mm(yj.view(1, -1), xj))/torch.sum(yj**2) 429 | if self.wayMatch == 'mean': 430 | discrepancy = torch.nn.MSELoss() 431 | loss += self.lamCIP/float(len(source)**2) * discrepancy(model(xmod), model(xmodj)) 432 | elif self.wayMatch == 'mmd': 433 | loss += self.lamCIP/float(len(source)**2) * \ 434 | mmd.mix_rbf_mmd2(model(xmod), model(xmodj), self.sigma_list) 435 | else: 436 | raise('error discrepancy') 437 | 438 | loss += self.lamL2 * torch.sum(model.lin1.weight ** 2) + \ 439 | self.lamL1 * torch.sum(torch.abs(model.lin1.weight)) 440 | # Perform gradient descent 441 | loss.backward() 442 | opt.step() 443 | running_loss += loss.item() 444 | 445 | self.losses[epoch] = running_loss 446 | self.model = model 447 | 448 | return self 449 | 450 | def predict(self, X): 451 | ypredX = self.model(X) 452 | return ypredX 453 | 454 | def __str__(self): 455 | return self.__class__.__name__ + self.wayMatch + "_CIP{:.1f}".format(self.lamCIP) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 456 | 457 | 458 | 459 | class CIRMweigh(BaseEstimator): 460 | """Match the conditional (on Y) mean of X * beta across source envs, use Yhat as proxy of Y to remove the Y parts in X. 461 | Match on the residual between one source env and target env""" 462 | def __init__(self, lamMatch=10., lamCIP=10., lamL2=0., lamL1=0., lr=1e-4, epochs=10, 463 | wayMatch='mean', sigma_list=[0.1, 1, 10, 100]): 464 | self.lamMatch = lamMatch 465 | self.lamCIP = lamCIP 466 | self.lamL2 = lamL2 467 | self.lamL1 = lamL1 468 | self.lr = lr 469 | self.epochs = epochs 470 | self.wayMatch = wayMatch 471 | self.sigma_list = sigma_list 472 | 473 | def fit(self, dataloaders, source, target): 474 | super().fit(dataloaders, source, target) 475 | 476 | d = dataloaders[target].dataset[0][0].shape[0] 477 | # Step 1: use source envs to match the conditional mean 478 | # find beta_invariant 479 | models1 = LinearModel(d).to(device) 480 | # custom initialization 481 | with torch.no_grad(): 482 | models1.lin1.bias.data = torch.zeros_like(models1.lin1.bias) 483 | torch.nn.init.xavier_normal_(models1.lin1.weight, gain=0.01) 484 | 485 | # Define loss function 486 | loss_fn = F.mse_loss 487 | opt = optim.Adam(models1.parameters(), lr=self.lr) 488 | 489 | losses1 = np.zeros(self.epochs) 490 | for epoch in range(self.epochs): 491 | running_loss = 0.0 492 | for i, data in enumerate(zip(*[dataloaders[m] for m in source])): 493 | loss = 0 494 | for mindex, m in enumerate(source): 495 | x, y = data[mindex][0].to(device), data[mindex][1].to(device) 496 | loss += loss_fn(models1(x).view(-1), y)/float(len(source)) 497 | xmod = x - torch.mm(y.view(-1, 1), torch.mm(y.view(1, -1), x))/torch.sum(y**2) 498 | 499 | # conditional invariance penalty 500 | for jindex, j in enumerate(source): 501 | xj, yj = data[jindex][0].to(device), data[jindex][1].to(device) 502 | if j > m: 503 | xmodj = xj - torch.mm(yj.view(-1, 1), torch.mm(yj.view(1, -1), xj))/torch.sum(yj**2) 504 | if self.wayMatch == 'mean': 505 | discrepancy = torch.nn.MSELoss() 506 | loss += self.lamCIP/float(len(source)**2) * discrepancy(models1(xmod), models1(xmodj)) 507 | elif self.wayMatch == 'mmd': 508 | loss += self.lamCIP/float(len(source)**2) * \ 509 | mmd.mix_rbf_mmd2(models1(xmod), models1(xmodj), self.sigma_list) 510 | else: 511 | raise('error discrepancy') 512 | loss += self.lamL2 * torch.sum(models1.lin1.weight ** 2) + \ 513 | self.lamL1 * torch.sum(torch.abs(models1.lin1.weight)) 514 | # Perform gradient descent 515 | loss.backward() 516 | opt.step() 517 | opt.zero_grad() 518 | running_loss += loss.item() 519 | losses1[epoch] = running_loss 520 | 521 | self.models1 = models1 522 | 523 | # fix grads now 524 | for param in models1.lin1.parameters(): 525 | param.requires_grad = False 526 | 527 | # Step 2: remove the invariant part on all source envs, so that everything is independent of Y 528 | # get that coefficient b 529 | YsrcMean = 0 530 | ntotal = 0 531 | for m in source: 532 | YsrcMean += torch.sum(dataloaders[m].dataset.tensors[1]) 533 | ntotal += dataloaders[m].dataset.tensors[1].shape[0] 534 | YsrcMean /= ntotal 535 | 536 | YTX = 0 537 | YTY = 0 538 | for m in source: 539 | for i, data in enumerate(dataloaders[m]): 540 | x, y = data[0].to(device), data[1].to(device) 541 | yguess = self.models1(x) 542 | yCentered = y - YsrcMean 543 | YTY += torch.sum(yguess.t() * yCentered) 544 | YTX += torch.mm(yCentered.view(1, -1), x) 545 | 546 | b = YTX / YTY 547 | self.b = b 548 | 549 | 550 | # Step 3: mean match between source and target on the residual, after transforming the covariates X - (X * beta_invariant) * b_invariant 551 | models = {} 552 | diffs = {} 553 | losses_all = {} 554 | self.total_weight = 0 555 | 556 | for m in source: 557 | models[m] = LinearModel(d).to(device) 558 | # custom initialization 559 | with torch.no_grad(): 560 | models[m].lin1.bias.data = torch.zeros_like(models[m].lin1.bias) 561 | torch.nn.init.xavier_normal_(models[m].lin1.weight, gain=0.01) 562 | # Define loss function 563 | loss_fn = F.mse_loss 564 | opt = optim.Adam(models[m].parameters(), lr=self.lr) 565 | 566 | losses_all[m] = np.zeros(self.epochs) 567 | 568 | for epoch in range(self.epochs): # loop over the dataset multiple times 569 | running_loss = 0.0 570 | 571 | for i, data in enumerate(zip(dataloaders[m], dataloaders[target])): 572 | opt.zero_grad() 573 | loss = 0 574 | x, y = data[0][0].to(device), data[0][1].to(device) 575 | yguess = self.models1(x) 576 | xmod = x - torch.mm(yguess, b) 577 | 578 | xtar = data[1][0].to(device) 579 | ytarguess = self.models1(xtar) 580 | xtarmod = xtar - torch.mm(ytarguess, b) 581 | 582 | loss += loss_fn(models[m](x).view(-1), y) 583 | if self.wayMatch == 'mean': 584 | discrepancy = torch.nn.MSELoss() 585 | loss += self.lamMatch * discrepancy(models[m](xmod), 586 | models[m](xtarmod)) 587 | elif self.wayMatch == 'mmd': 588 | loss += self.lamMatch * mmd.mix_rbf_mmd2(models[m](xmod), 589 | models[m](xtarmod), 590 | self.sigma_list) 591 | else: 592 | raise('error discrepancy') 593 | loss += self.lamL2 * torch.sum(models[m].lin1.weight ** 2) + \ 594 | self.lamL1 * torch.sum(torch.abs(models[m].lin1.weight)) 595 | 596 | loss.backward() 597 | opt.step() 598 | 599 | running_loss += loss.item() 600 | 601 | losses_all[m][epoch] = running_loss 602 | 603 | # need to compute diff after training 604 | 605 | 606 | diffs[m] = 0. 607 | with torch.no_grad(): 608 | for i, data in enumerate(zip(dataloaders[m], dataloaders[target])): 609 | x, y = data[0][0].to(device), data[0][1].to(device) 610 | yguess = self.models1(x) 611 | xmod = x - torch.mm(yguess, b) 612 | 613 | xtar = data[1][0].to(device) 614 | ytarguess = self.models1(xtar) 615 | xtarmod = xtar - torch.mm(ytarguess, b) 616 | 617 | if self.wayMatch == 'mean': 618 | discrepancy = torch.nn.MSELoss() 619 | diffs[m] += discrepancy(models[m](xmod), models[m](xtarmod)) / \ 620 | self.epochs / (len(dataloaders[m].dataset)/dataloaders[m].batch_size) 621 | elif self.wayMatch == 'mmd': 622 | diffs[m] += mmd.mix_rbf_mmd2(models[m](xmod), models[m](xtarmod), self.sigma_list) / \ 623 | self.epochs / (len(dataloaders[m].dataset)/dataloaders[m].batch_size) 624 | else: 625 | raise('error discrepancy') 626 | 627 | 628 | self.total_weight += torch.exp(-100.*diffs[m]) 629 | 630 | # take the min diff loss to be current best losses and model 631 | minDiff = diffs[source[0]] 632 | minDiffIndx = source[0] 633 | self.losses = losses_all[source[0]] 634 | for m in source: 635 | if diffs[m] < minDiff: 636 | minDiff = diffs[m] 637 | minDiffIndx = m 638 | self.losses = losses_all[m] 639 | self.model = models[m] 640 | self.minDiffIndx = minDiffIndx 641 | 642 | self.models = models 643 | self.diffs = diffs 644 | 645 | return self 646 | 647 | def predict(self, X): 648 | ypredX1 = 0 649 | for m in self.source: 650 | ypredX1 += torch.exp(-100.*self.diffs[m]) * self.models[m](X) 651 | ypredX1 /= self.total_weight 652 | return ypredX1 653 | 654 | def __str__(self): 655 | return self.__class__.__name__ + self.wayMatch + "_Match{:.1f}".format(self.lamMatch) + "_L2={:.1f}".format(self.lamL2) + "_L1={:.1f}".format(self.lamL1) 656 | -------------------------------------------------------------------------------- /sim/sim_linearSCM_var_shift_exp8_box_run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import sys 7 | import argparse 8 | 9 | import torch 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | 15 | # local packages 16 | import sys 17 | sys.path.append('../') 18 | import semiclass 19 | import semitorchclass 20 | import semitorchstocclass 21 | import util 22 | import simudata 23 | 24 | # check gpu avail 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | # Assuming that we are on a CUDA machine, this should print a CUDA device: 27 | print(device) 28 | 29 | # parse args 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--interv_type", type=str, default="sv1", help="type of intervention") 32 | parser.add_argument("--lamMatch", type=float, default=1., help="DIP matching penalty") 33 | parser.add_argument("--epochs", type=int, default=4000, help="number of epochs") 34 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 35 | parser.add_argument("--tag_DA", type=str, default="baseline", help="choose whether to run baseline methods or DA methods") 36 | parser.add_argument("--n", type=int, default=5000, help="sample size") 37 | myargs = parser.parse_args() 38 | print(myargs) 39 | 40 | lamL2 = 0. 41 | lamL1 = 0. 42 | 43 | if myargs.tag_DA == "baseline": 44 | methods = [ 45 | semiclass.Tar(lamL2=lamL2), 46 | semiclass.SrcPool(lamL2=lamL2), 47 | semitorchclass.Tar(lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, epochs=myargs.epochs), 48 | semitorchclass.SrcPool(lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, epochs=myargs.epochs), 49 | ] 50 | elif myargs.tag_DA == "DAmean": 51 | methods = [ 52 | semiclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, sourceInd=0), 53 | semiclass.DIPOracle(lamMatch=myargs.lamMatch, lamL2=lamL2, sourceInd=0), 54 | ] 55 | elif myargs.tag_DA == "DAstd": 56 | methods = [ 57 | semitorchclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, sourceInd=0, lr=myargs.lr, epochs=myargs.epochs, wayMatch='std'), 58 | semitorchclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, sourceInd=0, lr=myargs.lr, epochs=myargs.epochs, wayMatch='mean+std+25p'), 59 | ] 60 | else: # "DAMMD" 61 | methods = [ 62 | semitorchstocclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, sourceInd = 0, lr=myargs.lr, 63 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 64 | ] 65 | 66 | 67 | names = [str(m) for m in methods] 68 | print(names) 69 | names_short = [str(m).split('_')[0] for m in methods] 70 | print(names_short) 71 | 72 | seed1 = int(123456 + np.exp(1) * 1000) 73 | 74 | def simple_run_sem(sem_nums, ds, methods, i2n_ratios=[1.], n=2000, repeats=10): 75 | res_all = {} 76 | for i, sem_num in enumerate(sem_nums): 77 | for j, inter2noise_ratio_local in enumerate(i2n_ratios): 78 | print("Number of envs M=%d, inter2noise_ratio=%.1f" %(2, inter2noise_ratio_local), flush=True) 79 | params = {'M': 2, 'inter2noise_ratio': inter2noise_ratio_local, 'd': ds[i]} 80 | 81 | sem1 = simudata.pick_sem(sem_num, 82 | params=params, 83 | seed=seed1) 84 | 85 | 86 | # run methods on data generated from sem 87 | results_src_all, results_tar_all = util.run_all_methods(sem1, methods, n=n, repeats=repeats) 88 | res_all[(i, j)] = results_src_all, results_tar_all 89 | return res_all 90 | 91 | repeats = 10 92 | res_all = simple_run_sem(sem_nums=['r0%sd3x1' %myargs.interv_type, 93 | 'r0%sd?x1' %myargs.interv_type, 94 | 'r0%sd?x1' %myargs.interv_type], 95 | ds=[3, 10, 20], 96 | i2n_ratios=[1.], 97 | methods=methods, 98 | n=myargs.n, 99 | repeats=repeats) 100 | 101 | np.save("simu_results/sim_exp8_box_r0%sd31020_%s_lamMatch%s_n%d_epochs%d_repeats%d.npy" %(myargs.interv_type, 102 | myargs.tag_DA, myargs.lamMatch, myargs.n, myargs.epochs, repeats), res_all) -------------------------------------------------------------------------------- /sim/sim_linearSCM_var_shift_exp8_box_submit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import subprocess 6 | 7 | interv_type = 'sv1' 8 | epochs = 20000 9 | epochs_MMD = 2000 10 | n = 5000 11 | 12 | for tag_DA in ['baseline']: 13 | subprocess.call(['bsub', '-W 03:50', '-n 4', "./sim_linearSCM_var_shift_exp8_box_run.py --interv_type=%s --tag_DA=%s --n=%d --epochs=%d" %(interv_type, tag_DA, n, epochs)]) 14 | 15 | lamMatches = [10.**(k) for k in (np.arange(10)-5)] 16 | for tag_DA in ['DAmean', 'DAstd']: 17 | for lam in lamMatches: 18 | subprocess.call(['bsub', '-W 03:50', '-n 4', "./sim_linearSCM_var_shift_exp8_box_run.py --interv_type=%s --lamMatch=%f --tag_DA=%s --n=%d --epochs=%d" %(interv_type, lam, tag_DA, n, epochs)]) 19 | 20 | for tag_DA in ['DAMMD']: 21 | for lam in lamMatches: 22 | subprocess.call(['bsub', '-W 23:50', '-n 4', "./sim_linearSCM_var_shift_exp8_box_run.py --interv_type=%s --lamMatch=%f --tag_DA=%s --n=%d --epochs=%d" %(interv_type, lam, tag_DA, n, epochs_MMD)]) -------------------------------------------------------------------------------- /sim/sim_linearSCM_var_shift_exp8_scat_run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import sys 7 | import argparse 8 | 9 | import torch 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | 15 | # local packages 16 | import sys 17 | sys.path.append('../') 18 | import semiclass 19 | import semitorchclass 20 | import semitorchstocclass 21 | import util 22 | import simudata 23 | 24 | # check gpu avail 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | # Assuming that we are on a CUDA machine, this should print a CUDA device: 27 | print(device) 28 | 29 | # parse args 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--interv_type", type=str, default="sv1", help="type of intervention") 32 | parser.add_argument("--lamMatch", type=float, default=1., help="DIP matching penalty") 33 | parser.add_argument("--epochs", type=int, default=4000, help="number of epochs") 34 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 35 | parser.add_argument("--tag_DA", type=str, default="baseline", help="choose whether to run baseline methods or DA methods") 36 | parser.add_argument("--n", type=int, default=5000, help="sample size") 37 | parser.add_argument("--seed", type=int, default=0, help="seed of experiment") 38 | myargs = parser.parse_args() 39 | print(myargs) 40 | 41 | lamL2 = 0. 42 | lamL1 = 0. 43 | 44 | if myargs.tag_DA == "baseline": 45 | methods = [ 46 | semiclass.Tar(lamL2=lamL2), 47 | semiclass.SrcPool(lamL2=lamL2), 48 | semitorchclass.Tar(lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, epochs=myargs.epochs), 49 | semitorchclass.SrcPool(lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, epochs=myargs.epochs), 50 | ] 51 | elif myargs.tag_DA == "DAmean": 52 | methods = [ 53 | semiclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, sourceInd=0), 54 | semiclass.DIPOracle(lamMatch=myargs.lamMatch, lamL2=lamL2, sourceInd=0), 55 | ] 56 | elif myargs.tag_DA == "DAstd": 57 | methods = [ 58 | semitorchclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, sourceInd=0, lr=myargs.lr, epochs=myargs.epochs, wayMatch='std'), 59 | semitorchclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, sourceInd=0, lr=myargs.lr, epochs=myargs.epochs, wayMatch='mean+std+25p'), 60 | ] 61 | else: # "DAMMD" 62 | methods = [ 63 | semitorchstocclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, sourceInd = 0, lr=myargs.lr, 64 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 65 | ] 66 | 67 | names = [str(m) for m in methods] 68 | print(names) 69 | names_short = [str(m).split('_')[0] for m in methods] 70 | print(names_short) 71 | 72 | seed1 = int(123456 + np.exp(1) * 1000) 73 | 74 | params = {'M': 2, 'inter2noise_ratio': 1.0, 'd': 10} 75 | 76 | sem1 = simudata.pick_sem('r0%sd?x1' %myargs.interv_type, 77 | params=params, 78 | seed=seed1+myargs.seed) 79 | # run methods on data generated from sem 80 | results_src_all, results_tar_all = util.run_all_methods(sem1, methods, n=myargs.n, repeats=1) 81 | res_all = {} 82 | res_all['src'] = results_src_all 83 | res_all['tar'] = results_tar_all 84 | 85 | np.save("simu_results/sim_exp8_scat_r0%sd10_%s_lamMatch%s_n%d_epochs%d_seed%d.npy" %(myargs.interv_type, 86 | myargs.tag_DA, myargs.lamMatch, myargs.n, myargs.epochs, myargs.seed), res_all) 87 | 88 | 89 | -------------------------------------------------------------------------------- /sim/sim_linearSCM_var_shift_exp8_scat_submit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import subprocess 6 | 7 | interv_type = 'sv1' 8 | epochs = 20000 9 | epochs_MMD = 2000 10 | n = 5000 11 | 12 | for tag_DA in ['baseline']: 13 | for myseed in range(100): 14 | subprocess.call(['bsub', '-W 03:50', '-n 4', "./sim_linearSCM_var_shift_exp8_scat_run.py --interv_type=%s --tag_DA=%s --n=%d --epochs=%d --seed=%d" %(interv_type, tag_DA, n, epochs, myseed)]) 15 | 16 | 17 | lamMatches = [10.**(k) for k in (np.arange(10)-5)] 18 | for tag_DA in ['DAmean', 'DAstd']: 19 | for myseed in range(100): 20 | for lam in lamMatches: 21 | subprocess.call(['bsub', '-W 03:50', '-n 4', "./sim_linearSCM_var_shift_exp8_scat_run.py --interv_type=%s --lamMatch=%f --tag_DA=%s --n=%d --epochs=%d --seed=%d" %(interv_type, lam, tag_DA, n, epochs, myseed)]) 22 | 23 | for tag_DA in ['DAMMD']: 24 | for myseed in range(100): 25 | for lam in lamMatches: 26 | subprocess.call(['bsub', '-W 23:50', '-n 4', "./sim_linearSCM_var_shift_exp8_scat_run.py --interv_type=%s --lamMatch=%f --tag_DA=%s --n=%d --epochs=%d --seed=%d" %(interv_type, lam, tag_DA, n, epochs_MMD, myseed)]) -------------------------------------------------------------------------------- /sim/sim_linearSCM_var_shift_exp9_scat_run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import sys 7 | import argparse 8 | 9 | import torch 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | 15 | # local packages 16 | import sys 17 | sys.path.append('../') 18 | import semiclass 19 | import semitorchclass 20 | import semitorchstocclass 21 | import util 22 | import simudata 23 | 24 | # check gpu avail 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | 27 | # Assuming that we are on a CUDA machine, this should print a CUDA device: 28 | 29 | print(device) 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--interv_type", type=str, default="smv1", help="type of intervention") 33 | parser.add_argument("--lamMatch", type=float, default=1., help="DIP matching penalty") 34 | parser.add_argument("--lamCIP", type=float, default=0.1, help="CIP penalty") 35 | parser.add_argument("--epochs", type=int, default=4000, help="number of epochs") 36 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 37 | parser.add_argument("--tag_DA", type=str, default="baseline", help="choose whether to run baseline methods or DA methods") 38 | parser.add_argument("--n", type=int, default=5000, help="sample size") 39 | parser.add_argument("--seed", type=int, default=0, help="seed of experiment") 40 | myargs = parser.parse_args() 41 | print(myargs) 42 | 43 | 44 | lamL2 = 0. 45 | lamL1 = 0. 46 | 47 | if myargs.tag_DA == 'baseline': 48 | methods = [ 49 | semiclass.Tar(lamL2=lamL2), 50 | semiclass.SrcPool(lamL2=lamL2), 51 | semitorchclass.Tar(lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, epochs=myargs.epochs), 52 | semitorchclass.SrcPool(lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, epochs=myargs.epochs), 53 | ] 54 | elif myargs.tag_DA == 'DAmean': 55 | methods = [ 56 | semiclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, sourceInd=0), 57 | semiclass.DIPOracle(lamMatch=myargs.lamMatch, lamL2=lamL2, sourceInd=0), 58 | semiclass.DIPweigh(lamMatch=myargs.lamMatch, lamL2=lamL2), 59 | semiclass.CIP(lamCIP=myargs.lamCIP, lamL2=lamL2), 60 | semiclass.CIRMweigh(lamCIP=myargs.lamCIP, lamMatch=myargs.lamMatch, lamL2=lamL2), 61 | ] 62 | elif myargs.tag_DA == 'DAstd': 63 | methods = [ 64 | semitorchclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, sourceInd=0, lr=myargs.lr, 65 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 66 | semitorchclass.DIPweigh(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, 67 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 68 | semitorchclass.CIP(lamCIP=myargs.lamCIP, lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, 69 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 70 | semitorchclass.CIRMweigh(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, 71 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 72 | ] 73 | elif myargs.tag_DA == 'DAMMD': 74 | methods = [ 75 | semitorchstocclass.DIP(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, sourceInd = 0, lr=myargs.lr, 76 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 77 | semitorchstocclass.DIPweigh(lamMatch=myargs.lamMatch, lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, 78 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 79 | semitorchstocclass.CIP(lamCIP=myargs.lamCIP, lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, 80 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 81 | semitorchstocclass.CIRMweigh(lamMatch=myargs.lamMatch, lamCIP=myargs.lamCIP, lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, 82 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]) 83 | ] 84 | elif myargs.tag_DA == 'DACIP': 85 | methods = [ 86 | semiclass.CIP(lamCIP=myargs.lamCIP, lamL2=lamL2), 87 | semitorchclass.CIP(lamCIP=myargs.lamCIP, lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, 88 | epochs=myargs.epochs, wayMatch='mean+std+25p'), 89 | ] 90 | elif myargs.tag_DA == 'DACIPMMD': 91 | methods = [ 92 | semitorchstocclass.CIP(lamCIP=myargs.lamCIP, lamL2=lamL2, lamL1=lamL1, lr=myargs.lr, 93 | epochs=myargs.epochs, wayMatch='mmd', sigma_list=[1.]), 94 | ] 95 | 96 | 97 | names = [str(m) for m in methods] 98 | print(names) 99 | names_short = [str(m).split('_')[0] for m in methods] 100 | print(names_short) 101 | 102 | seed1 = int(123456 + np.exp(2) * 1000) 103 | 104 | params = {'M': 15, 'inter2noise_ratio': 1., 'd': 20, 'cicnum': 10,'interY': 1.} 105 | 106 | sem1 = simudata.pick_sem('r0%sd?x4' %myargs.interv_type, 107 | params=params, 108 | seed=seed1+myargs.seed) 109 | 110 | # run methods on data generated from sem 111 | results_src_all, results_tar_all, results_minDiffIndx = util.run_all_methods(sem1, 112 | methods, 113 | n=myargs.n, 114 | repeats=1, 115 | returnMinDiffIndx=True, 116 | tag_DA=myargs.tag_DA) 117 | res_all = {} 118 | res_all['src'] = results_src_all 119 | res_all['tar'] = results_tar_all 120 | res_all['minDiffIndx'] = results_minDiffIndx 121 | 122 | 123 | np.save("simu_results/sim_exp9_scat_r0%sd20x4_%s_lamMatch%s_lamCIP%s_n%d_epochs%d_seed%d.npy" %(myargs.interv_type, 124 | myargs.tag_DA, myargs.lamMatch, myargs.lamCIP, myargs.n, myargs.epochs, myargs.seed), res_all) 125 | 126 | 127 | -------------------------------------------------------------------------------- /sim/sim_linearSCM_var_shift_exp9_scat_submit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import numpy as np 6 | import subprocess 7 | import os 8 | 9 | interv_type = 'smv1' 10 | epochs = 20000 11 | epochs_MMD = 2000 12 | n = 5000 13 | 14 | for tag_DA in ['baseline']: 15 | for myseed in range(100): 16 | subprocess.call(['bsub', '-W 03:50', '-n 4', "./sim_linearSCM_var_shift_exp9_scat_run.py --interv_type=%s --tag_DA=%s --epochs=%d --seed=%d" %(interv_type, tag_DA, 20000, myseed)]) 17 | 18 | 19 | # You don't have to run all the lamMatch choices!!! 20 | lamMatches = [10.**(k) for k in (np.arange(10)-5)] 21 | for tag_DA in ['DAmean', 'DAstd']: 22 | for myseed in range(100): 23 | for lam in lamMatches: 24 | subprocess.call(['bsub', '-W 03:50', '-n 4', "./sim_linearSCM_var_shift_exp9_scat_run.py --interv_type=%s --lamMatch=%f --tag_DA=%s --epochs=%d --seed=%d" %(interv_type, lam, tag_DA, epochs, myseed)]) 25 | 26 | for tag_DA in ['DAMMD']: 27 | for myseed in range(100): 28 | for lam in lamMatches: 29 | subprocess.call(['bsub', '-W 119:50', '-n 4', "./sim_linearSCM_var_shift_exp9_scat_run.py --interv_type=%s --lamMatch=%f --tag_DA=%s --epochs=%d --seed=%d" %(interv_type, lam, tag_DA, epochs_MMD, myseed)]) 30 | 31 | 32 | # You don't have to run all the lamCIP choices!!! 33 | lamCIPs = [10.**(k) for k in (np.arange(10)-5)] 34 | for tag_DA in ['DACIP']: 35 | for myseed in range(100): 36 | for lam in lamCIPs: 37 | subprocess.call(['bsub', '-W 03:50', '-n 4', "./sim_linearSCM_var_shift_exp9_scat_run.py --interv_type=%s --lamCIP=%f --tag_DA=%s --epochs=%d --seed=%d" %(interv_type, lam, tag_DA, epochs, myseed)]) 38 | 39 | for tag_DA in ['DACIPMMD']: 40 | for myseed in range(100): 41 | for lam in lamCIPs: 42 | subprocess.call(['bsub', '-W 23:50', '-n 4', "./sim_linearSCM_var_shift_exp9_scat_run.py --interv_type=%s --lamCIP=%f --tag_DA=%s --epochs=%d --seed=%d" %(interv_type, lam, tag_DA, epochs_MMD, myseed)]) 43 | -------------------------------------------------------------------------------- /sim/simu_results/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /sim/simudata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('../') 4 | import sem 5 | import myrandom 6 | 7 | def pick_intervention_and_noise(M, dp1, inter2noise_ratio, interY=0., cic=[], typeshift='sm1', varAs=None, varnoiseY=1.): 8 | if typeshift == 'sm1': 9 | meanAs = inter2noise_ratio * np.random.randn(M, dp1) 10 | 11 | # meanAs[:, -1] = interY * np.random.randn(M) # 0 means no intervention on Y 12 | meanAs[:, -1] = 0. 13 | meanAs[-1, -1] = interY * np.random.randn(1) # 0 means no intervention on Y 14 | 15 | meanAs[:, cic] = 0 # set conditional invariant components 16 | 17 | if not varAs: 18 | varAs = np.zeros((M, dp1)) 19 | interAf = myrandom.Gaussian(M, meanAs, varAs) 20 | 21 | noisevar = np.ones((1, dp1)) 22 | noisevar[-1] = varnoiseY 23 | noisef = myrandom.Gaussian(1, np.zeros((1, dp1)), noisevar) 24 | elif typeshift == 'sm2': 25 | # mixture 26 | meanAs1 = inter2noise_ratio * np.random.randn(M, dp1) 27 | meanAs2 = inter2noise_ratio * np.random.randn(M, dp1) 28 | meanAs1[:, -1] = 0. # 0 intervY only for the target env 29 | meanAs2[:, -1] = 0. # 0 intervY only for the target env 30 | meanAs1[-1, -1] = interY * np.random.randn(1) # 0 means no intervention on Y 31 | meanAs2[-1, -1] = interY * np.random.randn(1) # 0 means no intervention on Y 32 | meanAs1[:, cic] = 0 # set conditional invariant components 33 | meanAs2[:, cic] = 0 # set conditional invariant components 34 | meanAsList = [meanAs1, meanAs2] 35 | 36 | if not varAs: 37 | varAs = np.zeros((M, dp1)) 38 | interAf = myrandom.Mix2Gaussian(M, meanAsList, varAs) 39 | 40 | noisevar = np.ones((1, dp1)) 41 | noisevar[-1] = varnoiseY 42 | noisef = myrandom.Gaussian(1, np.zeros((1, dp1)), noisevar) 43 | elif typeshift == 'sm3': 44 | # mixture 45 | meanAs1 = inter2noise_ratio * np.random.randn(M, dp1) 46 | meanAs2 = inter2noise_ratio * np.random.randn(M, dp1) 47 | meanAs3 = inter2noise_ratio * np.random.randn(M, dp1) 48 | meanAs1[:, -1] = 0. # 0 intervY only for the target env 49 | meanAs2[:, -1] = 0. # 0 intervY only for the target env 50 | meanAs3[:, -1] = 0. # 0 intervY only for the target env 51 | meanAs1[-1, -1] = interY * np.random.randn(1) # 0 means no intervention on Y 52 | meanAs2[-1, -1] = interY * np.random.randn(1) # 0 means no intervention on Y 53 | meanAs3[-1, -1] = interY * np.random.randn(1) # 0 means no intervention on Y 54 | meanAs1[:, cic] = 0 # set conditional invariant components 55 | meanAs2[:, cic] = 0 # set conditional invariant components 56 | meanAs3[:, cic] = 0 # set conditional invariant components 57 | meanAsList = [meanAs1, meanAs2, meanAs3] 58 | 59 | if not varAs: 60 | varAs = np.zeros((M, dp1)) 61 | interAf = myrandom.MixkGaussian(M, meanAsList, varAs) 62 | 63 | noisevar = np.ones((1, dp1)) 64 | noisevar[-1] = varnoiseY 65 | noisef = myrandom.Gaussian(1, np.zeros((1, dp1)), noisevar) 66 | 67 | elif typeshift == 'sv1': 68 | meanAs = np.zeros((M, dp1)) 69 | varAs_inter = inter2noise_ratio * (np.abs(np.random.randn(M, dp1))) 70 | varAs_inter[:, -1] = 1. 71 | varAs = varAs_inter 72 | 73 | interAf = myrandom.Gaussian(M, meanAs, varAs) 74 | 75 | # the noise is always zero, is taken care of by interAf 76 | noisef = myrandom.Gaussian(1, np.zeros((1, dp1)), np.zeros((1, dp1))) 77 | 78 | elif typeshift == 'smv1': 79 | meanAs = inter2noise_ratio * np.random.randn(M, dp1) 80 | # meanAs[:, -1] = interY * np.random.randn(M) # 0 means no intervention on Y 81 | meanAs[:, -1] = 0. 82 | meanAs[-1, -1] = interY * np.random.randn(1) # 0 means no intervention on Y 83 | meanAs[:, cic] = 0 # set conditional invariant components 84 | varAs_inter = inter2noise_ratio * np.abs(np.random.randn(M, dp1)) 85 | varAs_inter[:, -1] = 1. 86 | varAs_inter[:, cic] = 1. # set conditional invariant components 87 | varAs = varAs_inter 88 | 89 | interAf = myrandom.Gaussian(M, meanAs, varAs) 90 | 91 | # the noise is always zero, is taken care of by interAf 92 | noisef = myrandom.Gaussian(1, np.zeros((1, dp1)), np.zeros((1, dp1))) 93 | 94 | return interAf, noisef 95 | 96 | def pick_random_B(pred_dir = 'anticausal', dp1=1): 97 | B = np.zeros((dp1, dp1)) 98 | if pred_dir == 'anticausal': 99 | # triangular B 100 | for i in range(0, dp1-1): 101 | for j in range(i+1, dp1-1): 102 | B[j, i] = 0.5 * np.random.randn(1) 103 | B[:, -1] = 1.0 * np.random.randn(dp1) 104 | B[-1, -1] = 0 105 | elif pred_dir == 'causal': 106 | for i in range(0, dp1-1): 107 | for j in range(i+1, dp1-1): 108 | B[j, i] = 0.5 * np.random.randn(1) 109 | # so y should not change X 110 | B[:, -1] = 0 111 | # causal prediction 112 | B[-1, :-1] = 1.0 * np.random.randn(dp1-1) 113 | B[-1, -1] = 0 114 | elif pred_dir == 'halfhalf': 115 | # make them mixed causal and anti-causal 116 | B = np.zeros((dp1, dp1)) 117 | for i in range(0, dp1-1): 118 | for j in range(i+1, dp1-1): 119 | B[j, i] = 0.5 * np.random.randn(1) 120 | for i in range((dp1-1)//2): 121 | # half anti-causal 122 | B[2*i, -1] = 1.0 * np.random.randn(1) 123 | # half causal 124 | if 2*i+1 <= dp1-2: 125 | B[-1, 2*i+1] = 1.0 * np.random.randn(1) 126 | for j in range((dp1-1)//2): 127 | # anti-causal node should not point to causal node, to ensure acyclic graph 128 | B[2*i+1, 2*j] = 0 129 | B[-1, -1] = 0 130 | else: 131 | raise ValueError('case not recognized.') 132 | 133 | 134 | return B 135 | 136 | 137 | def pick_sem(data_num, params = None, seed=123456): 138 | np.random.seed(seed) 139 | # name rules 140 | # r0: r0 regression Y is cause 141 | # r1 regression Y is effect 142 | # r2 regression Y is in the middle 143 | # sm1: 1 dimensional mean shift 144 | # case 1 1 dimensional mean shift 145 | # case 2 2 mixture of mean shift 146 | # sv1 change of variance 147 | # d3: dimension of the problem is 3 148 | # x1: no intervention on Y, only intervention on X, 149 | # case 1 (no conditional invariant components, no inter Y) 150 | # case 2 (with conditional invariant components, no inter Y) 151 | # case 3 (no conditional invariant components, inter Y) 152 | # case 4 (with conditional invariant components, inter Y) 153 | if 'd3' in data_num: 154 | # Y cause, d = 3 155 | M = params['M'] 156 | # d plus 1 157 | dp1 = 4 158 | 159 | 160 | inter2noise_ratio = params['inter2noise_ratio'] 161 | if 'x1' in data_num: 162 | # conditional invariant components 163 | cic = [] 164 | # intervention on Y 165 | interY = 0 166 | elif 'x3' in data_num: 167 | cic = [] 168 | if 'interY' in params.keys(): 169 | interY = params['interY'] 170 | else: 171 | interY = 1. 172 | elif 'x2' in data_num: 173 | # conditional invariant components 174 | cic = [0] 175 | # intervention on Y 176 | interY = 0 177 | elif 'x4' in data_num: 178 | cic = [0] 179 | if 'interY' in params.keys(): 180 | interY = params['interY'] 181 | else: 182 | interY = 1. 183 | else: 184 | raise ValueError('case not recognized.') 185 | 186 | if 'sm1' in data_num: 187 | typeshift = 'sm1' 188 | elif 'sm2' in data_num: 189 | typeshift = 'sm2' 190 | elif 'sm3' in data_num: 191 | typeshift = 'sm3' 192 | elif 'sv1' in data_num: 193 | typeshift = 'sv1' 194 | elif 'smv1' in data_num: 195 | typeshift = 'smv1' 196 | elif 'smm2' in data_num: 197 | typeshift = 'smm2' 198 | else: 199 | typeshift = 'sm1' 200 | 201 | 202 | interAf, noisef = pick_intervention_and_noise(M, dp1, inter2noise_ratio, interY=interY, 203 | cic=cic, typeshift=typeshift, varAs=None, varnoiseY=1.) 204 | 205 | if 'r0' in data_num: 206 | pred_dir = 'anticausal' 207 | B = np.array([[0, 0, 0, 1], [0, 0, 0, -1], [0, 0, 0, 3], [0, 0, 0, 0]]) 208 | elif 'r1' in data_num: 209 | pred_dir = 'causal' 210 | B = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1, -1, 3, 0]]) 211 | # elif 'r2' in data_num: 212 | # pred_dir = 'halfhalf' 213 | else: 214 | raise ValueError('case not recognized.') 215 | 216 | 217 | 218 | invariantList = [] 219 | message = "%sM%di%d, fixed simple B" %(data_num, M, inter2noise_ratio) 220 | elif 'd?' in data_num: 221 | # Y cause, d = ? 222 | M = params['M'] 223 | # d plus 1 224 | if 'd' in params.keys(): 225 | dp1 = params['d'] + 1 226 | else: 227 | # set default dimension to 10 228 | dp1 = 11 229 | if 'interY' in params.keys(): 230 | interY = params['interY'] 231 | else: 232 | interY = 1. 233 | 234 | inter2noise_ratio = params['inter2noise_ratio'] 235 | 236 | if 'r0' in data_num: 237 | pred_dir = 'anticausal' 238 | varnoiseY = 1. 239 | elif 'r1' in data_num: 240 | pred_dir = 'causal' 241 | varnoiseY = 1. 242 | elif 'r2' in data_num: 243 | pred_dir = 'halfhalf' 244 | varnoiseY = 0.01 245 | else: 246 | raise ValueError('case not recognized.') 247 | 248 | B = pick_random_B(pred_dir, dp1) 249 | 250 | if 'x1' in data_num: 251 | # conditional invariant components 252 | cic = [] 253 | # intervention on Y 254 | interY = 0 255 | elif 'x3' in data_num: 256 | cic = [] 257 | if 'interY' in params.keys(): 258 | interY = params['interY'] 259 | else: 260 | interY = 1. 261 | elif 'x2' in data_num: 262 | if 'cicnum' in params.keys(): 263 | cicnum = params['cicnum'] 264 | else: 265 | cicnum = int(dp1/2) 266 | # conditional invariant components 267 | cic = np.random.choice(dp1-1, cicnum, replace=False) 268 | # intervention on Y 269 | interY = 0 270 | elif 'x4' in data_num: 271 | if 'cicnum' in params.keys(): 272 | cicnum = params['cicnum'] 273 | else: 274 | cicnum = int(dp1/2) 275 | # cic = np.arange(0, cicnum) 276 | cic = np.random.choice(dp1-1, cicnum, replace=False) 277 | if 'interY' in params.keys(): 278 | interY = params['interY'] 279 | else: 280 | interY = 1. 281 | else: 282 | raise ValueError('case not recognized.') 283 | 284 | if 'sm1' in data_num: 285 | typeshift = 'sm1' 286 | elif 'sm2' in data_num: 287 | typeshift = 'sm2' 288 | elif 'sv1' in data_num: 289 | typeshift = 'sv1' 290 | elif 'smv1' in data_num: 291 | typeshift = 'smv1' 292 | elif 'smm2' in data_num: 293 | typeshift = 'smm2' 294 | else: 295 | typeshift = 'sm1' 296 | 297 | interAf, noisef = pick_intervention_and_noise(M, dp1, inter2noise_ratio, interY=interY, cic=cic, typeshift=typeshift, varAs=None, varnoiseY=varnoiseY) 298 | 299 | invariantList = cic 300 | message = "%sM%dd%di%d, fixed simple B" %(data_num, M, dp1-1, inter2noise_ratio) 301 | # generate sem 302 | sem1 = sem.SEM(B, noisef, interAf, invariantList=invariantList, message=message) 303 | 304 | return sem1 305 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sem 3 | 4 | import torch 5 | 6 | # check gpu avail 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def MSE(yhat, y): 11 | return np.mean((yhat-y)**2) 12 | 13 | def torchMSE(a, b): 14 | return torch.mean((a.squeeze() - b)**2) 15 | 16 | def torchloaderMSE(me, dataloader, device): 17 | # get MSE from a torch model with dataloader 18 | error = 0 19 | n = 0 20 | with torch.no_grad(): 21 | for data in dataloader: 22 | x, y = data[0].to(device), data[1].to(device) 23 | ypred = me.predict(x) 24 | n += x.shape[0] 25 | error += torch.sum((ypred.squeeze() - y)**2) 26 | return error.item()/n 27 | 28 | 29 | # given a SEM, run all methods and return target risks and target test risks 30 | def run_all_methods(sem1, methods, n=1000, repeats=10, returnMinDiffIndx=False, tag_DA='DAMMD'): 31 | M = sem1.M 32 | results_src_all = np.zeros((M-1, len(methods), 2, repeats)) 33 | results_tar_all = np.zeros((len(methods), 2, repeats)) 34 | results_minDiffIndx = {} 35 | 36 | # generate data 37 | for repeat in range(repeats): 38 | data = sem1.generateAllSamples(n) 39 | dataTest = sem1.generateAllSamples(n) 40 | # may use other target as well 41 | source = np.arange(M-1) 42 | target = M-1 43 | 44 | # prepare torch format data 45 | dataTorch = {} 46 | dataTorchTest = {} 47 | 48 | for i in range(M): 49 | dataTorch[i] = [torch.from_numpy(data[i][0].astype(np.float32)).to(device), 50 | torch.from_numpy(data[i][1].astype(np.float32)).to(device)] 51 | dataTorchTest[i] = [torch.from_numpy(dataTest[i][0].astype(np.float32)).to(device), 52 | torch.from_numpy(dataTest[i][1].astype(np.float32)).to(device)] 53 | 54 | # prepare torch format data for batch stochastic gradient descent 55 | train_batch_size = 500 56 | test_batch_size = 500 57 | 58 | trainloaders = {} 59 | testloaders = {} 60 | 61 | for i in range(M): 62 | train_dataset = torch.utils.data.TensorDataset(torch.Tensor(data[i][0]), 63 | torch.Tensor(data[i][1])) 64 | test_dataset = torch.utils.data.TensorDataset(torch.Tensor(dataTest[i][0]), 65 | torch.Tensor(dataTest[i][1])) 66 | trainloaders[i] = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size) 67 | testloaders[i] = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size) 68 | 69 | for i, m in enumerate(methods): 70 | if m.__module__ == 'semiclass': 71 | me = m.fit(data, source=source, target=target) 72 | if hasattr(me, 'minDiffIndx'): 73 | print("best index="+str(me.minDiffIndx)) 74 | results_minDiffIndx[(tag_DA, i, repeat)] = me.minDiffIndx 75 | xtar, ytar = data[target] 76 | xtar_test, ytar_test = dataTest[target] 77 | targetE = MSE(me.ypred, ytar) 78 | targetNE = MSE(me.predict(xtar_test), ytar_test) 79 | for j, sourcej in enumerate(source): 80 | results_src_all[j, i, 0, repeat] = MSE(me.predict(data[sourcej][0]), data[sourcej][1]) 81 | results_src_all[j, i, 1, repeat] = MSE(me.predict(dataTest[sourcej][0]), dataTest[sourcej][1]) 82 | elif m.__module__ == 'semitorchclass': 83 | me = m.fit(dataTorch, source=source, target=target) 84 | if hasattr(me, 'minDiffIndx'): 85 | print("best index="+str(me.minDiffIndx)) 86 | results_minDiffIndx[(tag_DA, i, repeat)] = me.minDiffIndx 87 | xtar, ytar= dataTorch[target] 88 | xtar_test, ytar_test= dataTorchTest[target] 89 | targetE = torchMSE(me.ypred, ytar) 90 | targetNE = torchMSE(me.predict(xtar_test), ytar_test) 91 | for j, sourcej in enumerate(source): 92 | results_src_all[j, i, 0, repeat] = torchMSE(me.predict(dataTorch[sourcej][0]), dataTorch[sourcej][1]) 93 | results_src_all[j, i, 1, repeat] = torchMSE(me.predict(dataTorchTest[sourcej][0]), dataTorchTest[sourcej][1]) 94 | elif m.__module__ == 'semitorchstocclass': 95 | me = m.fit(trainloaders, source=source, target=target) 96 | targetE = torchloaderMSE(me, trainloaders[target], device) 97 | targetNE = torchloaderMSE(me, testloaders[target], device) 98 | for j, sourcej in enumerate(source): 99 | results_src_all[j, i, 0, repeat] = torchloaderMSE(me, trainloaders[sourcej], device) 100 | results_src_all[j, i, 1, repeat] = torchloaderMSE(me, testloaders[sourcej], device) 101 | else: 102 | raise ValueError("Unexpected method class") 103 | results_tar_all[i, 0, repeat] = targetE 104 | results_tar_all[i, 1, repeat] = targetNE 105 | print("Repeat %d Target %-30s error=%.3f errorTest=%.3f" %(repeat, str(m), targetE, targetNE), flush=True) 106 | if returnMinDiffIndx: 107 | return results_src_all, results_tar_all, results_minDiffIndx 108 | else: 109 | return results_src_all, results_tar_all 110 | --------------------------------------------------------------------------------