├── .gitignore ├── LICENSE ├── README.md ├── setup.py └── single_pulse_ml ├── __init__.py ├── classify.py ├── data └── data.txt ├── dataproc.py ├── frbkeras.py ├── model └── model.txt ├── plot_tools.py ├── plots └── Freq_train.png ├── reader.py ├── run_frb_simulation.py ├── run_single_pulse_DL.py ├── sim_parameters.py ├── simulate_frb.py ├── simulate_multibeam.py ├── telescope.py ├── tests ├── __init__.py ├── test_frbkeras.py ├── test_reader.py ├── test_run_frb_simulation.py └── test_simulate_frb.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.npy 3 | *.hdf5 4 | *.py~ 5 | 6 | ./dist/ 7 | 8 | ./single_pulse_ml/*pkl 9 | ./single_pulse_ml/*npy 10 | ./single_pulse_ml/*hdf5 11 | ./single_pulse_ml/plots/*png 12 | ./single_pulse_ml/model 13 | ./single_pulse_ml/run_frb_simulation.py 14 | ./single_pulse_ml/run_single_pulse_DL.py 15 | 16 | /*.egg-info 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | {description} 294 | Copyright (C) {year} {fullname} 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | {signature of Ty Coon}, 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### single_pulse_ml 2 | 3 | Build, train, and apply deep neural networks to single pulse candidates. 4 | 5 | run_frb_simulation.py constructs a training set that includes simulated FRBs 6 | 7 | run_single_pulse_DL.py allows for training of deep neural networks for several 8 | input data products, including: 9 | -- dedispersed dynamic spectra (2D CNN) 10 | -- DM/time intensity array (2D CNN) 11 | -- frequency-collapsed pulse profile (1D CNN) 12 | -- Multi-beam S/N information (1D feed forward DNN) 13 | 14 | run_single_pulse_DL.py can also be used when a trained model already exists and candidates are to be classified 15 | 16 | This code has been used on CHIME Pathfinder incoherent data as well as commissioning data on Apertif. 17 | 18 | ### Requirements 19 | 20 | - You will need the following: 21 | - numpy 22 | - scipy 23 | - h5py 24 | - matplotlib 25 | - tensorflow 26 | - keras 27 | 28 | ### Tests 29 | 30 | In the single_pulse_ml/tests/ directory, 31 | "test_run_frb_simulation.py" can be run to generate 100 simulated FRBs 32 | to ensure the simulation backend works. 33 | 34 | "test_frbkeras.py" will generate 1000 gaussian-noise 35 | dynamic spectrum candidates of dimension 32x64, then 36 | build, train, and test a CNN using the tools in frbkeras. 37 | This allows a test of the keras/tensorflow code. 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | def readme(): 4 | with open('README.rst') as f: 5 | return f.read() 6 | 7 | setup(name='single_pulse_ml', 8 | version='0.1', 9 | description='Deep learning implementation of single-pulse search', 10 | url='http://github.com/liamconnor/single_pulse_ml', 11 | author='Liam Connor', 12 | author_email='liam.dean.connor@gmail.com', 13 | license='GPL v2.0', 14 | packages=['single_pulse_ml'], 15 | install_requires=[ 16 | 'numpy', 17 | 'scipy', 18 | 'h5py', 19 | 'matplotlib', 20 | 'tensorflow-gpu', 21 | 'keras', 22 | ], 23 | test_suite='nose.collector', 24 | tests_require=['nose'], 25 | zip_safe=False) 26 | -------------------------------------------------------------------------------- /single_pulse_ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liamconnor/single_pulse_ml/88b6b76ebf3d3939214d9785d4e1c5076f653c38/single_pulse_ml/__init__.py -------------------------------------------------------------------------------- /single_pulse_ml/classify.py: -------------------------------------------------------------------------------- 1 | # Liam Connor 25 July 2018 2 | # Script to classify single-pulses 3 | # using tensorflow/keras model. Output probabilities 4 | # can be saved and plotted 5 | 6 | import optparse 7 | import numpy as np 8 | import h5py 9 | 10 | import warnings 11 | warnings.simplefilter(action='ignore', category=FutureWarning) 12 | 13 | import matplotlib as mpl 14 | mpl.use('pdf') 15 | 16 | import frbkeras 17 | import reader 18 | import plot_tools 19 | 20 | def classify(data, model, save_ranked=False, 21 | plot_ranked=False, prob_threshold=0.5, 22 | fnout='ranked'): 23 | 24 | model = frbkeras.load_model(model) 25 | 26 | mshape = model.input.shape 27 | dshape = data.shape 28 | 29 | # normalize data 30 | data = data.reshape(len(data), -1) 31 | data -= np.median(data, axis=-1)[:, None] 32 | data /= np.std(data, axis=-1)[:, None] 33 | 34 | # zero out nans 35 | data[data!=data] = 0.0 36 | data = data.reshape(dshape) 37 | 38 | if dshape[-1]!=1: 39 | data = data[..., None] 40 | 41 | if len(mshape)==3: 42 | data = data.mean(1) 43 | dshape = data.shape 44 | 45 | if mshape[1]dshape[1]: 51 | print("Model expects:", mshape) 52 | print("Data has:", dshape) 53 | 54 | return 55 | 56 | if mshape[2]dshape[2]: 62 | print("Model expects:", mshape) 63 | print("Data has:", dshape) 64 | 65 | return 66 | 67 | y_pred_prob = model.predict(data) 68 | y_pred_prob = y_pred_prob[:,1] 69 | 70 | ind_frb = np.where(y_pred_prob>prob_threshold)[0] 71 | 72 | print("\n%d out of %d events with probability > %.2f:\n %s" % 73 | (len(ind_frb), len(y_pred_prob), 74 | prob_threshold, ind_frb)) 75 | 76 | low_to_high_ind = np.argsort(y_pred_prob) 77 | 78 | if save_ranked is True: 79 | print("Need to fix the file naming") 80 | fnout_ranked = fn_data.rstrip('.hdf5') + \ 81 | 'freq_time_candidates.hdf5' 82 | 83 | g = h5py.File(fnout_ranked, 'w') 84 | g.create_dataset('data_frb_candidate', data=data[ind_frb]) 85 | g.create_dataset('frb_index', data=ind_frb) 86 | g.create_dataset('probability', data=y_pred_prob) 87 | g.close() 88 | print("\nSaved them and all probabilities to: \n%s" % fnout_ranked) 89 | 90 | if plot_ranked is True: 91 | if save_ranked is False: 92 | argtup = (data[ind_frb], ind_frb, y_pred_prob) 93 | 94 | plot_tools.plot_multiple_ranked(argtup, nside=10, \ 95 | fnfigout=fnout, ascending=False) 96 | else: 97 | plot_tools.plot_multiple_ranked(fnout_ranked, nside=10, \ 98 | fnfigout=fnout, ascending=False) 99 | 100 | 101 | if __name__=="__main__": 102 | parser = optparse.OptionParser(prog="classify.py", \ 103 | version="", \ 104 | usage="%prog FN_DATA FN_MODEL [OPTIONS]", \ 105 | description="Apply DNN model to FRB candidates") 106 | 107 | parser.add_option('--fn_model_dm', dest='fn_model_dm', type='str', \ 108 | help="Filename of dm_time model. Default None", \ 109 | default=None) 110 | 111 | parser.add_option('--fn_model_time', dest='fn_model_time', type='str', \ 112 | help="Filename of 1d time model. Default None", \ 113 | default=None) 114 | 115 | parser.add_option('--fn_model_mb', dest='fn_model_mb', type='str', \ 116 | help="Filename of multibeam model. Default None", \ 117 | default=None) 118 | 119 | parser.add_option('--pthresh', dest='prob_threshold', type='float', \ 120 | help="probability treshold", default=0.5) 121 | 122 | parser.add_option('--save_ranked', dest='save_ranked', 123 | action='store_true', \ 124 | help="save FRB events + probabilities", \ 125 | default=False) 126 | 127 | parser.add_option('--plot_ranked', dest='plot_ranked', \ 128 | action='store_true',\ 129 | help="plot triggers", default=False) 130 | 131 | parser.add_option('--twindow', dest='twindow', type='int', \ 132 | help="time width, default 64", default=64) 133 | 134 | parser.add_option('--fnout', dest='fnout', type='str', \ 135 | help="beginning of figure names", \ 136 | default='ranked_trig') 137 | 138 | options, args = parser.parse_args() 139 | 140 | assert len(args)==2, "Arguments are FN_DATA FN_MODEL [OPTIONS]" 141 | 142 | fn_data = args[0] 143 | fn_model_freq = args[1] 144 | 145 | print("Using datafile %s" % fn_data) 146 | print("Using keras model in %s" % fn_model_freq) 147 | 148 | data_freq, y, data_dm, data_mb = reader.read_hdf5(fn_data) 149 | 150 | NFREQ = data_freq.shape[1] 151 | NTIME = data_freq.shape[2] 152 | WIDTH = options.twindow 153 | 154 | # low time index, high time index 155 | tl, th = NTIME//2-WIDTH//2, NTIME//2+WIDTH//2 156 | 157 | if data_freq.shape[-1] > (th-tl): 158 | data_freq = data_freq[..., tl:th] 159 | 160 | fn_fig_out = options.fnout + '_freq_time' 161 | print("\nCLASSIFYING FREQ/TIME DATA\n") 162 | classify(data_freq, fn_model_freq, 163 | save_ranked=options.save_ranked, 164 | plot_ranked=options.plot_ranked, 165 | prob_threshold=options.prob_threshold, 166 | fnout=fn_fig_out) 167 | 168 | if options.fn_model_dm is not None: 169 | if len(data_dm)>0: 170 | print("\nCLASSIFYING DM/TIME DATA\n)") 171 | print(data_dm.shape) 172 | fn_fig_out = options.fnout + '_dm_time' 173 | classify(data_dm, options.fn_model_dm, 174 | save_ranked=options.save_ranked, 175 | plot_ranked=options.plot_ranked, 176 | prob_threshold=options.prob_threshold, 177 | fnout=fn_fig_out) 178 | else: 179 | print("No DM/time data to classify") 180 | 181 | if options.fn_model_time is not None: 182 | print("\nCLASSIFYING 1D TIME DATA\n)") 183 | fn_fig_out = options.fnout + '_1d_time' 184 | classify(data_freq, options.fn_model_time, 185 | save_ranked=options.save_ranked, 186 | plot_ranked=options.plot_ranked, 187 | prob_threshold=options.prob_threshold, 188 | fnout=fn_fig_out) 189 | 190 | if options.fn_model_mb is not None: 191 | classify(data_mb, options.fn_model_mb, 192 | save_ranked=options.save_ranked, 193 | plot_ranked=options.plot_ranked, 194 | prob_threshold=options.prob_threshold, 195 | fnout=options.fnout) 196 | 197 | exit() 198 | 199 | dshape = data_freq.shape 200 | 201 | # normalize data 202 | data_freq = data_freq.reshape(len(data_freq), -1) 203 | data_freq -= np.median(data_freq, axis=-1)[:, None] 204 | data_freq /= np.std(data_freq, axis=-1)[:, None] 205 | 206 | # zero out nans 207 | data_freq[data_freq!=data_freq] = 0.0 208 | data_freq = data_freq.reshape(dshape) 209 | 210 | if data_freq.shape[-1]!=1: 211 | data_freq = data_freq[..., None] 212 | 213 | model = frbkeras.load_model(fn_model_freq) 214 | 215 | if len(model.input.shape)==3: 216 | data_freq = data_freq.mean(1) 217 | 218 | y_pred_prob = model.predict(data_freq) 219 | y_pred_prob = y_pred_prob[:,1] 220 | 221 | ind_frb = np.where(y_pred_prob>options.prob_threshold)[0] 222 | 223 | print("\n%d out of %d events with probability > %.2f:\n %s" % 224 | (len(ind_frb), len(y_pred_prob), 225 | options.prob_threshold, ind_frb)) 226 | 227 | low_to_high_ind = np.argsort(y_pred_prob) 228 | 229 | if options.save_ranked is True: 230 | fnout_ranked = fn_data.rstrip('.hdf5') + 'freq_time_candidates.hdf5' 231 | 232 | g = h5py.File(fnout_ranked, 'w') 233 | g.create_dataset('data_frb_candidate', data=data_freq[ind_frb]) 234 | g.create_dataset('frb_index', data=ind_frb) 235 | g.create_dataset('probability', data=y_pred_prob) 236 | g.close() 237 | print("\nSaved them and all probabilities to: \n%s" % fnout_ranked) 238 | 239 | if options.plot_ranked is True: 240 | if options.save_ranked is False: 241 | argtup = (data_freq[ind_frb], ind_frb, y_pred_prob) 242 | plot_tools.plot_multiple_ranked(argtup, nside=5, \ 243 | fnfigout=options.fnout) 244 | else: 245 | plot_tools.plot_multiple_ranked(fnout_ranked, nside=5, \ 246 | fnfigout=options.fnout) 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | -------------------------------------------------------------------------------- /single_pulse_ml/data/data.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liamconnor/single_pulse_ml/88b6b76ebf3d3939214d9785d4e1c5076f653c38/single_pulse_ml/data/data.txt -------------------------------------------------------------------------------- /single_pulse_ml/dataproc.py: -------------------------------------------------------------------------------- 1 | """ Tools for preprocessing data 2 | """ 3 | 4 | import numpy as np 5 | 6 | def normalize_data(data): 7 | """ Normalize data to zero-median and 8 | unit standard deviation 9 | 10 | Parameters 11 | ---------- 12 | data : np.array 13 | (nfreq, ntimes) 14 | """ 15 | # subtract each channel's median 16 | data -= np.median(data, axis=-1)[:, None] 17 | # demand unit variance 18 | # data /= np.std(data, axis=-1)[:, None] 19 | # Try dividing by global variance. 20 | data /= np.std(data) 21 | # Replace nans with zero 22 | data[data!=data] = 0. 23 | 24 | return data 25 | 26 | 27 | def dedisp(data, dm, freq=np.linspace(800, 400, 1024), dt=512*2.56e-6): 28 | """ Dedisperse data by shifting freq bins 29 | 30 | Parameters 31 | ---------- 32 | data : np.array 33 | (nfreq, ntimes) 34 | dm : np.float 35 | dispersion measure in pc cm**-3 36 | freq : np.array 37 | (nfreq) vector in MHz 38 | dt : np.float 39 | time resolution of data in seconds 40 | """ 41 | dm_del = 4.148808e3 * dm * (freq**(-2) - 600.0**(-2)) 42 | data_out = np.zeros_like(data) 43 | 44 | for ii, ff in enumerate(freq): 45 | dmd = int(round(dm_del[ii] / dt)) 46 | data_out[ii] = np.roll(data[ii], -dmd, axis=-1) 47 | 48 | return data_out 49 | 50 | def dm_delays(dm, freq, f_ref): 51 | """ Calculate dispersion delays in seconds 52 | 53 | Parameters 54 | ---------- 55 | dm : np.float 56 | dispersion measure in pc cm**-3 57 | freq : np.array 58 | (nfreq) vector in MHz 59 | f_ref: np.float 60 | reference frequency in MHz 61 | """ 62 | return 4.148808e3 * dm * (freq**(-2) - f_ref**(-2)) 63 | 64 | 65 | def straighten_arr(data): 66 | """ Step through each freq, find DM shift 67 | that gives largest S/N, realign bins 68 | 69 | Parameters 70 | ---------- 71 | data : np.array 72 | (nfreq, ntimes) 73 | """ 74 | 75 | sn = [] 76 | 77 | dms = np.linspace(-5, 5, 100) 78 | 79 | for dm in dms: 80 | d_ = dedisp(data.copy(), dm, freq=linspace(800,400,16)) 81 | sn.append(d_.mean(0).max() / np.std(d_.mean(0))) 82 | 83 | d_ = dedisp(data, dms[np.argmax(sn)], freq=linspace(800,400,16)) 84 | 85 | return d_ 86 | 87 | def run_straightening(fn): 88 | """ Take filename, read in data, shift 89 | to remove any excess dm-delay. 90 | 91 | Parameters 92 | ---------- 93 | fn : str 94 | filename of numpy array 95 | """ 96 | f = np.load(fn) 97 | 98 | y = f[:, -1] 99 | 100 | d = f[y==1, :-1].copy() 101 | 102 | for ii in range(len(d)): 103 | dd_ = d[ii].reshape(-1, 250) 104 | d[ii] = (straighten_arr(dd_)).reshape(-1) 105 | 106 | f[y==1, :-1] = d 107 | 108 | for jj in range(len(f)): 109 | dd_ = f[jj, :-1].reshape(-1, 250) 110 | dd_ = reader.normalize_data(dd_) 111 | f[jj, :-1] = dd_.flatten() 112 | 113 | return f -------------------------------------------------------------------------------- /single_pulse_ml/frbkeras.py: -------------------------------------------------------------------------------- 1 | """ Tools for building and training deep neural 2 | networks in keras using the tensorflow backend. 3 | """ 4 | 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import sys 11 | 12 | import numpy as np 13 | from numpy.random import seed 14 | import h5py 15 | 16 | import keras 17 | from keras.models import Sequential 18 | from keras.layers import Dense, Dropout, Flatten, Merge 19 | from keras.layers import Conv1D, Conv2D 20 | from keras.layers import MaxPooling2D, MaxPooling1D, GlobalAveragePooling1D, BatchNormalization 21 | from keras.optimizers import SGD 22 | from keras.models import load_model 23 | 24 | 25 | def get_predictions(model, data, true_labels=None): 26 | """ Take a keras.model object, a data array, 27 | and true_labels, and return the probability of 28 | each feature being a TP, the prediction itself, 29 | and the mistakes. 30 | """ 31 | if len(true_labels.shape)==2: 32 | true_labels = true_labels[:,1] 33 | 34 | prob = model.predict(data) 35 | predictions = np.round(prob[:, 1]) 36 | 37 | if true_labels is not None: 38 | mistakes = np.where(predictions!=true_labels)[0] 39 | else: 40 | mistakes = [] 41 | 42 | return prob, predictions, mistakes 43 | 44 | def get_classification_results(y_true, y_pred): 45 | """ Take true labels (y_true) and model-predicted 46 | label (y_pred) for a binary classifier, and return 47 | true_positives, false_positives, true_negatives, false_negatives 48 | """ 49 | 50 | true_positives = np.where((y_true==1) & (y_pred==1))[0] 51 | false_positives = np.where((y_true==0) & (y_pred==1))[0] 52 | true_negatives = np.where((y_true==0) & (y_pred==0))[0] 53 | false_negatives = np.where((y_true==1) & (y_pred==0))[0] 54 | 55 | return true_positives, false_positives, true_negatives, false_negatives 56 | 57 | def confusion_mat(y_true, y_pred): 58 | """ Generate a confusion matrix for a 59 | binary classifier based on true labels ( 60 | y_true) and model-predicted label (y_pred) 61 | 62 | returns np.array([[TP, FP],[FN, TN]]) 63 | """ 64 | TP, FP, TN, FN = get_classification_results(y_true, y_pred) 65 | 66 | NTP = len(TP) 67 | NFP = len(FP) 68 | NTN = len(TN) 69 | NFN = len(FN) 70 | 71 | conf_mat = np.array([[NTP, NFP],[NFN, NTN]]) 72 | 73 | return conf_mat 74 | 75 | def print_metric(y_true, y_pred): 76 | """ Take true labels (y_true) and model-predicted 77 | label (y_pred) for a binary classifier 78 | and print a confusion matrix, metrics, 79 | return accuracy, precision, recall, fscore 80 | """ 81 | conf_mat = confusion_mat(y_true, y_pred) 82 | 83 | NTP, NFP, NTN, NFN = conf_mat[0,0], conf_mat[0,1], conf_mat[1,1], conf_mat[1,0] 84 | 85 | print("Confusion matrix:") 86 | 87 | print('\n'.join([''.join(['{:8}'.format(item) for item in row]) 88 | for row in conf_mat])) 89 | 90 | accuracy = float(NTP + NTN)/conf_mat.sum() 91 | precision = float(NTP) / (NTP + NFP + 1e-19) 92 | recall = float(NTP) / (NTP + NFN + 1e-19) 93 | fscore = 2*precision*recall/(precision+recall) 94 | 95 | print("accuracy: %f" % accuracy) 96 | print("precision: %f" % precision) 97 | print("recall: %f" % recall) 98 | print("fscore: %f" % fscore) 99 | 100 | return accuracy, precision, recall, fscore 101 | 102 | def construct_ff1d(features_only=False, fit=False, 103 | train_data=None, train_labels=None, 104 | eval_data=None, eval_labels=None, 105 | nbeam=32, epochs=5, 106 | nlayer1=32, nlayer2=64, batch_size=32): 107 | """ Build a one-dimensional feed forward neural network 108 | with a binary classifier. Can be used for, e.g., 109 | multi-beam detections. 110 | 111 | Parameters: 112 | ---------- 113 | features_only : bool 114 | Don't construct full model, only features layers 115 | fit : bool 116 | Fit model 117 | train_data : ndarray 118 | (ntrain, ntime, 1) float64 array with training data 119 | train_labels : ndarray 120 | (ntrigger, 2) binary labels of training data [0, 1] = FRB, [1, 0]=RFI 121 | eval_data : ndarray 122 | (neval, ntime, 1) float64 array with evaluation data 123 | eval_labels : 124 | (neval, 2) binary labels of eval data 125 | nbeam : int 126 | Number of input beams (more generally, number of data inputs) 127 | epochs : int 128 | Number of training epochs 129 | nlayer1 : int 130 | Number of neurons in first hidden layer 131 | nlayer2 : int 132 | Number of neurons in second hidden layer 133 | batch_size : int 134 | Number of batches for training 135 | 136 | Returns 137 | ------- 138 | model : XX 139 | 140 | score : XX 141 | 142 | """ 143 | model = Sequential() 144 | model.add(Dense(nlayer1, input_dim=nbeam, activation='relu')) 145 | model.add(Dropout(0.4)) 146 | model.add(Dense(nlayer2, init='normal', activation='relu')) 147 | 148 | if features_only is True: 149 | model.add(BatchNormalization()) # hack 150 | return model, [] 151 | 152 | model.add(Dropout(0.4)) 153 | model.add(Dense(2, activation='sigmoid')) 154 | 155 | model.compile(loss='binary_crossentropy', 156 | optimizer='rmsprop', 157 | metrics=['accuracy']) 158 | 159 | model.fit(train_data, train_labels, batch_size=batch_size, epochs=epochs) 160 | score = model.evaluate(eval_data, eval_labels, batch_size=batch_size) 161 | 162 | return model, score 163 | 164 | def construct_conv2d(features_only=False, fit=False, 165 | train_data=None, train_labels=None, 166 | eval_data=None, eval_labels=None, 167 | nfreq=16, ntime=250, epochs=5, 168 | nfilt1=32, nfilt2=64, batch_size=32): 169 | """ Build a two-dimensional convolutional neural network 170 | with a binary classifier. Can be used for, e.g., 171 | freq-time dynamic spectra of pulsars, dm-time intensity array. 172 | 173 | Parameters: 174 | ---------- 175 | features_only : bool 176 | Don't construct full model, only features layers 177 | fit : bool 178 | Fit model 179 | train_data : ndarray 180 | (ntrain, ntime, 1) float64 array with training data 181 | train_labels : ndarray 182 | (ntrigger, 2) binary labels of training data [0, 1] = FRB, [1, 0]=RFI 183 | eval_data : ndarray 184 | (neval, ntime, 1) float64 array with evaluation data 185 | eval_labels : 186 | (neval, 2) binary labels of eval data 187 | epochs : int 188 | Number of training epochs 189 | nfilt1 : int 190 | Number of neurons in first hidden layer 191 | nfilt2 : int 192 | Number of neurons in second hidden layer 193 | batch_size : int 194 | Number of batches for training 195 | 196 | Returns 197 | ------- 198 | model : XX 199 | 200 | score : np.float 201 | accuracy, i.e. fraction of predictions that are correct 202 | 203 | """ 204 | 205 | if train_data is not None: 206 | nfreq=train_data.shape[1] 207 | ntime=train_data.shape[2] 208 | 209 | model = Sequential() 210 | # this applies 32 convolution filters of size 5x5 each. 211 | model.add(Conv2D(nfilt1, (5, 5), activation='relu', input_shape=(nfreq, ntime, 1))) 212 | 213 | #model.add(Conv2D(32, (3, 3), activation='relu')) 214 | model.add(MaxPooling2D(pool_size=(2, 2))) 215 | # Randomly drop some fraction of nodes (set weights to 0) 216 | model.add(Dropout(0.4)) 217 | model.add(Conv2D(nfilt2, (5, 5), activation='relu')) 218 | model.add(MaxPooling2D(pool_size=(2, 2))) 219 | model.add(Dropout(0.4)) 220 | model.add(Flatten()) 221 | 222 | if features_only is True: 223 | model.add(BatchNormalization()) # hack 224 | return model, [] 225 | 226 | model.add(Dense(256, activation='relu')) # should be 1024 hack 227 | 228 | # model.add(Dense(1024, activation='relu')) # remove for now hack 229 | model.add(Dropout(0.5)) 230 | model.add(Dense(2, activation='softmax')) 231 | 232 | sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) 233 | model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy']) 234 | 235 | # train_labels = keras.utils.to_categorical(train_labels) 236 | # eval_labels = keras.utils.to_categorical(eval_labels) 237 | 238 | if fit is True: 239 | print("Using batch_size: %d" % batch_size) 240 | print("Using %d epochs" % epochs) 241 | cb = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, 242 | batch_size=32, write_graph=True, write_grads=False, 243 | write_images=True, embeddings_freq=0, embeddings_layer_names=None, 244 | embeddings_metadata=None) 245 | 246 | model.fit(train_data, train_labels, batch_size=batch_size, epochs=epochs, callbacks=[cb]) 247 | score = model.evaluate(eval_data, eval_labels, batch_size=batch_size) 248 | print("Conv2d only") 249 | print(score) 250 | 251 | return model, score 252 | 253 | def construct_conv1d(features_only=False, fit=False, 254 | train_data=None, train_labels=None, 255 | eval_data=None, eval_labels=None, 256 | nfilt1=64, nfilt2=128, 257 | batch_size=16, epochs=5): 258 | """ Build a one-dimensional convolutional neural network 259 | with a binary classifier. Can be used for, e.g., 260 | pulse profiles. 261 | 262 | Parameters: 263 | ---------- 264 | features_only : bool 265 | Don't construct full model, only features layers 266 | fit : bool 267 | Fit model 268 | train_data : ndarray 269 | (ntrain, ntime, 1) float64 array with training data 270 | train_labels : ndarray 271 | (ntrigger, 2) binary labels of training data [0, 1] = FRB, [1, 0]=RFI 272 | eval_data : ndarray 273 | (neval, ntime, 1) float64 array with evaluation data 274 | eval_labels : 275 | (neval, 2) binary labels of eval data 276 | epochs : int 277 | Number of training epochs 278 | nfilt1 : int 279 | Number of neurons in first hidden layer 280 | nfilt2 : int 281 | Number of neurons in second hidden layer 282 | batch_size : int 283 | Number of batches for training 284 | 285 | Returns 286 | ------- 287 | model : XX 288 | 289 | score : XX 290 | 291 | """ 292 | 293 | if train_data is not None: 294 | NTIME=train_data.shape[1] 295 | 296 | model = Sequential() 297 | model.add(Conv1D(nfilt1, 3, activation='relu', input_shape=(NTIME, 1))) 298 | model.add(Conv1D(nfilt1, 3, activation='relu')) 299 | model.add(MaxPooling1D(3)) 300 | model.add(Conv1D(nfilt2, 3, activation='relu')) 301 | model.add(Conv1D(nfilt2, 3, activation='relu')) 302 | model.add(GlobalAveragePooling1D()) 303 | 304 | if features_only is True: 305 | return model, [] 306 | 307 | model.add(Dropout(0.5)) 308 | model.add(Dense(2, activation='sigmoid')) 309 | 310 | model.compile(loss='binary_crossentropy', 311 | optimizer='rmsprop', 312 | metrics=['accuracy']) 313 | 314 | if fit is True: 315 | model.fit(train_data, train_labels, batch_size=batch_size, epochs=epochs) 316 | score = model.evaluate(eval_data, eval_labels, batch_size=16) 317 | print("Conv1d only") 318 | 319 | return model, score 320 | 321 | 322 | def merge_models(model_list, train_data_list, 323 | train_labels, eval_data_list, eval_labels, 324 | batch_size=32, epochs=5): 325 | """ Take list of models, list of training data, 326 | merge models and train as a single network. 327 | """ 328 | 329 | 330 | model = Sequential() 331 | model.add(Merge(model_list, mode = 'concat')) 332 | model.add(Dense(256, activation='relu')) 333 | model.add(Dense(2, init = 'normal', activation = 'sigmoid')) 334 | sgd = SGD(lr = 0.1, momentum = 0.9, decay = 0, nesterov = False) 335 | model.compile(loss = 'binary_crossentropy', 336 | optimizer=sgd, 337 | metrics=['accuracy']) 338 | seed(2017) 339 | model.fit(train_data_list, train_labels, 340 | batch_size=batch_size, nb_epoch=epochs, verbose=1) 341 | score = model.evaluate(eval_data_list, eval_labels, batch_size=batch_size) 342 | 343 | return model, score 344 | 345 | -------------------------------------------------------------------------------- /single_pulse_ml/model/model.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liamconnor/single_pulse_ml/88b6b76ebf3d3939214d9785d4e1c5076f653c38/single_pulse_ml/model/model.txt -------------------------------------------------------------------------------- /single_pulse_ml/plot_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | try: 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | 7 | import matplotlib.pyplot as plt 8 | from matplotlib import gridspec 9 | except: 10 | "Didn't work" 11 | pass 12 | 13 | def plot_simulated_events(data, labels, figname, 14 | NSIDE, NFREQ, NTIME, cmap='RdBu'): 15 | """ Make series of waterfall plots of training / test 16 | set. 17 | """ 18 | 19 | NFIG=NSIDE**2 20 | lab_dict = {0 : 'RFI', 1 : 'FRB'} 21 | 22 | fig = plt.figure(figsize=(15,15)) 23 | for ii in range(NFIG): 24 | plt.subplot(NSIDE,NSIDE,ii+1) 25 | plt.imshow(data[ii].reshape(-1, NTIME), 26 | aspect='auto', interpolation='nearest', 27 | cmap=cmap, vmin=-3, vmax=3) 28 | plt.axis('off') 29 | plt.colorbar() 30 | plt.title(lab_dict[labels[ii]]) 31 | plt.xlim(125-32,125+32) 32 | 33 | fig.savefig('%s_rfi.png' % figname) 34 | 35 | fig = plt.figure(figsize=(15,15)) 36 | for ii in range(NFIG): 37 | plt.subplot(NSIDE,NSIDE,ii+1) 38 | plt.imshow(data[-ii-1].reshape(-1, NTIME), 39 | aspect='auto', interpolation='nearest', 40 | cmap=cmap, vmin=-3, vmax=3) 41 | plt.axis('off') 42 | plt.colorbar() 43 | plt.title(lab_dict[labels[ii]]) 44 | plt.xlim(125-32,125+32) 45 | 46 | fig.savefig(figname) 47 | 48 | def plot_gallery(data_arr, titles, h, w, n_row=3, n_col=4, 49 | figname=None, cmap='RdBu', suptitle=''): 50 | """Helper function to plot a gallery of portraits""" 51 | plt.figure(figsize=(1.8 * n_col, 2.4 * n_row)) 52 | plt.suptitle(suptitle, fontsize=35, color='blue', alpha=0.5) 53 | plt.subplots_adjust(bottom=0, left=.01, right=.99, top=.90, hspace=.35) 54 | for i in range(min(n_row * n_col, len(data_arr))): 55 | d_arr = data_arr[i].reshape((h, w)) 56 | d_arr -= np.median(d_arr) 57 | plt.subplot(n_row, n_col, i + 1) 58 | plt.imshow(d_arr, cmap=cmap, aspect='auto') 59 | plt.title(titles[i], size=12, color='red') 60 | plt.xticks(()) 61 | plt.yticks(()) 62 | if figname: 63 | plt.savefig(figname) 64 | 65 | 66 | def get_title(y, target_names): 67 | prediction_titles = y.astype(str) 68 | prediction_titles[prediction_titles=='0'] = target_names[0] 69 | prediction_titles[prediction_titles=='1'] = target_names[1] 70 | 71 | return prediction_titles 72 | 73 | def get_title2(y_pred, y_test, target_names, i): 74 | pred_name = target_names[y_pred[i]] 75 | true_name = target_names[y_test[i]] 76 | return 'predicted: %s\ntrue: %s' % (pred_name, true_name) 77 | 78 | def plot_ranked_trigger(data, prob_arr, h=6, w=6, 79 | ascending=False, outname='out', 80 | cmap='RdBu', vmax=3, vmin=-3, 81 | yaxlabel='Freq'): 82 | """ Plot single-pulse triggers ranked by the 83 | classifier's assigned probability. 84 | 85 | Parameters 86 | ---------- 87 | data : np.array 88 | data array with triggers 89 | prob_arr : np.array 90 | probability of event being a true FRB 91 | h : np.int 92 | number of rows of triggers 93 | w : np.int 94 | number of columns of triggers 95 | ascending : bool / str 96 | plot in ascending order (True, False, 'mid') 97 | outname : str 98 | figure name 99 | cmap : str 100 | colormap to use in imshow 101 | 102 | Returns 103 | ------- 104 | None 105 | """ 106 | 107 | if len(prob_arr.shape)>1: 108 | prob_arr = prob_arr[:,1] 109 | 110 | ranking = np.argsort(prob_arr) 111 | 112 | if ascending == True: 113 | ranking = ranking[::-1] 114 | title_str = 'RFI most probable' 115 | outname = outname 116 | elif ascending == 'mid': 117 | # cp = np.argsort(abs(prob_arr[:,0]-0.5)) 118 | # ranking = cp[:h*w] 119 | inflection = np.argmax(abs(np.diff(prob_arr[ranking]))) 120 | ranking = ranking[inflection-h*w/2:inflection+h*w/2] 121 | title_str = 'Marginal events' 122 | outname = outname 123 | else: 124 | title_str = 'FRB most probable' 125 | outname = outname 126 | 127 | fig = plt.figure(figsize=(15,15)) 128 | 129 | for ii in range(min(h*w, len(prob_arr))): 130 | plt.subplot(h, w, ii+1) 131 | if len(data.shape)==3: 132 | plt.imshow(data[ranking[ii]], 133 | cmap=cmap, interpolation='nearest', 134 | aspect='auto', vmin=vmin, vmax=vmax, 135 | extent=[0, 1, 400, 800]) 136 | elif len(data.shape)==2: 137 | plt.plot(data[ranking[ii]]) 138 | else: 139 | print("Wrong data input shape") 140 | return 141 | 142 | #plt.axis('off') 143 | plt.xticks([]) 144 | plt.yticks([]) 145 | plt.title('p='+str(np.round(prob_arr[ranking[ii]], 5)), fontsize=12) 146 | 147 | if ii % w == 0: 148 | plt.ylabel(yaxlabel, fontsize=14) 149 | if ii >= (h*w-w): 150 | plt.xlabel("Time", fontsize=14) 151 | 152 | if outname is not None: 153 | fig.savefig(outname) 154 | else: 155 | plt.show() 156 | 157 | def plot_multiple_ranked(argin, nside=5, fnfigout='ranked_trig', 158 | ascending=True): 159 | """ Generate multiple multi-panel figures 160 | using plot_ranked_trigger 161 | 162 | Parameters 163 | ---------- 164 | 165 | argin : str/tuple 166 | input arguments, either 167 | (data_frb_candidate, frb_index, probability) 168 | or a filename 169 | nside : np.int 170 | number of figures per row/col 171 | fnfigout : str 172 | fig name 173 | """ 174 | import sys 175 | import h5py 176 | 177 | if type(argin)==tuple: 178 | data_frb_candidate, frb_index, probability = argin 179 | fn = './' 180 | elif type(argin)==str: 181 | fn = argin 182 | f = h5py.File(fn,'r') 183 | data_frb_candidate = f['data_frb_candidate'][:] 184 | frb_index = f['frb_index'][:] 185 | probability = f['probability'][:] 186 | f.close() 187 | else: 188 | print("Wrong input argument") 189 | return 190 | 191 | ntrig = len(frb_index) 192 | probability = probability[frb_index] 193 | ind = np.argsort(probability)[::-1] 194 | data = data_frb_candidate[ind] 195 | probability_ = probability[ind] 196 | 197 | for ii in range(ntrig//nside**2+1): 198 | data_sub = data[nside**2*ii:nside**2*(ii+1),...,0] 199 | prob_sub = probability_[nside**2*ii:nside**2*(ii+1)] 200 | pmin, pmax = prob_sub.min(), prob_sub.max() 201 | 202 | fnfigout_ = fnfigout+'prob:%.2f-%.2f.pdf' % (pmin, pmax) 203 | print("Saving to %s" % fnfigout) 204 | 205 | plot_ranked_trigger(data_sub, prob_sub, 206 | h=nside, w=nside, ascending=ascending, 207 | outname=fnfigout_, cmap=None) 208 | 209 | 210 | def plot_image_probabilities(FT_arr, DT_arr, FT_prob_spec, DT_prob_spec): 211 | 212 | assert (len(FT_arr.shape)==2) and (len(DT_arr.shape)==2), \ 213 | "Input data should be (nfreq, ntimes)" 214 | 215 | gs2 = gridspec.GridSpec(4, 3) 216 | ax1 = plt.subplot(gs2[:2, :2]) 217 | ax1.xaxis.set_ticklabels('') 218 | ax1.yaxis.set_ticklabels('') 219 | plt.ylabel('Freq', fontsize=18) 220 | plt.xlabel('Time', fontsize=18) 221 | ax1.imshow(FT_arr, cmap='RdBu', interpolation='nearest', aspect='auto') 222 | 223 | ax2 = plt.subplot(gs2[:2, 2:]) 224 | ax2.yaxis.tick_right() 225 | ax2.yaxis.set_label_position('right') 226 | plt.ylabel('probability', fontsize=18) 227 | ax2.bar([0, 1], FT_prob_spec, color='red', alpha=0.75) 228 | plt.xticks([0.5, 1.5], ['RFI', 'Pulse']) 229 | plt.ylim(0, 1) 230 | plt.xlim(-.25, 2.) 231 | 232 | ax3 = plt.subplot(gs2[2:, :2]) 233 | ax3.xaxis.set_ticklabels('') 234 | ax3.yaxis.set_ticklabels('') 235 | plt.ylabel('Freq', fontsize=18) 236 | plt.xlabel('Time', fontsize=18) 237 | ax3.imshow(DT_arr, cmap='RdBu', interpolation='nearest', \ 238 | aspect='auto') 239 | 240 | ax4 = plt.subplot(gs2[2:, 2:]) 241 | ax4.yaxis.set_label_position('right') 242 | ax4.yaxis.tick_right() 243 | plt.ylabel('probability', fontsize=18) 244 | ax4.bar([0, 1], DT_prob_spec, color='red', alpha=0.75) 245 | plt.xticks([0.5, 1.5], ['RFI', 'Pulse']) 246 | plt.ylim(0, 1) 247 | plt.xlim(-.25, 2.) 248 | 249 | plt.suptitle('TensorFlow Deep Learn', fontsize=45, ) 250 | 251 | 252 | class VisualizeLayers: 253 | """ Class to visualize the hidden 254 | layers of a deep neural network in 255 | keras. 256 | """ 257 | import keras.backend as backend 258 | 259 | def __init__(self, model): 260 | self._model = model 261 | self._NFREQ = model.get_input_shape_at(0)[1] 262 | self._NTIME = model.get_input_shape_at(0)[2] 263 | self.grid_counter = 0 264 | # Create empty list for non-redundant activations 265 | self._activations_nonred = [] 266 | self._NFREQ_min = min([mm.input.shape[1] for mm in model.layers]) 267 | 268 | def print_layers(self): 269 | """ Print layer names and shapes of keras model 270 | """ 271 | for layer in self._model.layers: 272 | print("%s: %10s" % (layer.name, layer.input.shape)) 273 | 274 | def imshow_custom(self, data, **kwargs): 275 | """ matplotlib imshow with custom arguments 276 | """ 277 | plt.imshow(data, aspect='auto', interpolation='nearest', 278 | **kwargs) 279 | 280 | def remove_doubles(self, activations): 281 | """ Remove layers with identical shapes, e.g. 282 | dropout layers 283 | """ 284 | self._activations_nonred.append(activations[0]) 285 | 286 | # Start from first element, skip input data 287 | for ii, activation in enumerate(activations[1:]): 288 | act_shape = activation.shape 289 | if act_shape != activations[ii].shape: 290 | self._activations_nonred.append(activation) 291 | 292 | def get_activations(self, model_inputs, 293 | print_shape_only=True, 294 | layer_name=None): 295 | 296 | print('----- activations -----') 297 | activations = [] 298 | inp = self._model.input 299 | 300 | model_multi_inputs_cond = True 301 | if not isinstance(inp, list): 302 | # only one input! let's wrap it in a list. 303 | inp = [inp] 304 | model_multi_inputs_cond = False 305 | 306 | outputs = [layer.output for layer in self._model.layers if 307 | layer.name == layer_name or layer_name is None] # all layer outputs 308 | 309 | funcs = [backend.function(inp + \ 310 | [backend.learning_phase()], [out]) \ 311 | for out in outputs] # evaluation functions 312 | 313 | if model_multi_inputs_cond: 314 | list_inputs = [] 315 | list_inputs.extend(model_inputs) 316 | list_inputs.append(0.) 317 | else: 318 | list_inputs = [model_inputs, 0.] 319 | 320 | # Learning phase. 0 = Test mode (no dropout or batch normalization) 321 | # layer_outputs = [func([model_inputs, 0.])[0] for func in funcs] 322 | layer_outputs = [func(list_inputs)[0] for func in funcs] 323 | 324 | # Append input data 325 | activations.append(model_inputs) 326 | 327 | for layer_activations in layer_outputs: 328 | activations.append(layer_activations) 329 | 330 | return activations 331 | 332 | def plot_feature_layer(self, activation, NSIDE=16): 333 | N_SUBFIG = activation.shape[-1] 334 | 335 | if N_SUBFIG==1: 336 | 337 | ax = plt.subplot2grid((NSIDE,NSIDE), 338 | (self.grid_counter, 3*NSIDE//8), 339 | colspan=NSIDE//4, rowspan=NSIDE//4) 340 | plt.plot(activation[0,:,0]) 341 | return 342 | 343 | for ii in range(N_SUBFIG): 344 | size=int(activation.shape[1] / self._NFREQ_min) 345 | # size=int(np.round(4*activation.shape[1]/self._NFREQ * NSIDE//32)) 346 | # size=min(size, NSIDE//8) 347 | start_grid = NSIDE//2 - N_SUBFIG*size//2 348 | print(NSIDE, self.grid_counter, start_grid + ii*size, size) 349 | ax = plt.subplot2grid((NSIDE,NSIDE), 350 | (self.grid_counter, start_grid + ii*size), 351 | colspan=size, rowspan=size) 352 | plt.plot(activation[0,:,ii]) 353 | plt.axis('off') 354 | 355 | def im_feature_layer(self, activation, cmap='Greys', NSIDE=16, 356 | start_grid=0, N_SUBFIG=None, skip=1): 357 | N_SUBFIG = activation.shape[-1] if N_SUBFIG is None else N_SUBFIG 358 | 359 | if N_SUBFIG==1: 360 | # cmap = 'RdBu' 361 | 362 | ax = plt.subplot2grid((NSIDE,NSIDE), 363 | (self.grid_counter, 3*NSIDE//8), 364 | colspan=NSIDE//4, rowspan=NSIDE//4) 365 | 366 | print(self.grid_counter,'0') 367 | self.grid_counter += (NSIDE//4+NSIDE//16) # Add one extra unit of space 368 | print(activation.shape) 369 | data = activation[0,:,:,0] 370 | data -= np.median(data) 371 | vmax = 6*np.std(data) 372 | vmin = -1*np.std(data) 373 | self.imshow_custom(data, cmap=cmap, extent=[0, 1, 400, 800], \ 374 | vmax=vmax, vmin=vmin) 375 | print(self.grid_counter,'1') 376 | 377 | plt.xlabel('Time') 378 | plt.ylabel('Freq [MHz]') 379 | 380 | return 381 | 382 | 383 | def im_layers(self, activations, loc_obj, cmap='Greys'): 384 | 385 | sizes = loc_obj[0] 386 | loc = loc_obj[1] 387 | 388 | for jj, activation in enumerate(activations): 389 | for ii in range(activation.shape[-1]): 390 | ax = plt.subplot2grid((NSIDE,NSIDE),(self.grid_counter, loc[jj][ii]), 391 | colspan=sizes[jj], rowspan=sizes[jj]) 392 | 393 | self.imshow_custom(activation[0,:,:,ii], cmap='Greys') 394 | plt.axis('off') 395 | 396 | self.grid_counter += (NSIDE//32+int(sizes[jj])) 397 | 398 | plt.show() 399 | 400 | def get_image_index(self, NSIDE=100): 401 | offset = 0 402 | sizes = np.array([8, 4, 4, 2]) 403 | N_SUBFIG = np.array([8, 8, 16, 16]) 404 | offset = NSIDE//2 - N_SUBFIG*sizes//2 405 | loc1 = (offset[0] + np.arange(8)*sizes[0]).astype(int) 406 | loc2 = (loc1 + (sizes[0]/2 - sizes[1]/2)).astype(int) 407 | loc3 = (offset[2] + arange(16)*(1+sizes[2])).astype(int) 408 | offset3 = NSIDE//2 - (loc3[0] + (loc3[-1] - loc3[0])/2.) 409 | loc3 += int(offset3) 410 | loc4 = (loc3 + (sizes[2]/2 - sizes[3]/2)).astype(int) 411 | loc = [loc1, loc2, loc3, loc4] 412 | 413 | loc_obj = (sizes, loc) 414 | 415 | return loc_obj 416 | 417 | def im_all(self, activations, NSIDE=32, figname=None, color='linen'): 418 | fig = figure(figsize=(15,15)) 419 | self.grid_counter = 0 420 | start_grid_map = np.zeros([len(activations)]).astype(int) 421 | n_neuron_map = [activation.shape[-1] for activation in activations] 422 | loc_obj = self.get_image_index() 423 | 424 | for kk, activation in enumerate(activations[:]): 425 | print(self.grid_counter, kk, activation.shape) 426 | if kk==0: 427 | self.im_layers(activation, loc_obj, cmap='Greys') 428 | elif activation.shape[-1]==2: # For binary classification 429 | activation = activation[0] 430 | activation[0] = 0.025 # Hack for now, visualizing. 431 | ind = np.array([0, 1]) 432 | width = 0.75 433 | ax = plt.subplot2grid((NSIDE,NSIDE), 434 | (self.grid_counter, 3*NSIDE//8), 435 | colspan=NSIDE//4, rowspan=NSIDE//4) 436 | 437 | rects1 = ax.bar(ind[1], activation[1], width, color='r', alpha=0.5) 438 | rects2 = ax.bar(ind[0], activation[0], width, color='green', alpha=0.5) 439 | 440 | ax.set_xticks(ind + width / 2) 441 | ax.set_xticklabels(('Noise', 'FRB')) 442 | ax.set_ylim(0, 1.25) 443 | ax.set_xlim(-0.25, 2.0) 444 | 445 | elif kk==1: 446 | self.im_layers(activations[1:5], loc_obj, cmap='Greys') 447 | 448 | if figname is not None: 449 | plt.savefig(figname)#, facecolor=color) 450 | 451 | def make_figure(self, data, NSIDE=32, figname=None): 452 | dsh = data.shape 453 | 454 | if len(dsh)==2: 455 | data = data[None,:,:,None] 456 | elif len(dsh)==3: 457 | if dsh[0]==1: 458 | data = data[..., None] 459 | elif dsh[-1]==1: 460 | data = data[None] 461 | 462 | # Make sure there's no activation 463 | # which has more filters than NSIDE 464 | for activation in activations: 465 | if len(activation.shape) > 2: 466 | NSIDE = max(NSIDE, activation.shape[-1]) 467 | 468 | print("Using NSIDE: %d" % NSIDE) 469 | 470 | self.remove_doubles(activations) 471 | self.im_all(self._activations_nonred, NSIDE=NSIDE, figname=figname) 472 | 473 | if __name__=='__main__': 474 | import sys 475 | 476 | import h5py 477 | 478 | try: 479 | fn = sys.argv[1] 480 | except: 481 | print("\nExpected input datafile as argument\n") 482 | exit() 483 | 484 | plot_multiple_ranked(fn, nside=5) 485 | 486 | 487 | 488 | -------------------------------------------------------------------------------- /single_pulse_ml/plots/Freq_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liamconnor/single_pulse_ml/88b6b76ebf3d3939214d9785d4e1c5076f653c38/single_pulse_ml/plots/Freq_train.png -------------------------------------------------------------------------------- /single_pulse_ml/reader.py: -------------------------------------------------------------------------------- 1 | """ Tools for io as well as creating training 2 | and test data sets. 3 | """ 4 | 5 | import os 6 | 7 | import time 8 | import numpy as np 9 | import h5py 10 | import glob 11 | import pickle 12 | 13 | try: 14 | import matplotlib.pylab as plt 15 | except: 16 | pass 17 | 18 | try: 19 | import filterbank 20 | except: 21 | pass 22 | 23 | 24 | def read_hdf5(fn): 25 | """ Read in data from .hdf5 file 26 | containing dynamic spectra, dm-time array, 27 | and data labels 28 | """ 29 | 30 | f = h5py.File(fn, 'r') 31 | data_freq = f['data_freq_time'][:] 32 | 33 | try: 34 | y = f['labels'][:] 35 | except: 36 | print("labels dataset not there") 37 | y = -1*np.zeros([len(data_freq)]) 38 | 39 | try: 40 | data_dm = f['data_dm_time'][:] 41 | except: 42 | print("dm-time dataset not there") 43 | data_dm = None 44 | 45 | try: 46 | data_mb = f['multibeam_snr'][:] 47 | except: 48 | print("multibeam dataset not there") 49 | data_mb = None 50 | 51 | return data_freq, y, data_dm, data_mb 52 | 53 | def write_to_fil(data, header, fn): 54 | filterbank.create_filterbank_file( 55 | fn, header, spectra=data, mode='readwrite') 56 | print("Writing to %s" % fn) 57 | 58 | def read_fil_data(fn, start=0, stop=1e7): 59 | print("Reading filterbank file %s \n" % fn) 60 | fil_obj = filterbank.FilterbankFile(fn) 61 | header = fil_obj.header 62 | delta_t = fil_obj.header['tsamp'] # delta_t in milliseconds 63 | fch1 = header['fch1'] 64 | nchans = header['nchans'] 65 | foff = header['foff'] 66 | fch_f = fch1 + nchans*foff 67 | freq = np.linspace(fch1, fch_f, nchans) 68 | data = fil_obj.get_spectra(start, stop) 69 | # turn array into time-major, for preprocess 70 | # data = data.transpose() 71 | 72 | return data, freq, delta_t, header 73 | 74 | def read_pathfinder_npy(fn): 75 | data = np.load(fn) 76 | nfreq, ntimes = data.shape[0], data.shape[1] 77 | 78 | if len(data)!=16: 79 | data = data.reshape(-1, nfreq//16, ntimes).mean(1) 80 | 81 | return data 82 | 83 | def rebin_arr(data, n0_f=1, n1_f=1): 84 | """ Rebin 2d array data to have shape 85 | (n0_f, n1_f) 86 | """ 87 | assert len(data.shape)==2 88 | 89 | n0, n1 = data.shape 90 | data_rb = data[:n0//n0_f * n0_f, :n1//n1_f * n1_f] 91 | data_rb = data_rb.reshape(n0_f, n0//n0_f, n1_f, n1//n1_f) 92 | data_rb = data_rb.mean(1).mean(-1) 93 | 94 | return data_rb 95 | 96 | def im(data, title='',figname='out.png'): 97 | fig = plt.figure()# 98 | plt.imshow(data, aspect='auto', interpolation='nearest', cmap='Greys') 99 | plt.savefig(figname) 100 | plt.title(title) 101 | plt.show() 102 | 103 | def combine_data_DT(fn): 104 | """ Combine the training set data in DM / Time space, 105 | assuming text file with lines: 106 | 107 | # filepath label 108 | DM20-100_vdif_assembler+a=00+n=02_DM-T_ +11424.89s.npy 0 109 | DM20-100_vdif_assembler+a=00+n=02_DM-T_ +19422.29s.npy 1 110 | DM20-100_vdif_assembler+a=00+n=02_DM-T_ +21658.40s.npy 0 111 | 112 | e.g. usage: combine_data_DT('./single_pulse_ml/data/test/data_list_DM.txt') 113 | """ 114 | 115 | f = open(fn,'r') 116 | 117 | data_full, y = [], [] 118 | k=0 119 | for ff in f: 120 | fn = './single_pulse_ml/data/' + ff.strip()[:-2] 121 | try: 122 | data = np.load(fn) 123 | except ValueError: 124 | continue 125 | k+=1 126 | label = int(ff[-2]) 127 | y.append(label) 128 | data = normalize_data(data) 129 | data = rebin_arr(data, 64, 250) 130 | 131 | data_full.append(data) 132 | 133 | ndm, ntimes = data.shape 134 | 135 | data_full = np.concatenate(data_full, axis=0) 136 | data_full.shape = (k, -1) 137 | 138 | return data_full, np.array(y) 139 | 140 | def combine_data_FT(fn): 141 | """ combine_data_FT('./single_pulse_ml/data/data_list') 142 | """ 143 | f = open(fn,'r') 144 | 145 | # data and its label class 146 | data_full, y = [], [] 147 | 148 | for ff in f: 149 | line = ff.split(' ') 150 | 151 | fn, label = line[0], int(line[1]) 152 | 153 | y.append(label) 154 | print(fn) 155 | tstamp = fn.split('+')[-2] 156 | 157 | #fdm = glob.glob('./*DM-T*%s*.npy' % tstamp) 158 | fn = './single_pulse_ml/data/test/' + fn 159 | data = read_pathfinder_npy(fn) 160 | data = normalize_data(data) 161 | data_full.append(data) 162 | 163 | nfreq, ntimes = data.shape[0], data.shape[-1] 164 | 165 | data_full = np.concatenate(data_full, axis=0) 166 | data_full.shape = (-1, nfreq*ntimes) 167 | 168 | return data_full, np.array(y) 169 | 170 | def write_data(data, y, fname='out'): 171 | training_arr = np.concatenate((data, y[:, None]), axis=-1) 172 | 173 | np.save(fname, training_arr) 174 | 175 | 176 | def read_data(fn): 177 | arr = np.load(fn) 178 | data, y = arr[:, :-1], arr[:, -1] 179 | 180 | return data, y 181 | 182 | def read_pkl(fn): 183 | if fn[-4:]!='.pkl': fn+='.pkl' 184 | 185 | file = open(fn, 'rb') 186 | 187 | model = pickle.load(file) 188 | 189 | return model 190 | 191 | def write_pkl(model, fn): 192 | if fn[-4:]!='.pkl': fn+='.pkl' 193 | 194 | file = open(fn, 'wb') 195 | pickle.dump(model, file) 196 | 197 | print("Wrote to pkl file: %s" % fn) 198 | 199 | def get_labels(): 200 | """ Cross reference DM-T files with Freq-T 201 | files and create a training set in DM-T space. 202 | """ 203 | 204 | fin = open('./single_pulse_ml/data/data_list','r') 205 | fout = open('./single_pulse_ml/data/data_list_DM','a') 206 | 207 | for ff in fin: 208 | x = ff.split(' ') 209 | n, c = x[0], int(x[1]) 210 | try: 211 | t0 = n.split('+')[-2] 212 | float(t0) 213 | except ValueError: 214 | t0 = n.split('+')[-1].split('s')[0] 215 | 216 | newlist = glob.glob('./single_pulse_ml/data/DM*DM*%s*' % t0) 217 | 218 | if len(newlist) > 0: 219 | string = "%s %s\n" % (newlist[0].split('/')[-1], c) 220 | fout.write(string) 221 | 222 | def create_training_set(freqtime=True, 223 | fout='./single_pulse_ml/data/data_freqtime_train'): 224 | if freqtime: 225 | data, y = combine_data_FT('test') 226 | else: 227 | data, y = combine_data_DT('test') 228 | 229 | write_data(data, y, fname=fout) 230 | 231 | def shuffle_array(data_1, data_2=None): 232 | """ Take one or two data array(s), shuffle 233 | in place, and shuffle the second array in the same 234 | ordering, if applicable. 235 | """ 236 | ntrigger = len(data_1) 237 | index = np.arange(ntrigger) 238 | 239 | if data_1.shape > 2: 240 | data_1 = data_1.reshape(ntrigger, -1) 241 | data_2 = data_2.reshape(ntrigger, -1) 242 | 243 | data_1_ = np.concatenate((data_1, index[:, None]), axis=-1) 244 | np.random.shuffle(data_1_) 245 | index_shuffle = (data_1_[:, -1]).astype(int) 246 | data_2 = data_2[index_shuffle] 247 | 248 | return data_1_[:, :-1], data_2 249 | 250 | 251 | 252 | 253 | 254 | 255 | -------------------------------------------------------------------------------- /single_pulse_ml/run_frb_simulation.py: -------------------------------------------------------------------------------- 1 | """ Script to build dataset out of simulated 2 | single pulses + false positive triggers 3 | """ 4 | 5 | from single_pulse_ml import sim_parameters 6 | from single_pulse_ml import telescope 7 | #from single_pulse_ml import simulate_frb 8 | import simulate_frb 9 | 10 | # TELESCOPE PARAMETERS: 11 | freq = (800, 400) # (FREQ_LOW, FREQ_UP) in MHz 12 | FREQ_REF = 600 # reference frequency in MHz 13 | DELTA_T = 0.0016 # time res in seconds 14 | NAME = "CHIMEPathfinder" 15 | 16 | # SIMULATION PARAMETERS 17 | NFREQ = 32 # Number of frequencies. Must agree with FP data 18 | NTIME = 250 # Number of time stamps per trigger 19 | dm = (-0.05, 0.05) 20 | fluence = (1, 10) 21 | width = (0.0016, 0.75) # width lognormal dist in seconds 22 | spec_ind = (-4., 4.) 23 | disp_ind = 2. 24 | scat_factor = (-4., -1.5) 25 | NRFI = 5000 26 | SNR_MIN = 5.0 27 | SNR_MAX = 25.00 28 | out_file_name = None, 29 | mk_plot = False 30 | NSIDE = 8 31 | dm_time_array = False 32 | outname_tag = 'apertif_250' 33 | 34 | #fn_rfi = './data/arts_FPs_33583.hdf5' 35 | #fn_noise = './data/apertif_background3669.npy' 36 | 37 | # If no background data available, use None option 38 | fn_rfi = None # Use Gaussian noise as false positive data 39 | fn_noise = None # Use Gaussian noise for simulated FRBs 40 | 41 | sim_obj = sim_parameters.SimParams(dm=dm, fluence=fluence, 42 | width=width, spec_ind=spec_ind, 43 | disp_ind=disp_ind, scat_factor=scat_factor, 44 | SNR_MIN=SNR_MIN, SNR_MAX=SNR_MAX, 45 | out_file_name=out_file_name, NRFI=NRFI, 46 | NTIME=NTIME, NFREQ=NFREQ, 47 | mk_plot=mk_plot, NSIDE=NSIDE, ) 48 | 49 | tel_obj = telescope.Telescope(freq=freq, FREQ_REF=FREQ_REF, 50 | DELTA_T=DELTA_T, name=NAME) 51 | 52 | data, labels, params, snr = simulate_frb.run_full_simulation( 53 | sim_obj, tel_obj, fn_rfi=fn_rfi, 54 | fn_noise=fn_noise, 55 | dm_time_array=dm_time_array, 56 | outname_tag=outname_tag) 57 | 58 | -------------------------------------------------------------------------------- /single_pulse_ml/run_single_pulse_DL.py: -------------------------------------------------------------------------------- 1 | """ Script to train and test or multiple deep 2 | neural networks. Input models are expected to be 3 | sequential keras model saved as an .hdf5 file. 4 | Input data are .hdf5 files with data sets 5 | 'labels', 'data_freq_time', 'data_dm_time'. 6 | They can be read by single_pulse_ml.reader.read_hdf5 7 | 8 | """ 9 | import sys 10 | 11 | import numpy as np 12 | import time 13 | import h5py 14 | 15 | #from single_pulse_ml import reader 16 | #from single_pulse_ml import frbkeras 17 | #from single_pulse_ml import plot_tools 18 | 19 | import reader 20 | import frbkeras 21 | import plot_tools 22 | 23 | try: 24 | import matplotlib 25 | matplotlib.use('Agg') 26 | 27 | import matplotlib.pyplot as plt 28 | from matplotlib import gridspec 29 | print("Worked") 30 | except: 31 | "Didn't work" 32 | pass 33 | 34 | FREQTIME=True # train 2D frequency-time CNN 35 | TIME1D=False # train 1D pulse-profile CNN 36 | DMTIME=False # train 2D DM-time CNN 37 | MULTIBEAM=False # train feed-forward NN on simulated multibeam data 38 | 39 | # If True, the different nets will be clipped after 40 | # feature extraction layers and will not be compiled / fit 41 | MERGE=False 42 | 43 | MK_PLOT=False 44 | CLASSIFY_ONLY=False 45 | save_classification=True 46 | model_nm = "./model/model_name" 47 | prob_threshold = 0.0 48 | 49 | ## Input hdf5 file. 50 | fn = './data/input_data.hdf5' 51 | 52 | # Save tf model as .hdf5 53 | save_model = True 54 | fnout = "./model/model_out_name" 55 | 56 | NDM=300 # number of DMs in input array 57 | WIDTH=64 # width to use of arrays along time axis 58 | train_size=0.5 # fraction of dataset to train on 59 | 60 | ftype = fn.split('.')[-1] 61 | 62 | # Create empty lists for final merged model 63 | model_list = [] 64 | train_data_list = [] 65 | eval_data_list = [] 66 | 67 | # Configure the accuracy metric for evaluation 68 | metrics = ["accuracy", "precision", "false_negatives", "recall"] 69 | 70 | if __name__=='__main__': 71 | # read in time-freq data, labels, dm-time data 72 | data_freq, y, data_dm, data_mb = reader.read_hdf5(fn) 73 | NTRIGGER = len(y) 74 | 75 | print("Using %s" % fn) 76 | 77 | NFREQ = data_freq.shape[1] 78 | NTIME = data_freq.shape[2] 79 | 80 | # low time index, high time index 81 | tl, th = NTIME//2-WIDTH//2, NTIME//2+WIDTH//2 82 | 83 | if data_freq.shape[-1] > (th-tl): 84 | data_freq = data_freq[..., tl:th] 85 | 86 | dshape = data_freq.shape 87 | 88 | # normalize data 89 | data_freq = data_freq.reshape(len(data_freq), -1) 90 | data_freq -= np.median(data_freq, axis=-1)[:, None] 91 | data_freq /= np.std(data_freq, axis=-1)[:, None] 92 | 93 | # zero out nans 94 | data_freq[data_freq!=data_freq] = 0.0 95 | data_freq = data_freq.reshape(dshape) 96 | 97 | if DMTIME is True: 98 | if data_dm.shape[-1] > (th-tl): 99 | data_dm = data_dm[:, :, tl:th] 100 | 101 | if data_dm.shape[-2] > 100: 102 | data_dm = data_dm[:, NDM//2-50:NDM//2+50] 103 | 104 | # tf/keras expects 4D tensors 105 | data_dm = data_dm[..., None] 106 | 107 | if TIME1D is True: 108 | data_1d = data_freq.mean(1)[..., None] 109 | from scipy.signal import detrend 110 | data_1d = detrend(data_1d, axis=1) 111 | 112 | if FREQTIME is True: 113 | # tf/keras expects 4D tensors 114 | data_freq = data_freq[..., None] 115 | 116 | if CLASSIFY_ONLY is False: 117 | # total number of triggers 118 | NTRIGGER = len(y) 119 | 120 | # fraction of true positives vs. total triggers 121 | TP_FRAC = np.float(y.sum()) / NTRIGGER 122 | 123 | # number of events on which to train 124 | NTRAIN = int(train_size * NTRIGGER) 125 | 126 | ind = np.arange(NTRIGGER) 127 | np.random.shuffle(ind) 128 | 129 | ind_train = ind[:NTRAIN] 130 | ind_eval = ind[NTRAIN:] 131 | 132 | train_labels, eval_labels = y[ind_train], y[ind_eval] 133 | 134 | # Convert labels (integers) to binary class matrix 135 | train_labels = frbkeras.keras.utils.to_categorical(train_labels) 136 | eval_labels = frbkeras.keras.utils.to_categorical(eval_labels) 137 | 138 | 139 | if FREQTIME is True: 140 | 141 | if CLASSIFY_ONLY is True: 142 | print("Classifying freq-time data") 143 | model_freq_time_nm = model_nm + 'freq_time.hdf5' 144 | eval_data_list.append(data_freq) 145 | 146 | model_freq_time = frbkeras.load_model(model_freq_time_nm) 147 | y_pred_prob = model_freq_time.predict(data_freq) 148 | y_pred_prob = y_pred_prob[:,1] 149 | y_pred_freq_time = np.round(y_pred_prob) 150 | 151 | ind_frb = np.where(y_pred_prob>prob_threshold)[0] 152 | 153 | frbkeras.print_metric(y, y_pred_freq_time) 154 | 155 | print("\n%d out of %d events with probability > %.2f:\n %s" % 156 | (len(ind_frb), len(y_pred_prob), 157 | prob_threshold, ind_frb)) 158 | 159 | low_to_high_ind = np.argsort(y_pred_prob) 160 | fnout_ranked = fn.rstrip('.hdf5') + 'freq_time_candidates.hdf5' 161 | 162 | eval_data_freq = data_freq #hack 163 | eval_labels = y 164 | 165 | if MK_PLOT is True: 166 | plot_tools.plot_ranked_trigger(data_freq[..., 0], 167 | y_pred_prob[:, None], h=5, w=5, ascending=False, 168 | outname='out') 169 | 170 | print("\nSaved them and all probabilities to: \n%s" % fnout_ranked) 171 | else: 172 | print("Learning frequency-time array") 173 | 174 | # split up data into training and evaluation sets 175 | train_data_freq, eval_data_freq = data_freq[ind_train], data_freq[ind_eval] 176 | 177 | # Build and train 2D CNN 178 | model_freq_time, score_freq_time = frbkeras.construct_conv2d( 179 | features_only=MERGE, fit=True, 180 | train_data=train_data_freq, eval_data=eval_data_freq, 181 | train_labels=train_labels, eval_labels=eval_labels, 182 | epochs=5, nfilt1=32, nfilt2=64, 183 | nfreq=NFREQ, ntime=WIDTH) 184 | 185 | model_list.append(model_freq_time) 186 | train_data_list.append(train_data_freq) 187 | eval_data_list.append(eval_data_freq) 188 | 189 | if save_model is True: 190 | if MERGE is True: 191 | fnout_freqtime = fnout+'freq_time_features.hdf5' 192 | else: 193 | fnout_freqtime = fnout + 'freq_time.hdf5' 194 | model_freq_time.save(fnout_freqtime) 195 | print("Saving freq-time model to: %s" % fnout_freqtime) 196 | 197 | fnout_ranked = fn.rstrip('.hdf5') + 'freq_time_candidates.hdf5' 198 | y_pred_prob = model_freq_time.predict(eval_data_freq) 199 | y_pred_prob = y_pred_prob[:,1] 200 | ind_frb = np.where(y_pred_prob>prob_threshold)[0] 201 | 202 | if save_classification is True: 203 | fnout_ranked = fn.rstrip('.hdf5') + 'freq_time_candidates.hdf5' 204 | g = h5py.File(fnout_ranked, 'w') 205 | g.create_dataset('data_frb_candidate', data=eval_data_freq) 206 | g.create_dataset('frb_index', data=ind_frb) 207 | g.create_dataset('probability', data=y_pred_prob) 208 | g.create_dataset('labels', data=eval_labels) 209 | g.close() 210 | print("\nSaved classification results to: \n%s" % fnout_ranked) 211 | 212 | if DMTIME is True: 213 | 214 | if CLASSIFY_ONLY is True: 215 | print("Classifying dm-time data") 216 | 217 | model_dm_time_nm = model_nm + 'dm_time.hdf5' 218 | eval_data_list.append(data_dm) 219 | 220 | model_dm_time = frbkeras.load_model(model_dm_time_nm) 221 | y_pred_prob = model_dm_time.predict(data_dm) 222 | y_pred_dm_time = np.round(y_pred_prob[:,1]) 223 | 224 | eval_data_dm = data_dm #hack 225 | eval_labels = y 226 | mistakes = np.where(y_pred_dm_time!=y)[0] 227 | print("\nMistakes: %s" % mistakes) 228 | 229 | frbkeras.print_metric(y, y_pred_dm_time) 230 | print("") 231 | else: 232 | print("Learning DM-time array") 233 | # split up data into training and evaluation sets 234 | train_data_dm, eval_data_dm = data_dm[ind_train], data_dm[ind_eval] 235 | 236 | # split up data into training and evaluation sets 237 | train_data_dm, eval_data_dm = data_dm[ind_train], data_dm[ind_eval] 238 | 239 | # Build and train 2D CNN 240 | model_dm_time, score_dm_time = frbkeras.construct_conv2d( 241 | features_only=MERGE, fit=True, 242 | train_data=train_data_dm, eval_data=eval_data_dm, 243 | train_labels=train_labels, eval_labels=eval_labels, 244 | epochs=5, nfilt1=32, nfilt2=64, 245 | nfreq=NDM, ntime=WIDTH) 246 | 247 | model_list.append(model_dm_time) 248 | train_data_list.append(train_data_dm) 249 | eval_data_list.append(eval_data_dm) 250 | 251 | if save_model is True: 252 | if MERGE is True: 253 | fnout_dmtime = fnout+'dm_time_features.hdf5' 254 | else: 255 | fnout_dmtime = fnout+'dm_time.hdf5' 256 | model_dm_time.save(fnout_dmtime) 257 | print("Saving dm-time model to: %s" % fnout_dmtime) 258 | 259 | if save_classification is True: 260 | fnout_ranked = fn.rstrip('.hdf5') + 'dm_time_candidates.hdf5' 261 | g = h5py.File(fnout_ranked, 'w') 262 | g.create_dataset('data_frb_candidate', data=eval_data_dm) 263 | g.create_dataset('frb_index', data=ind_frb) 264 | g.create_dataset('probability', data=y_pred_prob) 265 | g.create_dataset('labels', data=eval_labels) 266 | g.close() 267 | print("\nSaved classification results to: \n%s" % fnout_ranked) 268 | 269 | if TIME1D is True: 270 | 271 | if CLASSIFY_ONLY is True: 272 | print("Classifying pulse profile") 273 | 274 | model_time_nm = model_nm + '1d_time.hdf5' 275 | eval_data_list.append(data_1d) 276 | 277 | model_1d_time = frbkeras.load_model(model_time_nm) 278 | y_pred_prob = model_1d_time.predict(data_1d) 279 | y_pred_time = np.round(y_pred_prob[:,1]) 280 | ind_frb = np.where(y_pred_prob>prob_threshold)[0] 281 | 282 | eval_data_1d = data_1d #hack 283 | eval_labels = y 284 | 285 | print("\nMistakes: %s" % np.where(y_pred_time!=y)[0]) 286 | 287 | frbkeras.print_metric(y, y_pred_time) 288 | print("") 289 | else: 290 | print("Learning pulse profile") 291 | # split up data into training and evaluation sets 292 | train_data_1d, eval_data_1d = data_1d[ind_train], data_1d[ind_eval] 293 | 294 | # Build and train 1D CNN 295 | model_1d_time, score_1d_time = frbkeras.construct_conv1d( 296 | features_only=MERGE, fit=True, 297 | train_data=train_data_1d, eval_data=eval_data_1d, 298 | train_labels=train_labels, eval_labels=eval_labels, 299 | nfilt1=64, nfilt2=128) 300 | 301 | model_list.append(model_1d_time) 302 | train_data_list.append(train_data_1d) 303 | eval_data_list.append(eval_data_1d) 304 | 305 | if save_model is True: 306 | if MERGE is True: 307 | fnout_1dtime = fnout+'1d_time_features.hdf5' 308 | else: 309 | fnout_1dtime = fnout+'1d_time.hdf5' 310 | model_1d_time.save(fnout_1dtime) 311 | print("Saving 1d-time model to: %s" % fnout_1dtime) 312 | 313 | y_pred_prob = model_1d_time.predict(eval_data_1d) 314 | y_pred_prob = y_pred_prob[:,1] 315 | ind_frb = np.where(y_pred_prob>prob_threshold)[0] 316 | 317 | if save_classification is True: 318 | fnout_ranked = fn.rstrip('.hdf5') + '1d_time_candidates.hdf5' 319 | g = h5py.File(fnout_ranked, 'w') 320 | g.create_dataset('data_frb_candidate', data=eval_data_1d) 321 | g.create_dataset('frb_index', data=ind_frb) 322 | g.create_dataset('probability', data=y_pred_prob) 323 | g.create_dataset('labels', data=eval_labels) 324 | g.close() 325 | print("\nSaved classification results to: \n%s" % fnout_ranked) 326 | 327 | if MULTIBEAM is True: 328 | 329 | if CLASSIFY_ONLY is True: 330 | print("Classifying multibeam SNR") 331 | 332 | model_multibeam_nm = model_nm + '_multibeam.hdf5' 333 | eval_data_list.append(data_mb) 334 | 335 | model_1d_multibeam = frbkeras.load_model(model_multibeam_nm) 336 | y_pred_prob = model_1d_multibeam.predict(data_mb) 337 | y_pred_time = np.round(y_pred_prob[:,1]) 338 | 339 | print("\nMistakes: %s" % np.where(y_pred_time!=y)[0]) 340 | 341 | frbkeras.print_metric(y, y_pred_time) 342 | print("") 343 | else: 344 | print("Learning multibeam data") 345 | 346 | # Right now just simulate multibeam, simulate S/N per beam. 347 | import simulate_multibeam as sm 348 | 349 | nbeam = 40 350 | # Simulate a multibeam dataset 351 | data_mb, labels_mb = sm.make_multibeam_data(ntrigger=NTRIGGER) 352 | 353 | data_mb_fp = data_mb[labels_mb[:,1]==0] 354 | data_mb_tp = data_mb[labels_mb[:,1]==1] 355 | 356 | train_data_mb = np.zeros([NTRAIN, nbeam]) 357 | eval_data_mb = np.zeros([NTRIGGER-NTRAIN, nbeam]) 358 | 359 | data_ = np.empty_like(data_mb) 360 | labels_ = np.empty_like(labels_mb) 361 | 362 | kk, ll = 0, 0 363 | for ii in range(NTRAIN): 364 | if train_labels[ii,1]==0: 365 | train_data_mb[ii] = data_mb_fp[kk] 366 | kk+=1 367 | elif train_labels[ii,1]==1: 368 | train_data_mb[ii] = data_mb_tp[ll] 369 | ll+=1 370 | 371 | for ii in range(NTRIGGER-NTRAIN): 372 | if eval_labels[ii,1]==0: 373 | eval_data_mb[ii] = data_mb_fp[kk] 374 | kk+=1 375 | elif eval_labels[ii,1]==1: 376 | eval_data_mb[ii] = data_mb_tp[ll] 377 | ll+=1 378 | 379 | model_mb, score_mb = frbkeras.construct_ff1d( 380 | features_only=MERGE, fit=True, 381 | train_data=train_data_mb, 382 | train_labels=train_labels, 383 | eval_data=eval_data_mb, 384 | eval_labels=eval_labels, 385 | nbeam=nbeam, epochs=5, 386 | nlayer1=32, nlayer2=32, batch_size=32) 387 | 388 | model_list.append(model_mb) 389 | train_data_list.append(train_data_mb) 390 | eval_data_list.append(eval_data_mb) 391 | 392 | if save_model is True: 393 | if MERGE is True: 394 | fnout_mb = fnout+'_multibeam_features.hdf5' 395 | else: 396 | fnout_mb = fnout+'_multibeam.hdf5' 397 | model_mb.save(fnout_mb) 398 | 399 | fnout_ranked = fn.rstrip('.hdf5') + 'multibeam_candidates.hdf5' 400 | y_pred_prob = model_mb.predict(eval_data_mb) 401 | y_pred_prob = y_pred_prob[:,1] 402 | ind_frb = np.where(y_pred_prob>prob_threshold)[0] 403 | 404 | print(fnout_ranked) 405 | print(eval_data_mb.shape) 406 | 407 | g = h5py.File(fnout_ranked, 'w') 408 | g.create_dataset('data_frb_candidate', data=eval_data_mb) 409 | g.create_dataset('frb_index', data=ind_frb) 410 | g.create_dataset('probability', data=y_pred_prob) 411 | g.create_dataset('labels', data=eval_labels) 412 | g.close() 413 | 414 | 415 | if len(model_list)==1: 416 | score = model_list[0].evaluate(eval_data_list[0], eval_labels, batch_size=32) 417 | prob, predictions, mistakes = frbkeras.get_predictions( 418 | model_list[0], eval_data_list[0], 419 | true_labels=eval_labels) 420 | print(mistakes) 421 | print("" % score) 422 | 423 | elif MERGE is True: 424 | 425 | if CLASSIFY_ONLY is True: 426 | print("Classifying merged model") 427 | model_merged_nm = model_nm + '_merged.hdf5' 428 | 429 | model_merged = frbkeras.load_model(model_merged_nm) 430 | y_pred_prob = model_merged.predict(data_list) 431 | y_pred = np.round(y_pred_prob[:,1]) 432 | 433 | print("Mistakes: %s" % np.where(y_pred!=y)[0]) 434 | frbkeras.print_metric(y, y_pred) 435 | print("") 436 | else: 437 | 438 | print("\n=================================") 439 | print(" Merging & training %d models" % len(model_list)) 440 | print("=================================\n") 441 | 442 | model, score = frbkeras.merge_models( 443 | model_list, train_data_list, 444 | train_labels, eval_data_list, eval_labels, 445 | epochs=5) 446 | 447 | prob, predictions, mistakes = frbkeras.get_predictions( 448 | model, eval_data_list, 449 | true_labels=eval_labels[:, 1]) 450 | 451 | 452 | if save_model is True: 453 | fnout_merged = fnout+'_merged.hdf5' 454 | model.save(fnout_merged) 455 | 456 | print("\nMerged NN accuracy: %f" % score[1]) 457 | print("\nIndex of mistakes: %s\n" % mistakes) 458 | frbkeras.print_metric(eval_labels[:, 1], predictions) 459 | 460 | if CLASSIFY_ONLY is False: 461 | print('\n==========Results==========') 462 | try: 463 | print("\nFreq-time accuracy:\n--------------------") 464 | y_pred_prob = model_freq_time.predict(eval_data_freq) 465 | y_pred = np.round(y_pred_prob[:,1]) 466 | tfreq_acc, tfreq_prec, tfreq_rec, tfreq_f = frbkeras.print_metric(eval_labels[:,1], y_pred) 467 | 468 | mistakes_freq = np.where(y_pred!=eval_labels[:,1])[0] 469 | print("\nMistakes: %s" % mistakes_freq) 470 | except: 471 | pass 472 | try: 473 | print("\nDM-time accuracy:\n--------------------") 474 | y_pred_prob = model_dm_time.predict(eval_data_dm) 475 | y_pred = np.round(y_pred_prob[:,1]) 476 | dm_acc, dm_prec, dm_rec, dm_f = frbkeras.print_metric(eval_labels[:,1], y_pred) 477 | 478 | mistakes_dm = np.where(y_pred!=eval_labels[:,1])[0] 479 | # np.save('data_dm_mistakes', eval_data_dm[mistakes]) 480 | print("\nMistakes: %s" % mistakes_dm) 481 | except: 482 | pass 483 | try: 484 | print("\nPulse-profile Results:\n--------------------") 485 | y_pred_prob = model_1d_time.predict(eval_data_1d) 486 | y_pred = np.round(y_pred_prob[:,1]) 487 | pp_acc, pp_prec, pp_rec, pp_f = frbkeras.print_metric(eval_labels[:,1], y_pred) 488 | 489 | mistakes_1d = np.where(y_pred!=eval_labels[:,1])[0] 490 | # np.save('data_1d_mistakes', eval_1d_dm[mistakes]) 491 | print("\nMistakes: %s" % mistakes_1d) 492 | except: 493 | pass 494 | try: 495 | print("\nMultibeam Results:\n--------------------") 496 | y_pred_prob = model_mb.predict(eval_data_mb) 497 | y_pred = np.round(y_pred_prob[:,1]) 498 | mb_acc, mb_prec, mb_rec, mb_f = frbkeras.print_metric(eval_labels[:,1], y_pred) 499 | print("\nMistakes: %s" % np.where(y_pred!=eval_labels[:,1])[0]) 500 | except: 501 | pass 502 | 503 | 504 | 505 | 506 | 507 | 508 | -------------------------------------------------------------------------------- /single_pulse_ml/sim_parameters.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import h5py 6 | 7 | class SimParams: 8 | 9 | def __init__(self, dm=(-0.01, 0.01), fluence=(0.1, 0.3), 10 | width=(3*0.0016, 0.75), spec_ind=(-3., 3.), 11 | disp_ind=2., scat_factor=(-4., -1.), NRFI=None, NSIM=None, 12 | SNR_MIN=10., SNR_MAX=100., out_file_name=None, 13 | NTIME=250, NFREQ=16, mk_plot=False, NSIDE=8): 14 | 15 | self._dm = dm 16 | self._fluence = fluence 17 | self._width = width 18 | self._spec_ind = spec_ind 19 | self._disp_ind = disp_ind 20 | self._scat_factor = scat_factor 21 | 22 | self._SNR_MIN = SNR_MIN 23 | self._SNR_MAX = SNR_MAX 24 | self._NTIME = NTIME 25 | self._NFREQ = NFREQ 26 | self._out_file_name = out_file_name 27 | 28 | self._NRFI = NRFI 29 | self._NSIM = NSIM 30 | self.data_rfi = None 31 | self.y = None # FP labels 32 | 33 | self._mk_plot = mk_plot 34 | self._NSIDE = NSIDE 35 | 36 | def generate_noise(self): 37 | y = np.zeros([self._NRFI]) 38 | noise = np.random.normal(0, 1, self._NRFI*self._NTIME*self._NFREQ) 39 | noise = noise.reshape(-1, self._NFREQ*self._NTIME) 40 | self._NSIM = self._NRFI 41 | 42 | return noise, y 43 | 44 | def get_false_positives(self, fn): 45 | 46 | ftype = fn.split('.')[-1] 47 | 48 | if ftype in ('hdf5', 'h5'): 49 | f = h5py.File(fn) 50 | data_rfi = f['data_freq_time'][:] 51 | data_rfi = data_rfi.reshape(len(data_rfi), -1) 52 | y = f['labels'][:] 53 | elif ftype in ('npy',): 54 | f_rfi = np.load(fn) 55 | # Important step! Need to scramble RFI triggers. 56 | np.random.shuffle(f_rfi) 57 | # Read in data array and labels from RFI file 58 | data_rfi, y = f_rfi[:, :-1], f_rfi[:, -1] 59 | else: 60 | return 61 | 62 | if self._NRFI is not None: 63 | if self._NSIM is None: 64 | self._NSIM = self._NRFI 65 | 66 | self.data_rfi = data_rfi[:self._NRFI] 67 | self.y = y[:self._NRFI] 68 | else: 69 | self._NRFI = len(y) 70 | self._NSIM = self._NRFI 71 | self.data_rfi = data_rfi[:self._NSIM] 72 | self.y = y[:self._NSIM] 73 | 74 | return data_rfi, y 75 | 76 | def write_sim_data(self, data_freq_time, labels, fnout, 77 | data_dm_time=None, params=None, snr=None, 78 | ): 79 | 80 | ftype = fnout.split('.')[-1] 81 | 82 | if os.path.exists(fnout): 83 | t0_str = time.strftime("_%Y_%m_%d_%H:%M:%S", time.gmtime()) 84 | fnout = fnout.split(ftype)[0][:-1] + t0_str + '.' + ftype 85 | 86 | if ftype in ('hdf5', 'h5'): 87 | 88 | f = h5py.File(fnout) 89 | f.create_dataset('data_freq_time', data=data_freq_time) 90 | f.create_dataset('labels', data=labels) 91 | 92 | if data_dm_time is not None: 93 | f.create_dataset('data_dm_time', data=data_dm_time) 94 | if params is not None: 95 | f.create_dataset('params', data=params) 96 | if snr is not None: 97 | f.create_dataset('snr', data=snr) 98 | 99 | f.close() 100 | 101 | elif ftype in ('npy'): 102 | np.save(fnout, data) 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /single_pulse_ml/simulate_frb.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import glob 5 | from scipy import signal 6 | 7 | try: 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | except: 12 | plt = None 13 | pass 14 | 15 | from single_pulse_ml import reader 16 | from single_pulse_ml import dataproc 17 | from single_pulse_ml import tools 18 | 19 | try: 20 | from single_pulse_ml import plot_tools 21 | except: 22 | plot_tools = None 23 | 24 | 25 | class Event(object): 26 | """ Class to generate a realistic fast radio burst and 27 | add the event to data, including scintillation, temporal 28 | scattering, spectral index variation, and DM smearing. 29 | 30 | This class was expanded from real-time FRB injection 31 | in Kiyoshi Masui's 32 | https://github.com/kiyo-masui/burst\_search 33 | """ 34 | def __init__(self, t_ref, f_ref, dm, fluence, width, 35 | spec_ind, disp_ind=2, scat_factor=0): 36 | self._t_ref = t_ref 37 | self._f_ref = f_ref 38 | self._dm = dm 39 | self._fluence = fluence 40 | self._width = width 41 | self._spec_ind = spec_ind 42 | self._disp_ind = disp_ind 43 | self._scat_factor = min(1, scat_factor + 1e-18) # quick bug fix hack 44 | 45 | def disp_delay(self, f, _dm, _disp_ind=-2.): 46 | """ Calculate dispersion delay in seconds for 47 | frequency,f, in MHz, _dm in pc cm**-3, and 48 | a dispersion index, _disp_ind. 49 | """ 50 | return 4.148808e3 * _dm * (f**(-_disp_ind)) 51 | 52 | def arrival_time(self, f): 53 | t = self.disp_delay(f, self._dm, self._disp_ind) 54 | t = t - self.disp_delay(self._f_ref, self._dm, self._disp_ind) 55 | return self._t_ref + t 56 | 57 | def calc_width(self, dm, freq_c, bw=400.0, NFREQ=1024, 58 | ti=0.001, tsamp=0.001, tau=0): 59 | """ Calculated effective width of pulse 60 | including DM smearing, sample time, etc. 61 | Input/output times are in seconds. 62 | """ 63 | 64 | ti *= 1e3 65 | tsamp *= 1e3 66 | delta_freq = bw/NFREQ 67 | 68 | # taudm in milliseconds 69 | tdm = 8.3e-3 * dm * delta_freq / freq_c**3 70 | tI = np.sqrt(ti**2 + tsamp**2 + tdm**2 + tau**2) 71 | 72 | return 1e-3*tI 73 | 74 | def dm_smear(self, DM, freq_c, bw=400.0, NFREQ=1024, 75 | ti=1, tsamp=0.0016, tau=0): 76 | """ Calculate DM smearing SNR reduction 77 | """ 78 | tau *= 1e3 # make ms 79 | ti *= 1e3 80 | tsamp *= 1e3 81 | 82 | delta_freq = bw / NFREQ 83 | 84 | tI = np.sqrt(ti**2 + tsamp**2 + (8.3 * DM * delta_freq / freq_c**3)**2) 85 | 86 | return (np.sqrt(ti**2 + tau**2) / tI)**0.5 87 | 88 | def scintillation(self, freq): 89 | """ Include spectral scintillation across 90 | the band. Approximate effect as a sinusoid, 91 | with a random phase and a random decorrelation 92 | bandwidth. 93 | """ 94 | # Make location of peaks / troughs random 95 | scint_phi = np.random.rand() 96 | f = np.linspace(0, 1, len(freq)) 97 | 98 | # Make number of scintils between 0 and 10 (ish) 99 | nscint = np.exp(np.random.uniform(np.log(1e-3), np.log(7))) 100 | 101 | if nscint<1: 102 | nscint = 0 103 | # envelope = np.cos(nscint*(freq - self._f_ref)/self._f_ref + scint_phi) 104 | envelope = np.cos(2*np.pi*nscint*freq**-2/self._f_ref**-2 + scint_phi) 105 | envelope[envelope<0] = 0 106 | return envelope 107 | 108 | def gaussian_profile(self, nt, width, t0=0.): 109 | """ Use a normalized Gaussian window for the pulse, 110 | rather than a boxcar. 111 | """ 112 | t = np.linspace(-nt//2, nt//2, nt) 113 | g = np.exp(-(t-t0)**2 / width**2) 114 | 115 | if not np.all(g > 0): 116 | g += 1e-18 117 | 118 | g /= g.max() 119 | 120 | return g 121 | 122 | def scat_profile(self, nt, f, tau=1.): 123 | """ Include exponential scattering profile. 124 | """ 125 | tau_nu = tau * (f / self._f_ref)**-4. 126 | t = np.linspace(0., nt//2, nt) 127 | 128 | prof = 1 / tau_nu * np.exp(-t / tau_nu) 129 | return prof / prof.max() 130 | 131 | def pulse_profile(self, nt, width, f, tau=100., t0=0.): 132 | """ Convolve the gaussian and scattering profiles 133 | for final pulse shape at each frequency channel. 134 | """ 135 | gaus_prof = self.gaussian_profile(nt, width, t0=t0) 136 | scat_prof = self.scat_profile(nt, f, tau) 137 | # pulse_prof = np.convolve(gaus_prof, scat_prof, mode='full')[:nt] 138 | pulse_prof = signal.fftconvolve(gaus_prof, scat_prof)[:nt] 139 | 140 | return pulse_prof 141 | 142 | def add_to_data(self, delta_t, freq, data, scintillate=True): 143 | """ Method to add already-dedispersed pulse 144 | to background noise data. Includes frequency-dependent 145 | width (smearing, scattering, etc.) and amplitude 146 | (scintillation, spectral index). 147 | """ 148 | 149 | NFREQ = data.shape[0] 150 | NTIME = data.shape[1] 151 | tmid = NTIME//2 152 | 153 | scint_amp = self.scintillation(freq) 154 | self._fluence /= np.sqrt(NFREQ) 155 | stds = np.std(data) 156 | roll_ind = int(np.random.normal(0, 2)) 157 | 158 | for ii, f in enumerate(freq): 159 | width_ = self.calc_width(self._dm, self._f_ref*1e-3, 160 | bw=400.0, NFREQ=NFREQ, 161 | ti=self._width, tsamp=delta_t, tau=0) 162 | 163 | index_width = max(1, (np.round((width_/ delta_t))).astype(int)) 164 | tpix = int(self.arrival_time(f) / delta_t) 165 | 166 | if abs(tpix) >= tmid: 167 | # ensure that edges of data are not crossed 168 | continue 169 | 170 | pp = self.pulse_profile(NTIME, index_width, f, 171 | tau=self._scat_factor, t0=tpix) 172 | val = pp.copy() 173 | val /= (val.max()*stds) 174 | val *= self._fluence 175 | val /= (width_ / delta_t) 176 | val = val * (f / self._f_ref) ** self._spec_ind 177 | 178 | if scintillate is True: 179 | val = (0.1 + scint_amp[ii]) * val 180 | 181 | data[ii] += val 182 | data[ii] = np.roll(data[ii], roll_ind) 183 | 184 | def dm_transform(self, delta_t, data, freq, maxdm=5.0, NDM=50): 185 | """ Transform freq/time data to dm/time data. 186 | """ 187 | 188 | if len(freq)<3: 189 | NFREQ = data.shape[0] 190 | freq = np.linspace(freq[0], freq[1], NFREQ) 191 | 192 | dm = np.linspace(-maxdm, maxdm, NDM) 193 | ndm = len(dm) 194 | ntime = data.shape[-1] 195 | 196 | data_full = np.zeros([ndm, ntime]) 197 | 198 | for ii, dm in enumerate(dm): 199 | for jj, f in enumerate(freq): 200 | self._dm = dm 201 | tpix = int(self.arrival_time(f) / delta_t) 202 | data_rot = np.roll(data[jj], tpix, axis=-1) 203 | data_full[ii] += data_rot 204 | 205 | return data_full 206 | 207 | class EventSimulator(): 208 | """Generates simulated fast radio bursts. 209 | Events occurrences are drawn from a Poissonian distribution. 210 | 211 | 212 | This class was expanded from real-time FRB injection 213 | in Kiyoshi Masui's 214 | https://github.com/kiyo-masui/burst\_search 215 | """ 216 | 217 | def __init__(self, dm=(0.,2000.), fluence=(0.03,0.3), 218 | width=(2*0.0016, 1.), spec_ind=(-4.,4), 219 | disp_ind=2., scat_factor=(0, 0.5), freq=(800., 400.)): 220 | """ 221 | Parameters 222 | ---------- 223 | datasource : datasource.DataSource object 224 | Source of the data, specifying the data rate and band parameters. 225 | dm : float or pair of floats 226 | Burst dispersion measure or dispersion measure range (pc cm^-2). 227 | fluence : float or pair of floats 228 | Burst fluence (at band centre) or fluence range (s). 229 | width : float or pair of floats. 230 | Burst width or width range (s). 231 | spec_ind : float or pair of floats. 232 | Burst spectral index or spectral index range. 233 | disp_ind : float or pair of floats. 234 | Burst dispersion index or dispersion index range. 235 | freq : tuple 236 | Min and max of frequency range in MHz. Assumes low freq 237 | is first freq in array, not necessarily the lowest value. 238 | 239 | """ 240 | 241 | self.width = width 242 | self.freq_low = freq[0] 243 | self.freq_up = freq[1] 244 | 245 | if hasattr(dm, '__iter__') and len(dm) == 2: 246 | self._dm = tuple(dm) 247 | else: 248 | self._dm = (float(dm), float(dm)) 249 | if hasattr(fluence, '__iter__') and len(fluence) == 2: 250 | fluence = (fluence[1]**-1, fluence[0]**-1) 251 | self._fluence = tuple(fluence) 252 | else: 253 | self._fluence = (float(fluence)**-1, float(fluence)**-1) 254 | if hasattr(width, '__iter__') and len(width) == 2: 255 | self._width = tuple(width) 256 | else: 257 | self._width = (float(width), float(width)) 258 | if hasattr(spec_ind, '__iter__') and len(spec_ind) == 2: 259 | self._spec_ind = tuple(spec_ind) 260 | else: 261 | self._spec_ind = (float(spec_ind), float(spec_ind)) 262 | if hasattr(disp_ind, '__iter__') and len(disp_ind) == 2: 263 | self._disp_ind = tuple(disp_ind) 264 | else: 265 | self._disp_ind = (float(disp_ind), float(disp_ind)) 266 | if hasattr(scat_factor, '__iter__') and len(scat_factor) == 2: 267 | self._scat_factor = tuple(scat_factor) 268 | else: 269 | self._scat_factor = (float(scat_factor), float(scat_factor)) 270 | 271 | # self._freq = datasource.freq 272 | # self._delta_t = datasource.delta_t 273 | 274 | self._freq = np.linspace(self.freq_low, self.freq_up, 256) # tel parameter 275 | 276 | def draw_event_parameters(self): 277 | dm = uniform_range(*self._dm) 278 | fluence = uniform_range(*self._fluence)**(-2/3.) 279 | # Convert to Jy ms from Jy s 280 | fluence *= 1e3*self._fluence[0]**(-2/3.) 281 | spec_ind = uniform_range(*self._spec_ind) 282 | disp_ind = uniform_range(*self._disp_ind) 283 | # turn this into a log uniform dist. Note not *that* many 284 | # FRBs have been significantly scattered. Should maybe turn this 285 | # knob down. 286 | scat_factor = np.exp(np.random.uniform(*self._scat_factor)) 287 | # change width from uniform to lognormal 288 | width = np.random.lognormal(np.log(self._width[0]), self._width[1]) 289 | width = max(min(width, 100*self._width[0]), 0.5*self._width[0]) 290 | return dm, fluence, width, spec_ind, disp_ind, scat_factor 291 | 292 | def uniform_range(min_, max_): 293 | return random.uniform(min_, max_) 294 | 295 | 296 | def gen_simulated_frb(NFREQ=16, NTIME=250, sim=True, fluence=(0.03,0.3), 297 | spec_ind=(-4, 4), width=(2*0.0016, 1), dm=(-0.01, 0.01), 298 | scat_factor=(-3, -0.5), background_noise=None, delta_t=0.0016, 299 | plot_burst=False, freq=(800, 400), FREQ_REF=600., scintillate=True, 300 | ): 301 | """ Simulate fast radio bursts using the EventSimulator class. 302 | 303 | Parameters 304 | ---------- 305 | NFREQ : np.int 306 | number of frequencies for simulated array 307 | NTIME : np.int 308 | number of times for simulated array 309 | sim : bool 310 | whether or not to simulate FRB or just create noise array 311 | spec_ind : tuple 312 | range of spectral index 313 | width : tuple 314 | range of widths in seconds (atm assumed dt=0.0016) 315 | scat_factor : tuple 316 | range of scattering measure (atm arbitrary units) 317 | background_noise : 318 | if None, simulates white noise. Otherwise should be an array (NFREQ, NTIME) 319 | plot_burst : bool 320 | generates a plot of the simulated burst 321 | 322 | Returns 323 | ------- 324 | data : np.array 325 | data array (NFREQ, NTIME) 326 | parameters : tuple 327 | [dm, fluence, width, spec_ind, disp_ind, scat_factor] 328 | 329 | """ 330 | plot_burst = False 331 | 332 | # Hard code incoherent Pathfinder data time resolution 333 | # Maybe instead this should take a telescope class, which 334 | # has all of these things already. 335 | t_ref = 0. # hack 336 | 337 | if len(freq) < 3: 338 | freq=np.linspace(freq[0], freq[1], NFREQ) 339 | 340 | if background_noise is None: 341 | # Generate background noise with unit variance 342 | data = np.random.normal(0, 1, NTIME*NFREQ).reshape(NFREQ, NTIME) 343 | else: 344 | data = background_noise 345 | 346 | # What about reading in noisy background? 347 | if sim is False: 348 | return data, [] 349 | 350 | # Call class using parameter ranges 351 | ES = EventSimulator(dm=dm, scat_factor=scat_factor, fluence=fluence, 352 | width=width, spec_ind=spec_ind) 353 | # Realize event parameters for a single FRB 354 | dm, fluence, width, spec_ind, disp_ind, scat_factor = ES.draw_event_parameters() 355 | # Create event class with those parameters 356 | E = Event(t_ref, FREQ_REF, dm, 10e-4*fluence, 357 | width, spec_ind, disp_ind, scat_factor) 358 | # Add FRB to data array 359 | data -= np.median(data) 360 | data /= np.std(data) 361 | 362 | E.add_to_data(delta_t, freq, data, scintillate=scintillate) 363 | 364 | if plot_burst: 365 | subplot(211) 366 | imshow(data.reshape(-1, NTIME), aspect='auto', 367 | interpolation='nearest', vmin=0, vmax=10) 368 | subplot(313) 369 | plot(data.reshape(-1, ntime).mean(0)) 370 | 371 | return data, [dm, fluence, width, spec_ind, disp_ind, scat_factor] 372 | 373 | 374 | def inject_in_filterbank_background(fn_fil): 375 | """ Inject an FRB in each chunk of data 376 | at random times. Default params are for Apertif data. 377 | """ 378 | 379 | chunksize = 5e5 380 | ii=0 381 | 382 | data_full =[] 383 | nchunks = 250 384 | nfrb_chunk = 8 385 | chunksize = 2**16 386 | 387 | for ii in range(nchunks): 388 | downsamp = 2**((np.random.rand(nfrb_chunk)*6).astype(int)) 389 | 390 | try: 391 | # drop FRB in random location in data chunk 392 | rawdatafile = filterbank.filterbank(fn_fil) 393 | dt = rawdatafile.header['tsamp'] 394 | freq_up = rawdatafile.header['fch1'] 395 | nfreq = rawdatafile.header['nchans'] 396 | freq_low = freq_up + nfreq*rawdatafile.header['foff'] 397 | data = rawdatafile.get_spectra(ii*chunksize, chunksize) 398 | except: 399 | continue 400 | 401 | 402 | #dms = np.random.uniform(50, 750, nfrb_chunk) 403 | dm0 = np.random.uniform(90, 750) 404 | end_width = abs(4e3 * dm0 * (freq_up**-2 - freq_low**-2)) 405 | data.dedisperse(dm0) 406 | NFREQ, NT = data.data.shape 407 | 408 | print("Chunk %d with DM=%.1f" % (ii, dm0)) 409 | for jj in xrange(nfrb_chunk): 410 | if 8192*(jj+1) > (NT - end_width): 411 | print("Skipping at ", 8192*(jj+1)) 412 | continue 413 | data_event = data.data[:, jj*8192:(jj+1)*8192] 414 | data_event = data_event.reshape(NFREQ, -1, downsamp[jj]).mean(-1) 415 | print(data_event.shape) 416 | data_event = data_event.reshape(32, 48, -1).mean(1) 417 | 418 | NTIME = data_event.shape[-1] 419 | data_event = data_event[..., NTIME//2-125:NTIME//2+125] 420 | data_event -= np.mean(data_event, axis=-1, keepdims=True) 421 | data_full.append(data_event) 422 | 423 | data_full = np.concatenate(data_full) 424 | data_full = data_full.reshape(-1, 32, 250) 425 | 426 | np.save('data_250.npy', data_full) 427 | 428 | 429 | def inject_in_filterbank(fn_fil, fn_fil_out, N_FRBs=1, 430 | NFREQ=1536, NTIME=2**15): 431 | """ Inject an FRB in each chunk of data 432 | at random times. Default params are for Apertif data. 433 | """ 434 | 435 | chunksize = 5e5 436 | ii=0 437 | 438 | params_full_arr = [] 439 | 440 | for ii in xrange(N_FRBs): 441 | start, stop = chunksize*ii, chunksize*(ii+1) 442 | # drop FRB in random location in data chunk 443 | offset = int(np.random.uniform(0.1*chunksize, 0.9*chunksize)) 444 | 445 | data, freq, delta_t, header = reader.read_fil_data(fn_fil, 446 | start=start, stop=stop) 447 | 448 | # injected pulse time in seconds since start of file 449 | t0_ind = offset+NTIME//2+chunksize*ii 450 | t0 = t0_ind * delta_t 451 | 452 | if len(data[0])==0: 453 | break 454 | 455 | data_event = (data[offset:offset+NTIME].transpose()).astype(np.float) 456 | 457 | data_event, params = gen_simulated_frb(NFREQ=NFREQ, 458 | NTIME=NTIME, sim=True, fluence=(0.01, 1.), 459 | spec_ind=(-4, 4), width=(delta_t, 2), 460 | dm=(100, 1000), scat_factor=(-4, -0.5), 461 | background_noise=data_event, 462 | delta_t=delta_t, plot_burst=False, 463 | freq=(1550, 1250), 464 | FREQ_REF=1550.) 465 | 466 | params.append(offset) 467 | print("Injecting with DM:%f width: %f offset: %d" % 468 | (params[0], params[2], offset)) 469 | 470 | data[offset:offset+NTIME] = data_event.transpose() 471 | 472 | #params_full_arr.append(params) 473 | width = params[2] 474 | downsamp = max(1, int(width/delta_t)) 475 | 476 | params_full_arr.append([params[0], 20.0, t0, t0_ind, downsamp]) 477 | 478 | if ii==0: 479 | fn_rfi_clean = reader.write_to_fil(data, header, fn_fil_out) 480 | elif ii>0: 481 | fil_obj = reader.filterbank.FilterbankFile(fn_fil_out, mode='readwrite') 482 | fil_obj.append_spectra(data) 483 | 484 | del data 485 | 486 | params_full_arr = np.array(params_full_arr) 487 | 488 | np.savetxt('/home/arts/connor/arts-analysis/simulated.singlepulse', params_full_arr) 489 | 490 | return params_full_arr 491 | 492 | # a, p = gen_simulated_frb(NFREQ=1536, NTIME=2**15, sim=True, fluence=(2), 493 | # spec_ind=(-4, 4), width=(dt), dm=(40.0), 494 | # scat_factor=(-3, -0.5), background_noise=None, delta_t=dt, 495 | # plot_burst=False, freq=(1550, 1250), FREQ_REF=1400., 496 | # # ) 497 | 498 | # a, p = gen_simulated_frb(NFREQ=32, NTIME=250, sim=True, fluence=(5, 100), 499 | # spec_ind=(-4, 4), width=(dt, 1), dm=(-0.1, 0.1), 500 | # scat_factor=(-3, -0.5), background_noise=None, delta_t=dt, 501 | # plot_burst=False, freq=(800, 400), FREQ_REF=600., 502 | # ) 503 | 504 | 505 | def run_full_simulation(sim_obj, tel_obj, mk_plot=False, 506 | fn_rfi='./data/all_RFI_8001.npy', 507 | fn_noise=None, 508 | ftype='hdf5', dm_time_array=True, 509 | outname_tag='', outdir = './data/', 510 | figname='./plots/simulated_frb.pdf'): 511 | 512 | outfn = outdir + "data_nt%d_nf%d_dm%d_snr%d-%d_%s.%s" \ 513 | % (sim_obj._NTIME, sim_obj._NFREQ, 514 | round(max(sim_obj._dm)), sim_obj._SNR_MIN, 515 | sim_obj._SNR_MAX, outname_tag, ftype) 516 | 517 | if fn_rfi is not None: 518 | data_rfi, y = sim_obj.get_false_positives(fn_rfi) 519 | else: 520 | data_rfi, y = sim_obj.generate_noise() 521 | 522 | if fn_noise is not None: 523 | noise_arr = np.load(fn_noise) # Hack 524 | 525 | sim_obj._NRFI = min(sim_obj._NRFI, data_rfi.shape[0]) 526 | print("\nUsing %d false-positive triggers" % sim_obj._NRFI) 527 | print("Simulating %d FRBs\n" % sim_obj._NSIM) 528 | 529 | arr_sim_full = [] # data array with all events 530 | yfull = [] # label array FP=0, TP=1 531 | arr_dm_time_full = [] 532 | 533 | params_full_arr = [] 534 | width_full_arr = [] 535 | 536 | snr = [] # Keep track of simulated FRB signal-to-noise 537 | ii = -1 538 | jj = 0 539 | 540 | # Loop through total number of events 541 | while jj < (sim_obj._NRFI + sim_obj._NSIM): 542 | jj = len(arr_sim_full) 543 | ii += 1 544 | if ii % 500 == 0: 545 | print("simulated:%d kept:%d" % (ii, jj)) 546 | 547 | # If ii is greater than the number of RFI events in f, 548 | # simulate an FRB 549 | #sim = bool(ii >= NRFI) 550 | 551 | if ii < sim_obj._NRFI: 552 | data = data_rfi[ii].reshape(sim_obj._NFREQ, sim_obj._NTIME) 553 | 554 | # Normalize data to have unit variance and zero median 555 | data = reader.rebin_arr(data, sim_obj._NFREQ, sim_obj._NTIME) 556 | data = dataproc.normalize_data(data) 557 | 558 | arr_sim_full.append(data.reshape(sim_obj._NFREQ*sim_obj._NTIME)[None]) 559 | yfull.append(0) # Label the RFI with '0' 560 | continue 561 | 562 | elif (ii >=sim_obj._NRFI and jj < (sim_obj._NRFI + sim_obj._NSIM)): 563 | 564 | if fn_noise is not None: 565 | noise_ind = (jj-sim_obj._NRFI) % len(noise_arr) # allow for roll-over 566 | noise = (noise_arr[noise_ind]).copy() 567 | noise[noise!=noise] = 0.0 568 | noise -= np.median(noise, axis=-1)[..., None] 569 | noise -= np.median(noise) 570 | noise /= np.std(noise) 571 | # noise[:, 21] = 0 # hack mask out single bad channel 572 | else: 573 | noise = None 574 | 575 | # maybe should feed gen_sim a tel object and 576 | # a set of burst parameters... 577 | arr_sim, params = gen_simulated_frb(NFREQ=sim_obj._NFREQ, 578 | NTIME=sim_obj._NTIME, 579 | delta_t=tel_obj._DELTA_T, 580 | freq=tel_obj._freq, 581 | FREQ_REF=tel_obj._FREQ_REF, 582 | spec_ind=sim_obj._spec_ind, 583 | width=sim_obj._width, 584 | scat_factor=sim_obj._scat_factor, 585 | dm=sim_obj._dm, 586 | fluence=sim_obj._fluence, 587 | background_noise=noise, 588 | plot_burst=False, 589 | sim=True, 590 | ) 591 | 592 | # Normalize data to have unit variance and zero median 593 | arr_sim = reader.rebin_arr(arr_sim, sim_obj._NFREQ, sim_obj._NTIME) 594 | arr_sim = dataproc.normalize_data(arr_sim) 595 | # get SNR of simulated pulse. Center should be at ntime//2 596 | # rebin until max SNR is found. 597 | snr_ = tools.calc_snr(arr_sim.mean(0), fast=False) 598 | 599 | # Only use events within a range of signal-to-noise 600 | if snr_ > sim_obj._SNR_MIN and snr_ < sim_obj._SNR_MAX: 601 | arr_sim_full.append(arr_sim.reshape(-1, sim_obj._NFREQ*sim_obj._NTIME)) 602 | yfull.append(1) # Label the simulated FRB with '1' 603 | params_full_arr.append(params) # Save parameters bursts 604 | snr.append(snr_) 605 | continue 606 | else: 607 | continue 608 | 609 | if dm_time_array is True: 610 | E = Event(0, tel_obj._FREQ_REF, 0.0, 1.0, tel_obj._DELTA_T, 0., ) 611 | 612 | for ii, data in enumerate(arr_sim_full): 613 | if ii%500==0: 614 | print("DM-transformed:%d" % ii) 615 | 616 | data = data.reshape(-1, sim_obj._NTIME) 617 | data = dataproc.normalize_data(data) 618 | data_dm_time = E.dm_transform(tel_obj._DELTA_T, data, tel_obj._freq) 619 | data_dm_time = dataproc.normalize_data(data_dm_time) 620 | arr_dm_time_full.append(data_dm_time) 621 | 622 | NDM = data_dm_time.shape[0] 623 | arr_dm_time_full = np.concatenate(arr_dm_time_full) 624 | arr_dm_time_full = arr_dm_time_full.reshape(-1, NDM, sim_obj._NTIME) 625 | else: 626 | data_dm_time_full = None 627 | 628 | params_full_arr = np.concatenate(params_full_arr).reshape(-1, 6) 629 | snr = np.array(snr) 630 | yfull = np.array(yfull) 631 | 632 | arr_sim_full = np.concatenate(arr_sim_full, axis=-1) 633 | arr_sim_full = arr_sim_full.reshape(-1, sim_obj._NFREQ*sim_obj._NTIME) 634 | 635 | print("\nGenerated %d simulated FRBs with mean SNR: %f" 636 | % (sim_obj._NSIM, snr.mean())) 637 | print("Used %d RFI triggers" % sim_obj._NRFI) 638 | print("Total triggers with SNR>10: %d" % arr_sim_full.shape[0]) 639 | 640 | if ftype is 'hdf5': 641 | arr_sim_full = arr_sim_full.reshape(-1, sim_obj._NFREQ, sim_obj._NTIME) 642 | sim_obj.write_sim_data(arr_sim_full, yfull, outfn, 643 | data_dm_time=arr_dm_time_full, 644 | params=params_full_arr, 645 | snr=snr) 646 | print("Saving training/label data to:\n%s" % outfn) 647 | else: 648 | full_label_arr = np.concatenate((arr_sim_full, yfull[:, None]), axis=-1) 649 | print("Saving training/label data to:\n%s" % outfn) 650 | 651 | # save down the training data with labels 652 | np.save(outfn, full_label_arr) 653 | 654 | if plt==None: 655 | mk_plot = False 656 | 657 | if sim_obj._mk_plot==True: 658 | kk=0 659 | 660 | plot_tools.plot_simulated_events( 661 | arr_sim_full, y, figname, 662 | sim_obj._NSIDE, sim_obj._NFREQ, 663 | sim_obj._NTIME, cmap='Greys') 664 | 665 | return arr_sim_full, yfull, params_full_arr, snr 666 | 667 | 668 | 669 | 670 | -------------------------------------------------------------------------------- /single_pulse_ml/simulate_multibeam.py: -------------------------------------------------------------------------------- 1 | # Script for simulating multi-beam detections 2 | # 5 December 2017 3 | # Liam Connor 4 | import sys 5 | 6 | import numpy as np 7 | from numpy.random import seed 8 | import h5py 9 | 10 | import keras 11 | from keras.models import Sequential 12 | from keras.layers import Dense, Dropout, Flatten, Merge 13 | from keras.layers import Conv1D, Conv2D 14 | from keras.layers import MaxPooling2D, MaxPooling1D, GlobalAveragePooling1D 15 | from keras.optimizers import SGD 16 | from keras.models import load_model 17 | 18 | import frbkeras 19 | 20 | def gauss(x, xo, sig): 21 | return np.exp(-(x-xo)**2/sig**2) 22 | 23 | def generate_multibeam(nbeam=40, rows=8, cols=5, width=27, nside=1000): 24 | """ width in arcminutes 25 | """ 26 | # convert arcminutes to degrees 27 | width /= 60. 28 | 29 | # theta in degrees 30 | theta = np.linspace(-1, 1, 100) 31 | 32 | # compute 1D gaussian beam 33 | beam_theta = gauss(theta, 0, width) 34 | 35 | # compute 1D beam outer product with itself for 2D 36 | beam_2d = beam_theta[None]*beam_theta[:, None] 37 | 38 | # create nbeam arrays 39 | beam_arr = np.zeros([nside, nside, nbeam]) 40 | 41 | # Make each beam 42 | kk=0 43 | for ii in range(rows): 44 | for jj in range(cols): 45 | # get x,y coordinates of each beam center 46 | xx, yy = 500-4*50+ii*50, 500-2*50+jj*50 47 | beam_arr[xx:xx+100, yy:yy+100, kk] += beam_2d 48 | kk+=1 49 | 50 | return beam_arr 51 | 52 | def test_merge_model(n=32, m=64, ntrigger=10000): 53 | data = np.random.normal(0, 1, n*m*ntrigger).reshape(ntrigger, n, m) 54 | data[ntrigger//2:, :, m//2-2:m//2+1] += 0.25 55 | data /= np.std(data.reshape(-1, n*m), -1)[:, None, None] 56 | data -= np.median(data, 2)[:, :, None] 57 | 58 | # set RFI labels to 0, FRBs to 1 59 | labels = np.zeros([ntrigger]) 60 | labels[ntrigger//2:] = 1 61 | 62 | # convert to categorical array with shape (-1, 2) 63 | labels = labels.astype(int) 64 | labels = keras.utils.to_categorical(labels) 65 | 66 | data = data[..., None] 67 | 68 | model_2d_freq_time, score_freq_time = frbkeras.construct_conv2d( 69 | features_only=False, fit=True, 70 | train_data=data[::2], eval_data=data[1::2], 71 | train_labels=labels[::2], eval_labels=labels[1::2], 72 | epochs=5, nfilt1=32, nfilt2=64, 73 | nfreq=n, ntime=m) 74 | print(score_freq_time) 75 | 76 | train_data_mb, train_labels, eval_data_mb, eval_labels, model_mb = run_model(ntrigger) 77 | 78 | model_list = [model_mb, model_2d_freq_time] 79 | train_data_list = [train_data_mb, data[::2]] 80 | eval_data_list = [eval_data_mb, data[1::2]] 81 | 82 | model, score = frbkeras.merge_models(model_list, train_data_list, 83 | train_labels, eval_data_list, eval_labels, 84 | epoch=5) 85 | 86 | print(score) 87 | 88 | return data, labels, train_data_mb, train_labels, model 89 | 90 | def make_multibeam_data(ntrigger=2304, tp_frac=0.5, 91 | nbeam=40, rows=8, cols=5): 92 | 93 | A = generate_multibeam(nbeam=nbeam, rows=rows, cols=cols) 94 | # Take a euclidean flux distribution 95 | sn = np.random.uniform(1, 1000, 100*ntrigger)**-(2/3.) 96 | sn /= np.median(sn) 97 | sn *= 15 98 | #sn[sn > 150] = 150 99 | 100 | det_ = [] 101 | sn_ = [] 102 | multis = 0 103 | 104 | # drop FRBs at random locations with random flux 105 | for ii, ss in enumerate(sn): 106 | xi = np.random.uniform(400, 650) 107 | yi = np.random.uniform(300, 750) 108 | abeams = A[int(xi), int(yi)] * ss 109 | beamdet = np.where(abeams>=6)[0] 110 | if len(beamdet)>0: 111 | det_.append(beamdet) 112 | sn_.append(abeams[beamdet]) 113 | if len(beamdet)>1: 114 | multis += 1 115 | 116 | ntrigger = min(2*len(det_), ntrigger) 117 | data = np.zeros([nbeam*ntrigger]).reshape(-1, nbeam) 118 | N_FP = int((1-tp_frac)*ntrigger) 119 | N_TP = int(tp_frac*ntrigger) 120 | 121 | for ii in range(N_FP): 122 | # nbeam_ii = int(np.random.uniform(1, 32)) 123 | 124 | # Generate number of beams RFI shows up in 125 | nbeam_ii = min(nbeam, int(np.random.lognormal(1.25, 0.8))) 126 | 127 | ind = set(np.random.uniform(1, nbeam, nbeam_ii).astype(int).astype(list)) 128 | data[ii][list(ind)] = np.random.normal(20, 5, len(ind)) 129 | 130 | for ii in range(N_TP): 131 | # beam = int(np.random.uniform(1, 32)) 132 | data[N_FP+ii][det_[ii]] = sn_[ii]#np.random.normal(20, 5, 1) 133 | 134 | # set RFI labels to 0, FRBs to 1 135 | labels = np.zeros([ntrigger]) 136 | labels[N_FP:] = 1 137 | 138 | # convert to categorical array with shape (-1, 2) 139 | labels = labels.astype(int) 140 | labels = keras.utils.to_categorical(labels) 141 | 142 | # Print to see if fraction of multibeam detections is expected 143 | print(np.float(multis) / len(det_)) 144 | 145 | return data, labels 146 | 147 | def run_model(n, nbeam=40): 148 | import frbkeras 149 | 150 | data_mb, labels = make_multibeam_data(nbeam=nbeam, ntrigger=n, tp_frac=0.5) 151 | train_data_mb = data_mb[::2] 152 | train_labels = labels[::2] 153 | eval_data_mb = data_mb[1::2] 154 | eval_labels = labels[1::2] 155 | 156 | model_mb, score_mb = frbkeras.construct_ff1d( 157 | features_only=False, fit=True, 158 | train_data=train_data_mb, 159 | train_labels=train_labels, 160 | eval_data=eval_data_mb, 161 | eval_labels=eval_labels, 162 | nbeam=nbeam, epochs=5, 163 | nlayer1=32, nlayer2=32, 164 | batch_size=32) 165 | 166 | if len(score_mb)>1: 167 | prob, predictions, mistakes = frbkeras.get_predictions( 168 | model_mb, eval_data_mb, 169 | true_labels=eval_labels) 170 | print(score_mb) 171 | 172 | return train_data_mb, train_labels, eval_data_mb, eval_labels, model_mb 173 | 174 | -------------------------------------------------------------------------------- /single_pulse_ml/telescope.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Telescope(): 5 | 6 | def __init__(self, freq=(800, 400), FREQ_REF=600, 7 | DELTA_T=0.0016, name=None): 8 | """ Telescope class that can be fed to simulation 9 | 10 | Parameters: 11 | ----------- 12 | freq : tuple 13 | two-element tuple with (FREQ_LOW, FREQ_UP) in MHz 14 | e.g. for CHIME this is (800., 400.) 15 | DELTA_T : float 16 | time resolution in seconds 17 | NFREQ : int 18 | number of frequencies 19 | NTIME : int 20 | number of time samples 21 | name : str 22 | telescope name, e.g. CHIME_PATHFINDER 23 | 24 | """ 25 | self._FREQ_LOW = freq[0] 26 | self._FREQ_UP = freq[-1] 27 | self._freq = freq 28 | self._FREQ_REF = FREQ_REF 29 | self._DELTA_T = DELTA_T 30 | self._telname = name 31 | -------------------------------------------------------------------------------- /single_pulse_ml/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liamconnor/single_pulse_ml/88b6b76ebf3d3939214d9785d4e1c5076f653c38/single_pulse_ml/tests/__init__.py -------------------------------------------------------------------------------- /single_pulse_ml/tests/test_frbkeras.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from single_pulse_ml import frbkeras 5 | 6 | class TestFRBkeras(unittest.TestCase): 7 | 8 | def test_get_classification_results(self): 9 | """ Test that the true/false postives/negatives, 10 | are correctly identified. 11 | """ 12 | y_true = np.round(np.random.rand(10000)) 13 | y_pred = np.round(np.random.rand(10000)) 14 | 15 | TP, FP, TN, FN = frbkeras.get_classification_results(y_true, y_pred) 16 | minlen = min(np.array([len(TP), len(FP), len(TN), len(FN)])) 17 | assert minlen>0, "There should be more than 0 of all" 18 | 19 | # Now create 1000 false events that are predicted true 20 | y_true = np.zeros([1000]) 21 | y_pred = np.ones([1000]) 22 | 23 | TP, FP, TN, FN = frbkeras.get_classification_results(y_true, y_pred) 24 | 25 | assert len(TP)==0 26 | assert len(FP)!=0 27 | assert len(TN)==0 28 | assert len(FN)==0 29 | 30 | def test_construct_conv2d(self): 31 | """ Test the 2d CNN by generating fake 32 | data (gaussian noise) and fitting model 33 | """ 34 | ntime = 64 35 | nfreq = 32 36 | ntrigger = 1000 37 | 38 | data = np.random.normal(0, 1, ntrigger*nfreq*ntime) 39 | data.shape = (ntrigger, nfreq, ntime, 1) 40 | labels = np.round(np.random.rand(ntrigger)) 41 | labels = frbkeras.keras.utils.to_categorical(labels) 42 | 43 | # try training a model on random noise. should not do 44 | # better than ~50% acc 45 | model, score = frbkeras.construct_conv2d(train_data=data[::2], 46 | train_labels=labels[::2], 47 | eval_data=data[1::2], 48 | eval_labels=labels[1::2], 49 | fit=True, epochs=3) 50 | assert score[1]<0.9, "Trained on random noise. Should not have high acc" 51 | self.model_conv2d = model 52 | 53 | 54 | def test_construct_conv1d(self): 55 | """ Test the 1d CNN by generating fake 56 | data (gaussian noise) and fitting model 57 | """ 58 | ntime = 64 59 | ntrigger = 1000 60 | 61 | data = np.random.normal(0, 1, ntrigger*ntime) 62 | data.shape = (ntrigger, ntime, 1) 63 | labels = np.round(np.random.rand(ntrigger)) 64 | labels = frbkeras.keras.utils.to_categorical(labels) 65 | 66 | # try training a model on random noise. should not do 67 | # better than ~50% acc 68 | model, score = frbkeras.construct_conv1d(fit=True, train_data=data[::2], 69 | train_labels=labels[::2], 70 | eval_data=data[1::2], 71 | eval_labels=labels[1::2], 72 | batch_size=16, epochs=3) 73 | 74 | assert score[1]<0.9, "Trained on random noise. Should not have high acc" 75 | 76 | self.model_conv1d = model 77 | 78 | def test_construct_ff1d(self): 79 | nbeam = 32 80 | ntrigger = 1000 81 | 82 | data = np.random.normal(0, 1, ntrigger*nbeam) 83 | data.shape = (ntrigger, nbeam, 1) 84 | labels = np.round(np.random.rand(ntrigger)) 85 | labels = frbkeras.keras.utils.to_categorical(labels) 86 | 87 | # try training a model on random noise. should not do 88 | # better than ~50% acc 89 | model, score = frbkeras.construct_conv1d(fit=True, train_data=data[::2], 90 | train_labels=labels[::2], 91 | eval_data=data[1::2], 92 | eval_labels=labels[1::2], 93 | batch_size=16, epochs=3) 94 | 95 | 96 | if __name__ == '__main__': 97 | unittest.main() 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /single_pulse_ml/tests/test_reader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | import h5py 4 | import numpy as np 5 | 6 | from single_pulse_ml import reader 7 | 8 | class TestReader(TestCase): 9 | 10 | def test_read_hdf5(self): 11 | NFREQ = 64 12 | NTIME = 250 13 | NCANDIDATES = 100 14 | data_freq_time = np.random.normal(0, 1, NFREQ*NTIME*NCANDIDATES) 15 | data_freq_time.shape = (NCANDIDATES, NFREQ, NTIME) 16 | labels = np.ones([NCANDIDATES]) 17 | fn = './test.hdf5' 18 | 19 | g = h5py.File(fn,'w') 20 | g.create_dataset('data_freq_time', data=data_freq_time) 21 | g.create_dataset('labels', data=labels) 22 | g.create_dataset('data_dm_time', data=[]) 23 | g.close() 24 | 25 | data_freq, y, data_dm, data_mb = reader.read_hdf5(fn) 26 | 27 | 28 | if __name__ == '__main__': 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /single_pulse_ml/tests/test_run_frb_simulation.py: -------------------------------------------------------------------------------- 1 | """ Test script generating 100 RFI events + 2 | 100 simulated FRBs. Gaussian noise is used. 3 | Parameters of the CHIME Pathfinder are used. 4 | Data are saved to hdf5 file and a plot is made 5 | of FRBs. 6 | """ 7 | 8 | from single_pulse_ml import sim_parameters 9 | from single_pulse_ml import telescope 10 | from single_pulse_ml import simulate_frb 11 | 12 | # TELESCOPE PARAMETERS: 13 | freq = (800, 400) # (FREQ_LOW, FREQ_UP) in MHz 14 | FREQ_REF = 600 # reference frequency in MHz 15 | DELTA_T = 0.0016 # time res in seconds 16 | NAME = "CHIMEPathfinder" 17 | 18 | # SIMULATION PARAMETERS 19 | NFREQ = 32 # Number of frequencies. Must agree with FP data 20 | NTIME = 250 # Number of time stamps per trigger 21 | dm = (-0.05, 0.05) 22 | fluence = (5, 100) 23 | width = (2*0.0016, 0.75) # width lognormal dist in seconds 24 | spec_ind = (-4., 4.) 25 | disp_ind = 2. 26 | scat_factor = (-4., -1.5) 27 | NRFI = 100 28 | SNR_MIN = 8.0 29 | SNR_MAX = 100.0 30 | out_file_name = None, 31 | mk_plot = True 32 | NSIDE = 8 33 | dm_time_array = False 34 | outname_tag = 'test' 35 | outdir = '../data/' 36 | figname = '../plots/test_out_fig.pdf' 37 | 38 | fn_rfi = None 39 | fn_noise = None 40 | 41 | sim_obj = sim_parameters.SimParams(dm=dm, fluence=fluence, 42 | width=width, spec_ind=spec_ind, 43 | disp_ind=disp_ind, scat_factor=scat_factor, 44 | SNR_MIN=SNR_MIN, SNR_MAX=SNR_MAX, 45 | out_file_name=out_file_name, NRFI=NRFI, 46 | NTIME=NTIME, NFREQ=NFREQ, 47 | mk_plot=mk_plot, NSIDE=NSIDE, ) 48 | 49 | tel_obj = telescope.Telescope(freq=freq, FREQ_REF=FREQ_REF, 50 | DELTA_T=DELTA_T, name=NAME) 51 | 52 | data, labels, params, snr = simulate_frb.run_full_simulation( 53 | sim_obj, tel_obj, fn_rfi=fn_rfi, 54 | fn_noise=fn_noise, 55 | dm_time_array=dm_time_array, 56 | outname_tag=outname_tag, outdir=outdir, 57 | figname=figname) 58 | 59 | -------------------------------------------------------------------------------- /single_pulse_ml/tests/test_simulate_frb.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from single_pulse_ml import simulate_frb 5 | 6 | class TestSimulate_FRB(unittest.TestCase): 7 | 8 | def test_gen_simulated_frb(self): 9 | 10 | 11 | sim_data, params = simulate_frb.gen_simulated_frb(NFREQ=16, NTIME=250, sim=True, 12 | fluence=(0.03,0.3), 13 | spec_ind=(-4, 4), width=(2*0.0016, 1), dm=(-0.15, 0.15), 14 | scat_factor=(-3, -0.5), background_noise=None, delta_t=0.0016, 15 | plot_burst=False, freq=(800, 400), FREQ_REF=600., 16 | ) 17 | 18 | dm, fluence, width, spec_ind, disp_ind, scat_factor = params 19 | 20 | print(dm) 21 | assert np.abs(dm) < 0.2, "DM is not in correct DM range" 22 | assert width > 0, "Width must be positive" 23 | assert disp_ind==2, "Disp index doesn't match input" 24 | 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /single_pulse_ml/tools.py: -------------------------------------------------------------------------------- 1 | # training data 30 September 2017 2 | # miscellaneous tools for preparing and processing 3 | # machine learning data 4 | 5 | import numpy as np 6 | import glob 7 | import scipy.signal 8 | 9 | from single_pulse_ml import dataproc 10 | 11 | def save_background_data(fdir, outfile=None, nfreq = 32): 12 | """ Read in randomly selected Pathfinder data in directory fdir, 13 | dedisperse to a DM between 25 and 2000 pc cm**-3, 14 | and create a large array of (nfreq, ntime_pulse) arrays 15 | over which FRBs can be injected. 16 | These data haven't been RFI cleaned! Could cause problems. 17 | """ 18 | fl = glob.glob(fdir) 19 | fl.sort() 20 | arr_full = [] 21 | 22 | freq_rebin = 1 23 | ntime_pulse = 250 24 | 25 | for ff in fl[:75]: 26 | print(ff) 27 | arr = np.load(ff)[:, 0] 28 | arr[arr!=arr] = 0. 29 | nfreq_arr, ntime = arr.shape 30 | print(arr.shape) 31 | 32 | # Disperse data to random dm 33 | _dm = np.random.uniform(25, 2000.0) 34 | arr = dedisperse_data(arr, _dm) 35 | 36 | # rebin to nfreq, divide data into blocks of len ntime_pulse 37 | arr = np.nansum(arr.reshape(-1, freq_rebin, ntime), axis=1)/freq_rebin 38 | arr = arr[:, :ntime//ntime_pulse*ntime_pulse] 39 | arr = arr.reshape(nfreq, -1, ntime_pulse) 40 | arr_full.append(arr) 41 | 42 | # Reorganize array to be (ntriggers, nfreq, ntime_pulse) 43 | arr_full = np.concatenate(arr_full)[:, :ntime//ntime_pulse*ntime_pulse] 44 | arr_full = arr_full.reshape(-1, nfreq, ntime//ntime_pulse, ntime_pulse) 45 | arr_full = np.transpose(arr_full, (0, 2, 1, 3)).reshape(-1, nfreq, ntime_pulse) 46 | 47 | # Go through each noise trigger and add data 48 | for ii, arr in enumerate(arr_full): 49 | arr_full[ii] = dataproc.normalize_data(arr) 50 | 51 | # Reshape to have same shape as RFI triggers 52 | #arr_full = arr_full.reshape(-1, nfreq*ntime_pulse) 53 | np.random.shuffle(arr_full) 54 | 55 | if outfile is not None: 56 | np.save(outfile, arr_full) 57 | 58 | return arr_full 59 | 60 | def dedisperse_data(f, _dm, freq_bounds=(800,400), dt=0.0016, freq_ref=600): 61 | """ Dedisperse data to some dispersion measure _dm. 62 | Frequency is in MHz, dt delta time in seconds. 63 | f is data to be dedispersed, shaped (nfreq, ntime) 64 | """ 65 | 66 | # Calculate the number of bins to shift for each freq 67 | NFREQ=f.shape[0] 68 | freq = np.linspace(freq_bounds[0], freq_bounds[1], NFREQ) 69 | ind_delay = ((4.148808e3 * _dm * (freq**(-2.) - freq_ref**(-2.))) / dt).astype(int) 70 | for ii, nu in enumerate(freq): 71 | f[ii] = np.roll(f[ii], -ind_delay[ii]) 72 | 73 | return f 74 | 75 | def calc_snr(arr, fast=False): 76 | """ Calculate the S/N of pulse profile after 77 | trying 9 rebinnings. 78 | 79 | Parameters 80 | ---------- 81 | arr : np.array 82 | (ntime,) vector of pulse profile 83 | ntime : np.int 84 | number of times in profile 85 | 86 | Returns 87 | ------- 88 | snr : np.float 89 | S/N of pulse 90 | """ 91 | assert len(arr.shape)==1 92 | 93 | ntime = len(arr) 94 | snr_max = 0 95 | widths = [1, 2, 4, 8, 16, 32, 64, 128] 96 | 97 | # for ii in range(1, 10): 98 | for ii in widths: 99 | 100 | # skip if boxcar width is greater than 1/4th ntime 101 | if ii > ntime//8: 102 | continue 103 | 104 | arr_copy = arr.copy() 105 | arr_ = arr_copy[:len(arr)//ii*ii].reshape(-1, ii).mean(-1) 106 | 107 | if fast is False: 108 | std_chunk = scipy.signal.detrend(arr_, type='linear') 109 | std_chunk.sort() 110 | ntime_r = len(std_chunk) 111 | stds = 1.148*np.sqrt((std_chunk[ntime_r//40:-ntime_r//40]**2.0).sum() / 112 | (0.95*ntime_r)) 113 | snr_ = std_chunk[-1] / stds 114 | else: 115 | sig = np.std(arr_[:len(arr_)//3]) 116 | snr_ = arr_.max() / sig 117 | 118 | if snr_ > snr_max: 119 | snr_max = snr_ 120 | width_max = ii 121 | 122 | return snr_max 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | --------------------------------------------------------------------------------