├── src ├── README.md ├── data_prep_final.py ├── final_test_150.py ├── final_test.py └── final_inner_test.py ├── models ├── README.md ├── merge_final_test.py ├── final_test.py ├── scan_center_final.ipynb └── layer_prep_final.ipynb ├── doc ├── README.md └── trackML.pdf ├── eda └── README.md ├── .github └── workflows │ └── lint_python.yml ├── README.md ├── environment.yml └── LICENSE /src/README.md: -------------------------------------------------------------------------------- 1 | Source code 2 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | ignore this folder, was temporary 2 | -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | Documentation of the approach used in that challenge. 2 | -------------------------------------------------------------------------------- /doc/trackML.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfpuget/Kaggle_TrackML/HEAD/doc/trackML.pdf -------------------------------------------------------------------------------- /eda/README.md: -------------------------------------------------------------------------------- 1 | Notebook used to produce some of the figures inthe documentation. Refer to documentation for explanations. 2 | -------------------------------------------------------------------------------- /.github/workflows/lint_python.yml: -------------------------------------------------------------------------------- 1 | name: lint_python 2 | on: 3 | pull_request: 4 | push: 5 | # branches: [master] 6 | jobs: 7 | lint_python: 8 | runs-on: ubuntu-latest 9 | # strategy: 10 | # matrix: 11 | # os: [ubuntu-latest, macos-latest, windows-latest] 12 | # python-version: [2.7, 3.5, 3.6, 3.7, 3.8] # , pypy3] 13 | steps: 14 | - uses: actions/checkout@master 15 | - uses: actions/setup-python@master 16 | # with: 17 | # python-version: ${{ matrix.python-version }} 18 | - run: pip install black codespell flake8 isort pytest reorder-python-imports 19 | - run: black --check . || true 20 | # - run: black --diff . || true 21 | # - if: matrix.python-version >= 3.6 22 | # run: | 23 | # pip install black 24 | # black --check . || true 25 | - run: codespell --quiet-level=2 || true # --ignore-words-list="" --skip="" 26 | - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 27 | # isort and reorder-python-imports are two ways of doing the same thing 28 | - run: isort --recursive . || true 29 | - run: reorder-python-imports . || true 30 | - run: pip install -r requirements.txt || true 31 | - run: pytest . || true 32 | -------------------------------------------------------------------------------- /src/data_prep_final.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import pickle as pkl 7 | from tqdm import tqdm_notebook as tqdm 8 | from collections import Counter 9 | 10 | from trackml.dataset import load_event 11 | 12 | def main(): 13 | # load 90 train events data 14 | data_l = [] 15 | for i in range(10,100): 16 | event = '../input/train_1/event0000010%d' % i 17 | print('event:', event) 18 | hits, cells, particles, truth = load_event(event) 19 | data = hits 20 | data = data.merge(truth, how='left', on='hit_id') 21 | data = data.merge(particles, how='left', on='particle_id') 22 | 23 | # keep hits from tracks orginating from vertex 24 | data['rv'] = np.sqrt(data.vx**2 + data.vy**2) 25 | data = data[(data.rv <= 1) & (data.vz <= 50) & (data.vz >=-50)].copy() 26 | data = data[data.weight > 0] 27 | data['event_id'] = i 28 | 29 | data['pt'] = np.sqrt(data.px ** 2 + data.py ** 2) 30 | 31 | # use a simple relationship to compute alpha0 from pt, see documentaiton or EDA notebook. 32 | data['alpha0'] = np.exp(-8.115 - np.log(data.pt)) 33 | 34 | data_l.append(data) 35 | 36 | data = pd.concat(data_l, axis=0) 37 | 38 | # compute track level statistics 39 | df = data.groupby(['event_id', 'particle_id'])[['alpha0', 'vz']].first() 40 | df = df.dropna() 41 | np.save('../data/scan_center.npy', df.values) 42 | 43 | # compute tracklet frequencies 44 | # tracklets are sub tracks of length 4 45 | 46 | # assign a unique layer to each hit 47 | data['layer'] = 100 * data.volume_id + data.layer_id 48 | 49 | # for each track compute a string containing the sequence of layers traversed by the track 50 | data = data.sort_values(by=['particle_id', 'z']).reset_index(drop=True) 51 | df = data.groupby(['event_id', 'particle_id']).layer.apply(lambda s: ' '.join([str(i) for i in s])) 52 | df = df.to_frame('layers') 53 | 54 | # count each tracklet occurences 55 | cnt = Counter() 56 | for x in tqdm(df.itertuples(name=None, index=False)): 57 | layers = x[0].split() 58 | for i in range(len(layers) - 3): 59 | s = ' '.join(layers[i:i+4]) 60 | cnt[s] += 1 61 | 62 | #save result 63 | with open('../data/layers_4_center_fix.pkl', 'wb') as file: 64 | pkl.dump(cnt, file) 65 | 66 | if __name__ == "__main__": 67 | main() 68 | 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kaggle_TrackML 2 | Code for the TrackML competition on Kaggle: https://www.kaggle.com/c/trackml-particle-identification 3 | 4 | That solution got ranked 9th in the competition. Read the documentation before looking at the code ;) The content below provides useful information if you want to run the code. 5 | 6 | 1. Hardware used. For EDA and model tuning on train events: an Intel box with a 4 core i7 at 4.2 GHZ and 64 GB of memory, running ubuntu 16.04. For computing tracks on test events, either a Dell T810 with 20 cores Xeon CPU at 2.4 GHZ, running ubuntu 14.04, and 64 GB of memory, or an IBM AC922 server with 40 P9 cores and 256 GB of memory running RHEL 7.5 (ppc64le). 7 | 8 | 2. The code consumes about 3GB per worker, hence memory is not an issue really. One should favor a large number of cores as tracks are computed for a number of events in parallel. 9 | 10 | 3. We used various linux versions, depending on the machine used, see 1 above. 11 | 12 | 4. An environment.yml is provided but we use a much smaller set of packages than indicated. The code can run with only numpy, pandas, pickle, and scikit-learn installed on top of Python 3.6. EDA notebooks require anaconda, matplotlib, and seaborn. Version numbers are provided in the yaml file. 13 | 14 | 5. Running the code is rather simple: 15 | - Complete the cloned repo with additional directories as follows: 16 | 17 | ./data/ 18 | 19 | ./input/ 20 | 21 | ./submissions/final/ 22 | 23 | ./submissions/final_inner/ 24 | 25 | ./submissions/merge_final/ 26 | 27 | The test events data should then be downloaded into the ./input directory. 28 | 29 | - Edit the base_path value in the scripts in src directory to match where you cloned the code. 30 | - Edit the number of Pool workers in the scripts final_test.py, final_inner_test.py and merge_final.py to match the number of processors of your machine. 31 | 32 | - Run the scripts in that order: 33 | 34 | data_prep_final.py 35 | 36 | final_test.py 37 | 38 | final_inner_test.py 39 | 40 | merge_final_test.py 41 | 42 | The last script produces a file named merge_final.csv in the submissions/ directory. This file can be submitted to Kaggle server to get a private LB score slightly above 0.800 43 | 44 | Running only the first two scripts produce a simplified model file named final.csv in the submissions/ directory. This file can be submitted to Kaggle server to get a private LB score slightly above 0.787. 45 | 46 | Running the above scripts can take days. We provide final_test_150.py for faster runs. Edit the base_path and the number of workers in the pool before running it. Also create a directory final_150 in the submissions directory. That script runs in about 150 second per event with one i7 core and produces an output file named final_150.csv in the submissions/ directory. This file can be submitted to Kaggle server to yield a private LB score above 0.51. 47 | -------------------------------------------------------------------------------- /models/merge_final_test.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from tqdm import tqdm_notebook as tqdm 7 | 8 | from trackml.dataset import load_event 9 | from multiprocessing import Pool 10 | 11 | base_path = '/home/jfpuget/Kaggle/TrackML/' 12 | 13 | def merge_track(data, center, threshold_base, threshold_center, threshold): 14 | # merges tracks from two models for a given event 15 | # data['track_id'] contains the first set of tracks 16 | # center['track_id'] contains the second set of tracks 17 | # thresholds are used to fitler whoch pairs of tracks should be merged 18 | 19 | # save original data columns for output 20 | data_cols = data.columns 21 | 22 | # reindex center to align with data[data.volume_id < 10] index 23 | data_center = data[data.volume_id < 10] 24 | data_center = data_center[['hit_id']].merge(center, how='left', on='hit_id') 25 | data['center'] = 0 26 | data.loc[data.volume_id < 10, 'center'] = data_center.track_id.values 27 | 28 | # computes track overlap 29 | data['count_both'] = data.groupby(['track_id', 'center']).hit_id.transform('count') 30 | data['count_center'] = data.groupby(['center']).hit_id.transform('count') 31 | data['count_track'] = data.groupby(['track_id']).hit_id.transform('count') 32 | 33 | # compute pairs of tracks that overlap more than input thresholds 34 | mapping = data.groupby(['track_id', 'center']).count_both.max().to_frame('count_max').reset_index() 35 | data = data.merge(mapping, how='left', on=['track_id', 'center']) 36 | data['valid'] = ((data.count_max == data.count_both) & \ 37 | (data.count_both > threshold) & \ 38 | (data.count_both > threshold_center*data.count_center) & \ 39 | (data.count_both > threshold_base*data.count_track)) 40 | mapping = data[data.valid].groupby(['track_id', 'center'])[['track_id', 'center']].first() 41 | 42 | # merge tracks 43 | data['new_center'] = 0 44 | for t,c in mapping.itertuples(index=False, name=None): 45 | data.loc[data.center == c, 'new_center'] = t 46 | data.loc[(data.track_id > 0) & (data.new_center == 0), 'new_center'] = data.track_id[(data.track_id > 0) & (data.new_center == 0)].values 47 | # use remaining tracks for remaining hits 48 | track_max = data.new_center.max() 49 | data.loc[data.center > 0, 'center'] += track_max 50 | data.loc[(data.center > 0) & (data.new_center == 0), 'new_center'] = data.center[(data.center > 0) & (data.new_center == 0)].values 51 | data.track_id = data.new_center 52 | return data[data_cols].copy() 53 | 54 | 55 | def get_event(i): 56 | return 'event000000%03d' % i 57 | 58 | def work_sub(param): 59 | # merges tracks for event i and saves result into a file 60 | (i, ) = param 61 | th_b = 0.16 62 | th_c = 0.45 63 | 64 | event = get_event(i) 65 | print('event:', event) 66 | hits, cells = load_event('../input/test/' + event, parts=['hits', 'cells']) 67 | data = pd.read_csv('../submissions/final/'+event) 68 | data = data.merge(hits, how='left', on='hit_id') 69 | inner = pd.read_csv('../submissions/final_inner/'+event) 70 | data = merge_track(data, inner, th_b, th_c, 1) 71 | data['event_id'] = i 72 | data[['event_id', 'hit_id', 'track_id']].to_csv('../submissions/merge_final/' + event +'.csv', 73 | index=False) 74 | return i 75 | 76 | 77 | def main(): 78 | # merge each event tracks in parallel. 79 | # number of process in pool should be close to the number of processors 80 | params = [(i, ) for i in range(125)] 81 | 82 | if 1: 83 | pool = Pool(processes=20, maxtasksperchild=1) 84 | ls = pool.map( work_sub, params, chunksize=1 ) 85 | pool.close() 86 | else: 87 | ls = [work_sub(param) for param in params] 88 | 89 | # computes submission by concatenating each event tracks 90 | submissions = [] 91 | for i in tqdm(range(125)): 92 | event = get_event(i) 93 | data0 = pd.read_csv('../submissions/merge_final/' + event + '.csv') 94 | submissions.append(data0) 95 | 96 | submission = pd.concat(submissions, axis=0) 97 | submission.track_id = (submission.track_id).astype('int64') 98 | submission.to_csv('../submissions/merge_final.csv', index=False) 99 | 100 | if __name__ == "__main__": 101 | main() 102 | 103 | 104 | -------------------------------------------------------------------------------- /models/final_test.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import pickle as pkl 7 | from sklearn.cluster import DBSCAN 8 | from itertools import combinations 9 | from multiprocessing import Pool 10 | 11 | base_path = '/home/jfpuget/Kaggle/TrackML/' 12 | 13 | class Clusterer(object): 14 | 15 | def __init__(self, eps, max_cluster, scan_center, quality_threshold, cnt, event): 16 | self.eps_ = eps 17 | self.max_cluster_ = max_cluster 18 | self.scan_center_ = scan_center 19 | self.quality_threshold_ = quality_threshold 20 | self.cnt_ = cnt 21 | self.event_ = event 22 | 23 | def fit_alpha(self, data, alpha0, z): 24 | cond0 = np.abs(data.rt * alpha0) < 1 25 | data1 = data[cond0].copy() 26 | data1['theta0'] = np.arcsin(data1.rt * alpha0) 27 | data1['rt1'] = data1.theta0 / alpha0 28 | data1['theta0'] *= data1.theta_ratio 29 | data1['theta'] = data1.theta_base + data1.theta0 30 | 31 | data1['xcos'] = np.cos(data1.theta) 32 | data1['ysin'] = np.sin(data1.theta) 33 | data1['zr'] = (np.arcsinh((data1.z - z) / (0.7 * data1.rt1)) / 3.5) 34 | dfs = data1[['xcos', 'ysin', 'zr']] 35 | clusters0 = DBSCAN(eps=self.eps_, 36 | min_samples=2, 37 | metric='euclidean', 38 | n_jobs=1).fit(dfs).labels_ 39 | clusters = np.zeros(data.shape[0]) 40 | clusters[cond0] = clusters0 + 1 41 | 42 | maxs1 = data['s1'].max() 43 | clusters[clusters > 0] += maxs1 44 | data['s2'] = clusters 45 | data['N2'] = count_module(data, 's2') 46 | data.loc[data.N2 < 2, 's2'] = 0 47 | data.loc[data.N2 > self.max_cluster_, 's2'] = 0 48 | data['Q2'] = get_all_layer_quality(data, 's2', self.cnt_) 49 | data['WN2'] = data.N2 * data.Q2 50 | data.loc[data.WN2 <= data.WN1, 's2'] = 0 51 | data['N2'] = count_module(data, 's2') 52 | data['WN2'] = data.N2 * data.Q2 53 | cond = ( (data['WN2'] > data['WN1']) \ 54 | & (data['Q2'] > self.quality_threshold_) \ 55 | & (data['N2'] < self.max_cluster_) 56 | ) 57 | data.loc[cond, 's1'] = data.loc[cond, 's2'] 58 | data.loc[cond, 'Q1'] = data.loc[cond, 'Q2'] 59 | 60 | data['N1'] = count_module(data, 's1') 61 | cond = ((data['N1'] >= self.max_cluster_) ) 62 | data.loc[cond, 'N1'] = 0 63 | data.loc[cond, 's1'] = 0 64 | data.loc[cond, 'Q1'] = 0 65 | data['WN1'] = data.N1 * data.Q1 66 | 67 | data['track_id'] = data['s1'] 68 | return data 69 | 70 | def fit_predict(self, data, n_iter): 71 | data['theta_base'] = np.arctan2(data.y, data.x) 72 | data['rt'] = np.sqrt(data.x**2 + data.y**2) 73 | data['theta0'] = 0.0 74 | data['theta'] = 0.0 75 | data['rt1'] = 0.0 76 | data['zr'] = 0.0 77 | data['s1'] = data.track_id 78 | data['Q1'] = get_all_layer_quality(data, 's1', self.cnt_) 79 | data['N1'] = count_module(data, 's1') 80 | data['WN1'] = data.N1 * data.Q1 81 | 82 | scan_center = self.scan_center_ 83 | 84 | np.random.seed(0) 85 | mm = 1 86 | for ii in (range(n_iter)): 87 | n = np.random.randint(0, len(scan_center)) 88 | alpha0 = scan_center[n, 0] 89 | z = scan_center[n, 1] 90 | data = self.fit_alpha(data, mm * alpha0, z) 91 | mm = - mm 92 | if ii % 1000 == 0: 93 | print(self.event_, '%05d' % ii) 94 | return data 95 | 96 | def count_module(dfh, col): 97 | dfmod = dfh.groupby([col, 'volume_id', 'layer_id']).hit_id.count() 98 | dfmod = dfmod.to_frame('n_volume_layer').reset_index().groupby(col).n_volume_layer.count().reset_index() 99 | dfmod = dfh[[col]].merge(dfmod[[col, 'n_volume_layer']], how='left', on=col) 100 | dfmod1 = dfh.groupby([col, 'volume_id', 'layer_id', 'module_id']).hit_id.count() 101 | dfmod1 = dfmod1.to_frame('n_volume_layer_module').reset_index().groupby(col).n_volume_layer_module.count().reset_index() 102 | dfmod = dfmod.merge(dfmod1[[col, 'n_volume_layer_module']], how='left', on=col) 103 | dfmod.loc[dfmod[col] == 0, 'n_volume_layer'] = 0 104 | dfmod.loc[dfmod['n_volume_layer_module'] <= 3, 'n_volume_layer'] = 0 105 | return dfmod.n_volume_layer.values 106 | 107 | def get_all_layer_quality(dfh, col, cnt): 108 | 109 | def get_layer_quality(layers, cnt=cnt): 110 | layers = [str(x) for x in layers] 111 | if len(layers) <= 3: 112 | return 0 113 | quality = ([cnt[' '.join(layers[i:i+4])] for i in range(len(layers) - 3)]) 114 | quality = np.mean(np.log1p(quality)) 115 | return quality 116 | 117 | dfmod = dfh.sort_values(by=[col, 'z']) 118 | df = dfmod[dfmod[col] > 0].groupby([col]).layer.apply(get_layer_quality).to_frame('layer_quality').reset_index() 119 | dfh = dfh[[col]].merge(df, how='left', on=col) 120 | dfh.layer_quality.fillna(0, inplace=True) 121 | return dfh.layer_quality.values 122 | 123 | def get_event(i): 124 | return 'event000000%03d' % int(i) 125 | 126 | def work_sub(param): 127 | (i, num_iter, max_cluster, quality_threshold) = param 128 | 129 | event = get_event(i) 130 | print('event:', event) 131 | hits = pd.read_csv(base_path+'input/test/'+event + '-hits.csv') 132 | data = hits 133 | print(data.shape) 134 | data['event_id'] = i 135 | data['track_id'] = 0 136 | data['layer'] = 100 * data.volume_id + data.layer_id 137 | data['theta_ratio'] = 1 - (np.abs(data.z + 200) / 6000)**2.4 + 0.005 138 | 139 | scan_center = np.load('../data/scan_center.npy') 140 | 141 | with open('../data/layers_4_center_fix.pkl', 'rb') as file: 142 | cnt = pkl.load(file) 143 | 144 | model = Clusterer(eps=0.0028, max_cluster=max_cluster, scan_center=scan_center, cnt=cnt, 145 | quality_threshold=quality_threshold, event=event) 146 | data = model.fit_predict(data, num_iter) 147 | 148 | data[['event_id', 'hit_id', 'track_id']].to_csv(base_path+'submissions/final/' + event, index=False) 149 | return i 150 | 151 | def main(): 152 | max_cluster = 20 153 | quality_threshold = 5 154 | num_iter=60000 155 | 156 | params = [(i, num_iter, max_cluster, quality_threshold) for i in range(125)] 157 | 158 | if 1: 159 | pool = Pool(processes=42, maxtasksperchild=1) 160 | ls = pool.map( work_sub, params, chunksize=1 ) 161 | pool.close() 162 | else: 163 | ls = [work_sub(param) for param in params] 164 | 165 | if __name__ == "__main__": 166 | main() 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /src/final_test_150.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import pickle as pkl 7 | from sklearn.cluster import DBSCAN 8 | from multiprocessing import Pool 9 | from tqdm import tqdm_notebook as tqdm 10 | 11 | base_path = '/home/jfpuget/Kaggle/TrackML/' 12 | 13 | class Clusterer(object): 14 | 15 | def __init__(self, eps, max_cluster, scan_center, quality_threshold, cnt, event): 16 | self.eps_ = eps 17 | self.max_cluster_ = max_cluster 18 | self.scan_center_ = scan_center 19 | self.quality_threshold_ = quality_threshold 20 | self.cnt_ = cnt 21 | self.event_ = event 22 | 23 | def fit_alpha(self, data, alpha0, z): 24 | cond0 = np.abs(data.rt * alpha0) < 1 25 | data1 = data[cond0].copy() 26 | data1['theta0'] = np.arcsin(data1.rt * alpha0) 27 | data1['rt1'] = data1.theta0 / alpha0 28 | data1['theta0'] *= data1.theta_ratio 29 | data1['theta'] = data1.theta_base + data1.theta0 30 | 31 | data1['xcos'] = np.cos(data1.theta) 32 | data1['ysin'] = np.sin(data1.theta) 33 | data1['zr'] = (np.arcsinh((data1.z - z) / (data1.rt1)) / 3.5) 34 | dfs = data1[['xcos', 'ysin', 'zr']] 35 | clusters0 = DBSCAN(eps=self.eps_, 36 | min_samples=2, 37 | metric='euclidean', 38 | n_jobs=1).fit(dfs).labels_ 39 | clusters = np.zeros(data.shape[0]) 40 | clusters[cond0] = clusters0 + 1 41 | 42 | maxs1 = data['s1'].max() 43 | clusters[clusters > 0] += maxs1 44 | data['s2'] = clusters 45 | data['N2'] = count_module(data, 's2') 46 | data.loc[data.N2 < 2, 's2'] = 0 47 | data.loc[data.N2 > self.max_cluster_, 's2'] = 0 48 | data['Q2'] = get_all_layer_quality(data, 's2', self.cnt_) 49 | data['WN2'] = data.N2 * data.Q2 50 | data.loc[data.WN2 <= data.WN1, 's2'] = 0 51 | data['N2'] = count_module(data, 's2') 52 | data['WN2'] = data.N2 * data.Q2 53 | cond = ( (data['WN2'] > data['WN1']) \ 54 | & (data['Q2'] > self.quality_threshold_) \ 55 | & (data['N2'] < self.max_cluster_) 56 | ) 57 | data.loc[cond, 's1'] = data.loc[cond, 's2'] 58 | data.loc[cond, 'Q1'] = data.loc[cond, 'Q2'] 59 | 60 | data['N1'] = count_module(data, 's1') 61 | cond = ((data['N1'] >= self.max_cluster_) ) 62 | data.loc[cond, 'N1'] = 0 63 | data.loc[cond, 's1'] = 0 64 | data.loc[cond, 'Q1'] = 0 65 | data['WN1'] = data.N1 * data.Q1 66 | 67 | data['track_id'] = data['s1'] 68 | return data 69 | 70 | def fit_predict(self, data, n_iter): 71 | data['theta_base'] = np.arctan2(data.y, data.x) 72 | data['rt'] = np.sqrt(data.x**2 + data.y**2) 73 | data['theta0'] = 0.0 74 | data['theta'] = 0.0 75 | data['rt1'] = 0.0 76 | data['zr'] = 0.0 77 | data['s1'] = data.track_id 78 | data['Q1'] = get_all_layer_quality(data, 's1', self.cnt_) 79 | data['N1'] = count_module(data, 's1') 80 | data['WN1'] = data.N1 * data.Q1 81 | 82 | scan_center = self.scan_center_ 83 | 84 | np.random.seed(0) 85 | mm = 1 86 | for ii in (range(n_iter)): 87 | n = np.random.randint(0, len(scan_center)) 88 | alpha0 = scan_center[n, 0] 89 | z = scan_center[n, 1] 90 | data = self.fit_alpha(data, mm * alpha0, z) 91 | mm = - mm 92 | if ii % 1000 == 0: 93 | print(self.event_, '%05d' % ii) 94 | return data 95 | 96 | def count_module(dfh, col): 97 | dfmod = dfh.groupby([col, 'volume_id', 'layer_id']).hit_id.count() 98 | dfmod = dfmod.to_frame('n_volume_layer').reset_index().groupby(col).n_volume_layer.count().reset_index() 99 | dfmod = dfh[[col]].merge(dfmod[[col, 'n_volume_layer']], how='left', on=col) 100 | dfmod1 = dfh.groupby([col, 'volume_id', 'layer_id', 'module_id']).hit_id.count() 101 | dfmod1 = dfmod1.to_frame('n_volume_layer_module').reset_index().groupby(col).n_volume_layer_module.count().reset_index() 102 | dfmod = dfmod.merge(dfmod1[[col, 'n_volume_layer_module']], how='left', on=col) 103 | dfmod.loc[dfmod[col] == 0, 'n_volume_layer'] = 0 104 | dfmod.loc[dfmod['n_volume_layer_module'] <= 3, 'n_volume_layer'] = 0 105 | return dfmod.n_volume_layer.values 106 | 107 | def get_all_layer_quality(dfh, col, cnt): 108 | 109 | def get_layer_quality(layers, cnt=cnt): 110 | layers = [str(x) for x in layers] 111 | if len(layers) <= 3: 112 | return 0 113 | quality = ([cnt[' '.join(layers[i:i+4])] for i in range(len(layers) - 3)]) 114 | quality = np.mean(np.log1p(quality)) 115 | return quality 116 | 117 | dfmod = dfh.sort_values(by=[col, 'z']) 118 | df = dfmod[dfmod[col] > 0].groupby([col]).layer.apply(get_layer_quality).to_frame('layer_quality').reset_index() 119 | dfh = dfh[[col]].merge(df, how='left', on=col) 120 | dfh.layer_quality.fillna(0, inplace=True) 121 | return dfh.layer_quality.values 122 | 123 | def get_event(i): 124 | return 'event000000%03d' % int(i) 125 | 126 | def work_sub(param): 127 | (i, num_iter, max_cluster, quality_threshold) = param 128 | 129 | event = get_event(i) 130 | print('event:', event) 131 | hits = pd.read_csv(base_path+'input/test/'+event + '-hits.csv') 132 | data = hits 133 | print(data.shape) 134 | data['event_id'] = i 135 | data['track_id'] = 0 136 | data['layer'] = 100 * data.volume_id + data.layer_id 137 | data['theta_ratio'] = 1 - (np.abs(data.z + 200) / 6000)**2.4 + 0.005 138 | 139 | scan_center = np.load('../data/scan_center.npy') 140 | 141 | with open('../data/layers_4_center_fix.pkl', 'rb') as file: 142 | cnt = pkl.load(file) 143 | 144 | model = Clusterer(eps=0.0028, max_cluster=max_cluster, scan_center=scan_center, cnt=cnt, 145 | quality_threshold=quality_threshold, event=event) 146 | data = model.fit_predict(data, num_iter) 147 | 148 | data[['event_id', 'hit_id', 'track_id']].to_csv(base_path+'submissions/final_150/' + event, index=False) 149 | return i 150 | 151 | def main(): 152 | max_cluster = 20 153 | quality_threshold = 5 154 | num_iter=150 155 | 156 | params = [(i, num_iter, max_cluster, quality_threshold) for i in range(125)] 157 | 158 | if 1: 159 | pool = Pool(processes=4, maxtasksperchild=1) 160 | ls = pool.map( work_sub, params, chunksize=1 ) 161 | pool.close() 162 | else: 163 | ls = [work_sub(param) for param in params] 164 | 165 | submissions = [] 166 | for i in tqdm(range(125)): 167 | event = get_event(i) 168 | data0 = pd.read_csv('../submissions/final_150/' + event) 169 | submissions.append(data0) 170 | 171 | submission = pd.concat(submissions, axis=0) 172 | submission.track_id = (submission.track_id).astype('int64') 173 | submission.to_csv('../submissions/sub_final_150.csv', index=False) 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - arrow-cpp=0.8.0=py36_4 7 | - blas=1.1=openblas 8 | - ca-certificates=2018.4.16=0 9 | - certifi=2018.4.16=py36_0 10 | - conda=4.5.2=py36_0 11 | - conda-env=2.6.0=0 12 | - dask=0.15.4=py_0 13 | - dask-core=0.15.4=py_0 14 | - distributed=1.19.3=py36_0 15 | - feather-format=0.4.0=py36_2 16 | - joblib=0.11=py36_0 17 | - numpy=1.12.1=py36_blas_openblas_200 18 | - openblas=0.2.19=2 19 | - openssl=1.0.2o=0 20 | - pandas=0.22.0=py36_0 21 | - parquet-cpp=1.4.0.pre=2 22 | - pyarrow=0.8.0=py36_0 23 | - regex=2018.02.21=py36_0 24 | - scikit-learn=0.19.0=py36_blas_openblas_201 25 | - scipy=0.19.1=py36_blas_openblas_202 26 | - setuptools=36.3.0=py36_0 27 | - tqdm=4.15.0=py_1 28 | - zict=0.1.3=py_0 29 | - _license=1.1=py36_1 30 | - alabaster=0.7.10=py36_0 31 | - anaconda-client=1.6.3=py36_0 32 | - anaconda=custom=py36_0 33 | - anaconda-navigator=1.6.2=py36_0 34 | - anaconda-project=0.6.0=py36_0 35 | - asn1crypto=0.22.0=py36_0 36 | - astroid=1.4.9=py36_0 37 | - astropy=1.3.2=np112py36_0 38 | - babel=2.4.0=py36_0 39 | - backports=1.0=py36_0 40 | - beautifulsoup4=4.6.0=py36_0 41 | - bitarray=0.8.1=py36_0 42 | - blaze=0.10.1=py36_0 43 | - bleach=1.5.0=py36_0 44 | - bokeh=0.12.5=py36_1 45 | - boto=2.46.1=py36_0 46 | - bottleneck=1.2.1=np112py36_0 47 | - cairo=1.14.8=0 48 | - cffi=1.10.0=py36_0 49 | - chardet=3.0.3=py36_0 50 | - click=6.7=py36_0 51 | - cloudpickle=0.2.2=py36_0 52 | - clyent=1.2.2=py36_0 53 | - colorama=0.3.9=py36_0 54 | - contextlib2=0.5.5=py36_0 55 | - cryptography=1.8.1=py36_0 56 | - curl=7.52.1=0 57 | - cycler=0.10.0=py36_0 58 | - cython=0.25.2=py36_0 59 | - cytoolz=0.8.2=py36_0 60 | - datashape=0.5.4=py36_0 61 | - dbus=1.10.10=0 62 | - decorator=4.0.11=py36_0 63 | - docutils=0.13.1=py36_0 64 | - entrypoints=0.2.2=py36_1 65 | - et_xmlfile=1.0.1=py36_0 66 | - expat=2.1.0=0 67 | - fastcache=1.0.2=py36_1 68 | - flask=0.12.2=py36_0 69 | - flask-cors=3.0.2=py36_0 70 | - fontconfig=2.12.1=3 71 | - freetype=2.5.5=2 72 | - get_terminal_size=1.0.0=py36_0 73 | - gevent=1.2.1=py36_0 74 | - glib=2.50.2=1 75 | - greenlet=0.4.12=py36_0 76 | - gst-plugins-base=1.8.0=0 77 | - gstreamer=1.8.0=0 78 | - h5py=2.7.0=np112py36_0 79 | - harfbuzz=0.9.39=2 80 | - hdf5=1.8.17=1 81 | - heapdict=1.0.0=py36_1 82 | - html5lib=0.999=py36_0 83 | - icu=54.1=0 84 | - idna=2.5=py36_0 85 | - imagesize=0.7.1=py36_0 86 | - ipykernel=4.6.1=py36_0 87 | - ipython=5.3.0=py36_0 88 | - ipython_genutils=0.2.0=py36_0 89 | - ipywidgets=6.0.0=py36_0 90 | - isort=4.2.5=py36_0 91 | - itsdangerous=0.24=py36_0 92 | - jbig=2.1=0 93 | - jdcal=1.3=py36_0 94 | - jedi=0.10.2=py36_2 95 | - jinja2=2.9.6=py36_0 96 | - jpeg=9b=0 97 | - jsonschema=2.6.0=py36_0 98 | - jupyter=1.0.0=py36_3 99 | - jupyter_client=5.0.1=py36_0 100 | - jupyter_console=5.1.0=py36_0 101 | - jupyter_core=4.3.0=py36_0 102 | - lazy-object-proxy=1.2.2=py36_0 103 | - libffi=3.2.1=1 104 | - libgcc=4.8.5=2 105 | - libgfortran=3.0.0=1 106 | - libiconv=1.14=0 107 | - libpng=1.6.27=0 108 | - libsodium=1.0.10=0 109 | - libtiff=4.0.6=3 110 | - libtool=2.4.2=0 111 | - libxcb=1.12=1 112 | - libxml2=2.9.4=0 113 | - libxslt=1.1.29=0 114 | - llvmlite=0.18.0=py36_0 115 | - locket=0.2.0=py36_1 116 | - lxml=3.7.3=py36_0 117 | - markupsafe=0.23=py36_2 118 | - matplotlib=2.0.2=np112py36_0 119 | - mistune=0.7.4=py36_0 120 | - mkl=2017.0.1=0 121 | - mkl-service=1.1.2=py36_3 122 | - mpmath=0.19=py36_1 123 | - msgpack-python=0.4.8=py36_0 124 | - multipledispatch=0.4.9=py36_0 125 | - navigator-updater=0.1.0=py36_0 126 | - nbconvert=5.1.1=py36_0 127 | - nbformat=4.3.0=py36_0 128 | - networkx=1.11=py36_0 129 | - nltk=3.2.3=py36_0 130 | - nose=1.3.7=py36_1 131 | - notebook=5.0.0=py36_0 132 | - numba=0.33.0=np112py36_0 133 | - numexpr=2.6.2=np112py36_0 134 | - numpydoc=0.6.0=py36_0 135 | - odo=0.5.0=py36_1 136 | - olefile=0.44=py36_0 137 | - openpyxl=2.4.7=py36_0 138 | - packaging=16.8=py36_0 139 | - pandocfilters=1.4.1=py36_0 140 | - pango=1.40.3=1 141 | - partd=0.3.8=py36_0 142 | - path.py=10.3.1=py36_0 143 | - pathlib2=2.2.1=py36_0 144 | - patsy=0.4.1=py36_0 145 | - pcre=8.39=1 146 | - pep8=1.7.0=py36_0 147 | - pexpect=4.2.1=py36_0 148 | - pickleshare=0.7.4=py36_0 149 | - pillow=4.1.1=py36_0 150 | - pip=9.0.1=py36_1 151 | - pixman=0.34.0=0 152 | - ply=3.10=py36_0 153 | - prompt_toolkit=1.0.14=py36_0 154 | - psutil=5.2.2=py36_0 155 | - ptyprocess=0.5.1=py36_0 156 | - py=1.4.33=py36_0 157 | - pycosat=0.6.2=py36_0 158 | - pycparser=2.17=py36_0 159 | - pycrypto=2.6.1=py36_6 160 | - pycurl=7.43.0=py36_2 161 | - pyflakes=1.5.0=py36_0 162 | - pygments=2.2.0=py36_0 163 | - pylint=1.6.4=py36_1 164 | - pyodbc=4.0.16=py36_0 165 | - pyopenssl=17.0.0=py36_0 166 | - pyparsing=2.1.4=py36_0 167 | - pyqt=5.6.0=py36_2 168 | - pytables=3.3.0=np112py36_0 169 | - pytest=3.0.7=py36_0 170 | - python=3.6.1=2 171 | - python-dateutil=2.6.0=py36_0 172 | - pytz=2017.2=py36_0 173 | - pywavelets=0.5.2=np112py36_0 174 | - pyyaml=3.12=py36_0 175 | - pyzmq=16.0.2=py36_0 176 | - qt=5.6.2=4 177 | - qtawesome=0.4.4=py36_0 178 | - qtconsole=4.3.0=py36_0 179 | - qtpy=1.2.1=py36_0 180 | - readline=6.2=2 181 | - requests=2.14.2=py36_0 182 | - rope=0.9.4=py36_1 183 | - ruamel_yaml=0.11.14=py36_1 184 | - scikit-image=0.13.0=np112py36_0 185 | - seaborn=0.7.1=py36_0 186 | - simplegeneric=0.8.1=py36_1 187 | - singledispatch=3.4.0.3=py36_0 188 | - sip=4.18=py36_0 189 | - six=1.10.0=py36_0 190 | - snowballstemmer=1.2.1=py36_0 191 | - sortedcollections=0.5.3=py36_0 192 | - sortedcontainers=1.5.7=py36_0 193 | - sphinx=1.5.6=py36_0 194 | - spyder=3.1.4=py36_0 195 | - sqlalchemy=1.1.9=py36_0 196 | - sqlite=3.13.0=0 197 | - statsmodels=0.8.0=np112py36_0 198 | - sympy=1.0=py36_0 199 | - tblib=1.3.2=py36_0 200 | - terminado=0.6=py36_0 201 | - testpath=0.3=py36_0 202 | - tk=8.5.18=0 203 | - toolz=0.8.2=py36_0 204 | - tornado=4.5.1=py36_0 205 | - traitlets=4.3.2=py36_0 206 | - unicodecsv=0.14.1=py36_0 207 | - unixodbc=2.3.4=0 208 | - wcwidth=0.1.7=py36_0 209 | - werkzeug=0.12.2=py36_0 210 | - wheel=0.29.0=py36_0 211 | - widgetsnbextension=2.0.0=py36_0 212 | - wrapt=1.10.10=py36_0 213 | - xlrd=1.0.0=py36_0 214 | - xlsxwriter=0.9.6=py36_0 215 | - xlwt=1.2.0=py36_0 216 | - xz=5.2.2=1 217 | - yaml=0.1.6=0 218 | - zeromq=4.1.5=0 219 | - zlib=1.2.8=3 220 | - pip: 221 | - backports.shutil-get-terminal-size==1.0.0 222 | - bz2file==0.98 223 | - catboost==0.2.5 224 | - category-encoders==1.2.4 225 | - cplex==12.8.0.0 226 | - dill==0.2.7.1 227 | - docplex==2.4.61 228 | - gensim==3.0.0 229 | - hdbscan==0.8.13 230 | - keras==2.0.8 231 | - lightgbm==2.0.12 232 | - markdown==2.6.9 233 | - mlcrate==0.1.0 234 | - multiprocess==0.70.5 235 | - ortools==6.6.4656 236 | - pathos==0.2.1 237 | - plotly==2.2.1 238 | - pox==0.2.3 239 | - ppft==1.6.4.7.1 240 | - protobuf==3.5.1 241 | - rgf-python==2.0.3 242 | - rope-py3k==0.9.4.post1 243 | - smart-open==1.5.3 244 | - tables==3.3.0 245 | - tensorflow==1.3.0 246 | - tensorflow-tensorboard==0.1.8 247 | - trackml==2 248 | - urllib3==1.22 249 | - xgboost==0.7 250 | - xlearn==0.30a1 251 | prefix: /home/jfpuget/anaconda3 252 | 253 | -------------------------------------------------------------------------------- /src/final_test.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import pickle as pkl 7 | from sklearn.cluster import DBSCAN 8 | from itertools import combinations 9 | from multiprocessing import Pool 10 | 11 | base_path = '/home/jfpuget/Kaggle/TrackML/' 12 | 13 | class Clusterer(object): 14 | 15 | def __init__(self, eps, max_cluster, scan_center, quality_threshold, cnt, event): 16 | self.eps_ = eps 17 | self.max_cluster_ = max_cluster 18 | self.scan_center_ = scan_center 19 | self.quality_threshold_ = quality_threshold 20 | self.cnt_ = cnt 21 | self.event_ = event 22 | 23 | def fit_alpha(self, data, alpha0, z): 24 | # find tracks originating from (0, 0, z) and diameter 1/alpha0 25 | cond0 = np.abs(data.rt * alpha0) < 1 26 | data1 = data[cond0].copy() 27 | data1['theta0'] = np.arcsin(data1.rt * alpha0) 28 | data1['rt1'] = data1.theta0 / alpha0 29 | data1['theta0'] *= data1.theta_ratio 30 | data1['theta'] = data1.theta_base + data1.theta0 31 | 32 | data1['xcos'] = np.cos(data1.theta) 33 | data1['ysin'] = np.sin(data1.theta) 34 | data1['zr'] = (np.arcsinh((data1.z - z) / (0.7 * data1.rt1)) / 3.5) 35 | dfs = data1[['xcos', 'ysin', 'zr']] 36 | clusters0 = DBSCAN(eps=self.eps_, 37 | min_samples=2, 38 | metric='euclidean', 39 | n_jobs=1).fit(dfs).labels_ 40 | clusters = np.zeros(data.shape[0]) 41 | clusters[cond0] = clusters0 + 1 42 | 43 | # merge track candidates form DBSCAN with previously found track candidates 44 | maxs1 = data['s1'].max() 45 | clusters[clusters > 0] += maxs1 46 | data['s2'] = clusters 47 | data['N2'] = count_module(data, 's2') 48 | data.loc[data.N2 < 2, 's2'] = 0 49 | data.loc[data.N2 > self.max_cluster_, 's2'] = 0 50 | data['Q2'] = get_all_layer_quality(data, 's2', self.cnt_) 51 | data['WN2'] = data.N2 * data.Q2 52 | data.loc[data.WN2 <= data.WN1, 's2'] = 0 53 | data['N2'] = count_module(data, 's2') 54 | data['WN2'] = data.N2 * data.Q2 55 | cond = ( (data['WN2'] > data['WN1']) \ 56 | & (data['Q2'] > self.quality_threshold_) \ 57 | & (data['N2'] < self.max_cluster_) 58 | ) 59 | data.loc[cond, 's1'] = data.loc[cond, 's2'] 60 | data.loc[cond, 'Q1'] = data.loc[cond, 'Q2'] 61 | 62 | # update current tracks statistics 63 | data['N1'] = count_module(data, 's1') 64 | cond = ((data['N1'] >= self.max_cluster_) ) 65 | data.loc[cond, 'N1'] = 0 66 | data.loc[cond, 's1'] = 0 67 | data.loc[cond, 'Q1'] = 0 68 | data['WN1'] = data.N1 * data.Q1 69 | 70 | data['track_id'] = data['s1'] 71 | return data 72 | 73 | def fit_predict(self, data, n_iter): 74 | # initialize track statistics with no tracks 75 | data['theta_base'] = np.arctan2(data.y, data.x) 76 | data['rt'] = np.sqrt(data.x**2 + data.y**2) 77 | data['theta0'] = 0.0 78 | data['theta'] = 0.0 79 | data['rt1'] = 0.0 80 | data['zr'] = 0.0 81 | data['s1'] = data.track_id 82 | data['Q1'] = get_all_layer_quality(data, 's1', self.cnt_) 83 | data['N1'] = count_module(data, 's1') 84 | data['WN1'] = data.N1 * data.Q1 85 | 86 | scan_center = self.scan_center_ 87 | 88 | np.random.seed(0) 89 | mm = 1 90 | 91 | # loop over randoly chose track parameters 92 | for ii in (range(n_iter)): 93 | n = np.random.randint(0, len(scan_center)) 94 | alpha0 = scan_center[n, 0] 95 | z = scan_center[n, 1] 96 | data = self.fit_alpha(data, mm * alpha0, z) 97 | mm = - mm 98 | if ii % 1000 == 0: 99 | print(self.event_, '%05d' % ii) 100 | return data 101 | 102 | def count_module(dfh, col): 103 | # count the number of volumes traversed by track candidates define by values in column col 104 | dfmod = dfh.groupby([col, 'volume_id', 'layer_id']).hit_id.count() 105 | dfmod = dfmod.to_frame('n_volume_layer').reset_index().groupby(col).n_volume_layer.count().reset_index() 106 | dfmod = dfh[[col]].merge(dfmod[[col, 'n_volume_layer']], how='left', on=col) 107 | dfmod1 = dfh.groupby([col, 'volume_id', 'layer_id', 'module_id']).hit_id.count() 108 | dfmod1 = dfmod1.to_frame('n_volume_layer_module').reset_index().groupby(col).n_volume_layer_module.count().reset_index() 109 | dfmod = dfmod.merge(dfmod1[[col, 'n_volume_layer_module']], how='left', on=col) 110 | dfmod.loc[dfmod[col] == 0, 'n_volume_layer'] = 0 111 | dfmod.loc[dfmod['n_volume_layer_module'] <= 3, 'n_volume_layer'] = 0 112 | return dfmod.n_volume_layer.values 113 | 114 | def get_all_layer_quality(dfh, col, cnt): 115 | # computes quality for all track candidates defined by column col 116 | # cnt is a dictionary containing tracklet frequencies 117 | 118 | def get_layer_quality(layers, cnt=cnt): 119 | # compute smoothed geometric average of tracklet frequencies for one track candidate. 120 | layers = [str(x) for x in layers] 121 | if len(layers) <= 3: 122 | return 0 123 | quality = ([cnt[' '.join(layers[i:i+4])] for i in range(len(layers) - 3)]) 124 | quality = np.mean(np.log1p(quality)) 125 | return quality 126 | 127 | dfmod = dfh.sort_values(by=[col, 'z']) 128 | df = dfmod[dfmod[col] > 0].groupby([col]).layer.apply(get_layer_quality).to_frame('layer_quality').reset_index() 129 | dfh = dfh[[col]].merge(df, how='left', on=col) 130 | dfh.layer_quality.fillna(0, inplace=True) 131 | return dfh.layer_quality.values 132 | 133 | def get_event(i): 134 | return 'event000000%03d' % int(i) 135 | 136 | def work_sub(param): 137 | # computes tracks for event i 138 | (i, num_iter, max_cluster, quality_threshold) = param 139 | 140 | event = get_event(i) 141 | print('event:', event) 142 | hits = pd.read_csv(base_path+'input/test/'+event + '-hits.csv') 143 | data = hits 144 | print(data.shape) 145 | data['event_id'] = i 146 | data['track_id'] = 0 147 | data['layer'] = 100 * data.volume_id + data.layer_id 148 | data['theta_ratio'] = 1 - (np.abs(data.z + 200) / 6000)**2.4 + 0.005 149 | 150 | # load (alpha0, z) values from train events 151 | scan_center = np.load('../data/scan_center.npy') 152 | 153 | # load tracklet frequencies 154 | with open('../data/layers_4_center_fix.pkl', 'rb') as file: 155 | cnt = pkl.load(file) 156 | 157 | # compute tracks and save them 158 | model = Clusterer(eps=0.0028, max_cluster=max_cluster, scan_center=scan_center, cnt=cnt, 159 | quality_threshold=quality_threshold, event=event) 160 | data = model.fit_predict(data, num_iter) 161 | data[['event_id', 'hit_id', 'track_id']].to_csv(base_path+'submissions/final/' + event, index=False) 162 | return i 163 | 164 | def main(): 165 | max_cluster = 20 166 | quality_threshold = 5 167 | num_iter=60000 168 | 169 | # compute each event tracks in parallel. 170 | # number of process in pool should be close to the number of processors 171 | params = [(i, num_iter, max_cluster, quality_threshold) for i in range(125)] 172 | 173 | if 1: 174 | pool = Pool(processes=21, maxtasksperchild=1) 175 | ls = pool.map( work_sub, params, chunksize=1 ) 176 | pool.close() 177 | else: 178 | ls = [work_sub(param) for param in params] 179 | 180 | # computes submission by concatenating each event tracks 181 | submissions = [] 182 | for i in (range(125)): 183 | event = get_event(i) 184 | data0 = pd.read_csv('../submissions/final/' + event) 185 | submissions.append(data0) 186 | 187 | submission = pd.concat(submissions, axis=0) 188 | submission.track_id = (submission.track_id).astype('int64') 189 | submission.to_csv('../submissions/sub_final.csv', index=False) 190 | 191 | if __name__ == "__main__": 192 | main() 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /src/final_inner_test.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import pickle as pkl 6 | from sklearn.cluster import DBSCAN 7 | from itertools import combinations 8 | from multiprocessing import Pool 9 | #from tqdm import tqdm_notebook as tqdm 10 | 11 | base_path = '/home/jfpuget/Kaggle/TrackML/' 12 | 13 | class Clusterer(object): 14 | 15 | def __init__(self, eps, max_cluster, scan_center, quality_threshold, cnt, event): 16 | self.eps_ = eps 17 | self.max_cluster_ = max_cluster 18 | self.scan_center_ = scan_center 19 | self.quality_threshold_ = quality_threshold 20 | self.cnt_ = cnt 21 | self.event_ = event 22 | 23 | def fit_alpha(self, data, alpha0, z): 24 | # find tracks originating from (0, 0, z) and diameter 1/alpha0 25 | cond0 = np.abs(data.rt * alpha0) < 1 26 | data1 = data[cond0].copy() 27 | data1['theta0'] = np.arcsin(data1.rt * alpha0) 28 | data1['rt1'] = data1.theta0 / alpha0 29 | data1['theta'] = data1.theta_base + data1.theta0 30 | 31 | data1['xcos'] = np.cos(data1.theta) 32 | data1['ysin'] = np.sin(data1.theta) 33 | data1['zr'] = (np.arcsinh((data1.z - z) / (0.7 * data1.rt1)) / 3.5) 34 | dfs = data1[['xcos', 'ysin', 'zr']] 35 | clusters0 = DBSCAN(eps=self.eps_, 36 | min_samples=2, 37 | metric='euclidean', 38 | n_jobs=1).fit(dfs).labels_ 39 | clusters = np.zeros(data.shape[0]) 40 | clusters[cond0] = clusters0 + 1 41 | 42 | # merge track candidates form DBSCAN with previously found track candidates 43 | maxs1 = data['s1'].max() 44 | clusters[clusters > 0] += maxs1 45 | data['s2'] = clusters 46 | data['N2'] = count_module(data, 's2') 47 | data.loc[data.N2 < 2, 's2'] = 0 48 | data.loc[data.N2 > self.max_cluster_, 's2'] = 0 49 | data['Q2'] = get_all_layer_quality(data, 's2', self.cnt_) 50 | data['WN2'] = data.N2 * data.Q2 51 | data.loc[data.WN2 <= data.WN1, 's2'] = 0 52 | data['N2'] = count_module(data, 's2') 53 | data['WN2'] = data.N2 * data.Q2 54 | cond = ( (data['WN2'] > data['WN1']) \ 55 | & (data['Q2'] > self.quality_threshold_) \ 56 | & (data['N2'] < self.max_cluster_) 57 | ) 58 | data.loc[cond, 's1'] = data.loc[cond, 's2'] 59 | data.loc[cond, 'Q1'] = data.loc[cond, 'Q2'] 60 | 61 | # update current tracks statistics 62 | data['N1'] = count_module(data, 's1') 63 | cond = ((data['N1'] >= self.max_cluster_) ) 64 | data.loc[cond, 'N1'] = 0 65 | data.loc[cond, 's1'] = 0 66 | data.loc[cond, 'Q1'] = 0 67 | data['WN1'] = data.N1 * data.Q1 68 | 69 | data['track_id'] = data['s1'] 70 | return data 71 | 72 | def fit_predict(self, data, n_iter): 73 | # initialize track statistics with no tracks 74 | data['theta_base'] = np.arctan2(data.y, data.x) 75 | data['rt'] = np.sqrt(data.x**2 + data.y**2) 76 | data['theta0'] = 0.0 77 | data['theta'] = 0.0 78 | data['rt1'] = 0.0 79 | data['zr'] = 0.0 80 | data['s1'] = data.track_id 81 | data['Q1'] = get_all_layer_quality(data, 's1', self.cnt_) 82 | data['N1'] = count_module(data, 's1') 83 | data['WN1'] = data.N1 * data.Q1 84 | 85 | scan_center = self.scan_center_ 86 | 87 | np.random.seed(0) 88 | mm = 1 89 | 90 | # loop over randoly chose track parameters 91 | for ii in range(n_iter): #tqdm(range(n_iter)): 92 | n = np.random.randint(0, len(scan_center)) 93 | alpha0 = scan_center[n, 0] 94 | n = np.random.randint(0, len(scan_center)) 95 | z = scan_center[n, 1] 96 | data = self.fit_alpha(data, mm * alpha0, z) 97 | mm = - mm 98 | if ii % 1000 == 0: 99 | print(self.event_, '%05d' % ii) 100 | return data 101 | 102 | def count_module(dfh, col): 103 | # count the number of volumes traversed by track candidates define by values in column col 104 | dfmod = dfh.groupby([col, 'volume_id', 'layer_id']).hit_id.count() 105 | dfmod = dfmod.to_frame('n_volume_layer').reset_index().groupby(col).n_volume_layer.count().reset_index() 106 | dfmod = dfh[[col]].merge(dfmod[[col, 'n_volume_layer']], how='left', on=col) 107 | dfmod1 = dfh.groupby([col, 'volume_id', 'layer_id', 'module_id']).hit_id.count() 108 | dfmod1 = dfmod1.to_frame('n_volume_layer_module').reset_index().groupby(col).n_volume_layer_module.count().reset_index() 109 | dfmod = dfmod.merge(dfmod1[[col, 'n_volume_layer_module']], how='left', on=col) 110 | dfmod.loc[dfmod[col] == 0, 'n_volume_layer'] = 0 111 | dfmod.loc[dfmod['n_volume_layer_module'] <= 3, 'n_volume_layer'] = 0 112 | return dfmod.n_volume_layer.values 113 | 114 | def get_all_layer_quality(dfh, col, cnt): 115 | # computes quality for all track candidates defined by column col 116 | # cnt is a dictionary containing tracklet frequencies 117 | 118 | def get_layer_quality(layers, cnt=cnt): 119 | # compute smoothed geometric average of tracklet frequencies for one track candidate. 120 | layers = [str(x) for x in layers] 121 | if len(layers) <= 3: 122 | return 0 123 | quality = ([cnt[' '.join(layers[i:i+4])] for i in range(len(layers) - 3)]) 124 | quality = np.mean(np.log1p(quality)) 125 | return quality 126 | 127 | dfmod = dfh.sort_values(by=[col, 'z']) 128 | df = dfmod[dfmod[col] > 0].groupby([col]).layer.apply(get_layer_quality).to_frame('layer_quality').reset_index() 129 | dfh = dfh[[col]].merge(df, how='left', on=col) 130 | dfh.layer_quality.fillna(0, inplace=True) 131 | return dfh.layer_quality.values 132 | 133 | def get_event(i): 134 | return 'event000000%03d' % int(i) 135 | 136 | def work_sub(param): 137 | # computes tracks for event i 138 | (i, num_iter, max_cluster, quality_threshold) = param 139 | 140 | event = get_event(i) 141 | print('event:', event) 142 | hits = pd.read_csv(base_path+'input/test/'+event + '-hits.csv') 143 | data = hits 144 | # only consider hits from inner volumes 145 | data = data[data.volume_id < 10].copy() 146 | print(data.shape) 147 | data['event_id'] = i 148 | data['track_id'] = 0 149 | data['layer'] = 100 * data.volume_id + data.layer_id 150 | 151 | # load (alpha0, z) values from train events 152 | scan_center = np.load('../data/scan_center.npy') 153 | 154 | # load tracklet frequencies 155 | with open('../data/layers_4_center_fix.pkl', 'rb') as file: 156 | cnt = pkl.load(file) 157 | 158 | # compute tracks and save them 159 | model = Clusterer(eps=0.0022, max_cluster=max_cluster, scan_center=scan_center, cnt=cnt, 160 | quality_threshold=quality_threshold, event=event) 161 | data = model.fit_predict(data, num_iter) 162 | data[['event_id', 'hit_id', 'track_id']].to_csv(base_path+'submissions/final_inner/' + event, index=False) 163 | return i 164 | 165 | def main(): 166 | max_cluster = 20 167 | quality_threshold = 5 168 | num_iter=25000 169 | 170 | # compute each event tracks in parallel. 171 | # number of process in pool should be close to the number of processors 172 | params = [(i, num_iter, max_cluster, quality_threshold) for i in range(125)] 173 | 174 | if 1: 175 | pool = Pool(processes=21, maxtasksperchild=1) 176 | ls = pool.map( work_sub, params, chunksize=1 ) 177 | pool.close() 178 | else: 179 | ls = [work_sub(param) for param in params] 180 | 181 | # computes submission by concatenating each event tracks 182 | submissions = [] 183 | for i in (range(125)): 184 | event = get_event(i) 185 | data0 = pd.read_csv('../submissions/final_inner/' + event) 186 | submissions.append(data0) 187 | 188 | submission = pd.concat(submissions, axis=0) 189 | submission.track_id = (submission.track_id).astype('int64') 190 | submission.to_csv('../submissions/sub_final_inner.csv', index=False) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() 195 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/scan_center_final.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import pickle as pkl\n", 14 | "import gc\n", 15 | "\n", 16 | "from itertools import combinations\n", 17 | "from tqdm import tqdm_notebook as tqdm\n", 18 | "\n", 19 | "from matplotlib import pyplot as plt\n", 20 | "%matplotlib inline\n", 21 | "\n", 22 | "from trackml.dataset import load_event\n", 23 | "\n", 24 | "pd.options.display.max_columns = 200" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 6, 30 | "metadata": { 31 | "scrolled": true 32 | }, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "event: ../input/train_1/event000001010\n", 39 | "event: ../input/train_1/event000001011\n", 40 | "event: ../input/train_1/event000001012\n", 41 | "event: ../input/train_1/event000001013\n", 42 | "event: ../input/train_1/event000001014\n", 43 | "event: ../input/train_1/event000001015\n", 44 | "event: ../input/train_1/event000001016\n", 45 | "event: ../input/train_1/event000001017\n", 46 | "event: ../input/train_1/event000001018\n", 47 | "event: ../input/train_1/event000001019\n", 48 | "event: ../input/train_1/event000001020\n", 49 | "event: ../input/train_1/event000001021\n", 50 | "event: ../input/train_1/event000001022\n", 51 | "event: ../input/train_1/event000001023\n", 52 | "event: ../input/train_1/event000001024\n", 53 | "event: ../input/train_1/event000001025\n", 54 | "event: ../input/train_1/event000001026\n", 55 | "event: ../input/train_1/event000001027\n", 56 | "event: ../input/train_1/event000001028\n", 57 | "event: ../input/train_1/event000001029\n", 58 | "event: ../input/train_1/event000001030\n", 59 | "event: ../input/train_1/event000001031\n", 60 | "event: ../input/train_1/event000001032\n", 61 | "event: ../input/train_1/event000001033\n", 62 | "event: ../input/train_1/event000001034\n", 63 | "event: ../input/train_1/event000001035\n", 64 | "event: ../input/train_1/event000001036\n", 65 | "event: ../input/train_1/event000001037\n", 66 | "event: ../input/train_1/event000001038\n", 67 | "event: ../input/train_1/event000001039\n", 68 | "event: ../input/train_1/event000001040\n", 69 | "event: ../input/train_1/event000001041\n", 70 | "event: ../input/train_1/event000001042\n", 71 | "event: ../input/train_1/event000001043\n", 72 | "event: ../input/train_1/event000001044\n", 73 | "event: ../input/train_1/event000001045\n", 74 | "event: ../input/train_1/event000001046\n", 75 | "event: ../input/train_1/event000001047\n", 76 | "event: ../input/train_1/event000001048\n", 77 | "event: ../input/train_1/event000001049\n", 78 | "event: ../input/train_1/event000001050\n", 79 | "event: ../input/train_1/event000001051\n", 80 | "event: ../input/train_1/event000001052\n", 81 | "event: ../input/train_1/event000001053\n", 82 | "event: ../input/train_1/event000001054\n", 83 | "event: ../input/train_1/event000001055\n", 84 | "event: ../input/train_1/event000001056\n", 85 | "event: ../input/train_1/event000001057\n", 86 | "event: ../input/train_1/event000001058\n", 87 | "event: ../input/train_1/event000001059\n", 88 | "event: ../input/train_1/event000001060\n", 89 | "event: ../input/train_1/event000001061\n", 90 | "event: ../input/train_1/event000001062\n", 91 | "event: ../input/train_1/event000001063\n", 92 | "event: ../input/train_1/event000001064\n", 93 | "event: ../input/train_1/event000001065\n", 94 | "event: ../input/train_1/event000001066\n", 95 | "event: ../input/train_1/event000001067\n", 96 | "event: ../input/train_1/event000001068\n", 97 | "event: ../input/train_1/event000001069\n", 98 | "event: ../input/train_1/event000001070\n", 99 | "event: ../input/train_1/event000001071\n", 100 | "event: ../input/train_1/event000001072\n", 101 | "event: ../input/train_1/event000001073\n", 102 | "event: ../input/train_1/event000001074\n", 103 | "event: ../input/train_1/event000001075\n", 104 | "event: ../input/train_1/event000001076\n", 105 | "event: ../input/train_1/event000001077\n", 106 | "event: ../input/train_1/event000001078\n", 107 | "event: ../input/train_1/event000001079\n", 108 | "event: ../input/train_1/event000001080\n", 109 | "event: ../input/train_1/event000001081\n", 110 | "event: ../input/train_1/event000001082\n", 111 | "event: ../input/train_1/event000001083\n", 112 | "event: ../input/train_1/event000001084\n", 113 | "event: ../input/train_1/event000001085\n", 114 | "event: ../input/train_1/event000001086\n", 115 | "event: ../input/train_1/event000001087\n", 116 | "event: ../input/train_1/event000001088\n", 117 | "event: ../input/train_1/event000001089\n", 118 | "event: ../input/train_1/event000001090\n", 119 | "event: ../input/train_1/event000001091\n", 120 | "event: ../input/train_1/event000001092\n", 121 | "event: ../input/train_1/event000001093\n", 122 | "event: ../input/train_1/event000001094\n", 123 | "event: ../input/train_1/event000001095\n", 124 | "event: ../input/train_1/event000001096\n", 125 | "event: ../input/train_1/event000001097\n", 126 | "event: ../input/train_1/event000001098\n", 127 | "event: ../input/train_1/event000001099\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "data_l = []\n", 133 | "for i in range(10,100):\n", 134 | " event = '../input/train_1/event0000010%d' % i\n", 135 | " print('event:', event)\n", 136 | " hits, cells, particles, truth = load_event(event)\n", 137 | " data = hits\n", 138 | " data = data.merge(truth, how='left', on='hit_id')\n", 139 | " data = data.merge(particles, how='left', on='particle_id')\n", 140 | " data['rv'] = np.sqrt(data.vx**2 + data.vy**2)\n", 141 | " data = data[(data.rv <= 1) & (data.vz <= 50) & (data.vz >=-50)].copy()\n", 142 | " data = data[data.weight > 0]\n", 143 | " data['event_id'] = i\n", 144 | " \n", 145 | " data['pt'] = np.sqrt(data.px ** 2 + data.py ** 2)\n", 146 | " data['alpha0'] = np.exp(-8.115 - np.log(data.pt))\n", 147 | " \n", 148 | " data_l.append(data)\n", 149 | "\n", 150 | "data = pd.concat(data_l, axis=0)\n", 151 | "data = data.sample(frac=1, random_state=0)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "text/plain": [ 162 | "(591203, 2)" 163 | ] 164 | }, 165 | "execution_count": 7, 166 | "metadata": {}, 167 | "output_type": "execute_result" 168 | } 169 | ], 170 | "source": [ 171 | "df = data.groupby(['event_id', 'particle_id'])[['alpha0', 'vz']].first()\n", 172 | "df = df.dropna()\n", 173 | "df.shape" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 8, 179 | "metadata": { 180 | "collapsed": true 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "np.save('../data/scan_center.npy', df.values)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "data": { 194 | "text/html": [ 195 | "
\n", 196 | "\n", 209 | "\n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | "
alpha0vz
event_idparticle_id
1045038057858007040.0000128.77349
45039432247541760.0000148.77349
45040119442309120.0001118.77349
45040806637076480.0000778.77349
45041493831843840.0002828.77349
45042868221378560.0000868.77349
45043555416145920.0000458.77349
45046991389982720.0006768.77349
45049740169052160.0005308.77349
45052488948121600.0000638.77349
45053176142888960.0001648.77349
45054550532423680.0002538.77349
45055237727191040.0007698.77349
45055924921958400.0003368.77349
45056612116725760.0009258.77349
45057986506260480.0014518.77349
45058673701027840.0007288.77349
45064858453934080.0004708.77349
45068981622538240.0007928.77349
45069668817305600.0000828.77349
45071043206840320.0002428.77349
45072417596375040.0009328.77349
45073104791142400.0001048.77349
45074479180677120.0000978.77349
45075166375444480.0001638.77349
45075853570211840.0000738.77349
45078602349281280.0000438.77349
45079289544048640.0002118.77349
45079976738816000.0003018.77349
45081351128350720.0000728.77349
............
999277478454301818880.000653-4.81552
9277479828691353600.000590-4.81552
9277481203080888320.000667-4.81552
9277482577470423040.000343-4.81552
9277483264665190400.000756-4.81552
9277483951859957760.000312-4.81552
9277484639054725120.001089-4.81552
9277485326249492480.000872-4.81552
9277495634171002880.000606-4.81552
9277496321365770240.000192-4.81552
9277502506118676480.000513-4.81552
9277508003676815360.001171-4.81552
9277508690871582720.000838-4.81552
9277510752455884800.001128-4.81552
9277515562819256320.001347-4.81552
9277517624403558400.000463-4.81552
9277518311598325760.000330-4.81552
9277523121961697280.001004-4.81552
9277523809156464640.001004-4.81552
9322451915851694080.001669-8.56420
9322455351825530880.000532-8.56420
9322456039020298240.000467-8.56420
9322460162188902400.000606-8.56420
9322460849383669760.001358-8.56420
9322465659747041280.001310-8.56420
9322466346941808640.001439-8.56420
9322467034136576000.000648-8.56420
9322467721331343360.001242-8.56420
9322468408526110720.000960-8.56420
9322469095720878080.000994-8.56420
\n", 535 | "

591203 rows × 2 columns

\n", 536 | "
" 537 | ], 538 | "text/plain": [ 539 | " alpha0 vz\n", 540 | "event_id particle_id \n", 541 | "10 4503805785800704 0.000012 8.77349\n", 542 | " 4503943224754176 0.000014 8.77349\n", 543 | " 4504011944230912 0.000111 8.77349\n", 544 | " 4504080663707648 0.000077 8.77349\n", 545 | " 4504149383184384 0.000282 8.77349\n", 546 | " 4504286822137856 0.000086 8.77349\n", 547 | " 4504355541614592 0.000045 8.77349\n", 548 | " 4504699138998272 0.000676 8.77349\n", 549 | " 4504974016905216 0.000530 8.77349\n", 550 | " 4505248894812160 0.000063 8.77349\n", 551 | " 4505317614288896 0.000164 8.77349\n", 552 | " 4505455053242368 0.000253 8.77349\n", 553 | " 4505523772719104 0.000769 8.77349\n", 554 | " 4505592492195840 0.000336 8.77349\n", 555 | " 4505661211672576 0.000925 8.77349\n", 556 | " 4505798650626048 0.001451 8.77349\n", 557 | " 4505867370102784 0.000728 8.77349\n", 558 | " 4506485845393408 0.000470 8.77349\n", 559 | " 4506898162253824 0.000792 8.77349\n", 560 | " 4506966881730560 0.000082 8.77349\n", 561 | " 4507104320684032 0.000242 8.77349\n", 562 | " 4507241759637504 0.000932 8.77349\n", 563 | " 4507310479114240 0.000104 8.77349\n", 564 | " 4507447918067712 0.000097 8.77349\n", 565 | " 4507516637544448 0.000163 8.77349\n", 566 | " 4507585357021184 0.000073 8.77349\n", 567 | " 4507860234928128 0.000043 8.77349\n", 568 | " 4507928954404864 0.000211 8.77349\n", 569 | " 4507997673881600 0.000301 8.77349\n", 570 | " 4508135112835072 0.000072 8.77349\n", 571 | "... ... ...\n", 572 | "99 927747845430181888 0.000653 -4.81552\n", 573 | " 927747982869135360 0.000590 -4.81552\n", 574 | " 927748120308088832 0.000667 -4.81552\n", 575 | " 927748257747042304 0.000343 -4.81552\n", 576 | " 927748326466519040 0.000756 -4.81552\n", 577 | " 927748395185995776 0.000312 -4.81552\n", 578 | " 927748463905472512 0.001089 -4.81552\n", 579 | " 927748532624949248 0.000872 -4.81552\n", 580 | " 927749563417100288 0.000606 -4.81552\n", 581 | " 927749632136577024 0.000192 -4.81552\n", 582 | " 927750250611867648 0.000513 -4.81552\n", 583 | " 927750800367681536 0.001171 -4.81552\n", 584 | " 927750869087158272 0.000838 -4.81552\n", 585 | " 927751075245588480 0.001128 -4.81552\n", 586 | " 927751556281925632 0.001347 -4.81552\n", 587 | " 927751762440355840 0.000463 -4.81552\n", 588 | " 927751831159832576 0.000330 -4.81552\n", 589 | " 927752312196169728 0.001004 -4.81552\n", 590 | " 927752380915646464 0.001004 -4.81552\n", 591 | " 932245191585169408 0.001669 -8.56420\n", 592 | " 932245535182553088 0.000532 -8.56420\n", 593 | " 932245603902029824 0.000467 -8.56420\n", 594 | " 932246016218890240 0.000606 -8.56420\n", 595 | " 932246084938366976 0.001358 -8.56420\n", 596 | " 932246565974704128 0.001310 -8.56420\n", 597 | " 932246634694180864 0.001439 -8.56420\n", 598 | " 932246703413657600 0.000648 -8.56420\n", 599 | " 932246772133134336 0.001242 -8.56420\n", 600 | " 932246840852611072 0.000960 -8.56420\n", 601 | " 932246909572087808 0.000994 -8.56420\n", 602 | "\n", 603 | "[591203 rows x 2 columns]" 604 | ] 605 | }, 606 | "execution_count": 9, 607 | "metadata": {}, 608 | "output_type": "execute_result" 609 | } 610 | ], 611 | "source": [ 612 | "df" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": null, 618 | "metadata": { 619 | "collapsed": true 620 | }, 621 | "outputs": [], 622 | "source": [] 623 | } 624 | ], 625 | "metadata": { 626 | "kernelspec": { 627 | "display_name": "Python [conda env:tf15]", 628 | "language": "python", 629 | "name": "conda-env-tf15-py" 630 | }, 631 | "language_info": { 632 | "codemirror_mode": { 633 | "name": "ipython", 634 | "version": 3 635 | }, 636 | "file_extension": ".py", 637 | "mimetype": "text/x-python", 638 | "name": "python", 639 | "nbconvert_exporter": "python", 640 | "pygments_lexer": "ipython3", 641 | "version": "3.6.1" 642 | } 643 | }, 644 | "nbformat": 4, 645 | "nbformat_minor": 2 646 | } 647 | -------------------------------------------------------------------------------- /models/layer_prep_final.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "from sklearn.model_selection import train_test_split\n", 14 | "\n", 15 | "from sklearn.linear_model import LinearRegression, RANSACRegressor\n", 16 | "from sklearn.preprocessing import StandardScaler\n", 17 | "from sklearn.cluster import DBSCAN\n", 18 | "from sklearn.neighbors import KDTree\n", 19 | "\n", 20 | "import hdbscan\n", 21 | "from sklearn.neighbors import NearestNeighbors\n", 22 | "import pickle as pkl\n", 23 | "import gc\n", 24 | "\n", 25 | "from itertools import combinations\n", 26 | "from tqdm import tqdm_notebook as tqdm\n", 27 | "\n", 28 | "from matplotlib import pyplot as plt\n", 29 | "%matplotlib inline\n", 30 | "\n", 31 | "import seaborn as sns\n", 32 | "\n", 33 | "import mlcrate as mlc\n", 34 | "\n", 35 | "\n", 36 | "from trackml.dataset import load_event\n", 37 | "import pickle as pkl\n", 38 | "\n", 39 | "pd.options.display.max_columns = 200" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "def score_event(truth, submission):\n", 51 | " truth = truth[['hit_id', 'particle_id', 'weight']].merge(submission, how='left', on='hit_id')\n", 52 | " df = truth.groupby(['track_id', 'particle_id']).hit_id.count().to_frame('count_both').reset_index()\n", 53 | " \n", 54 | " df1 = df.groupby(['particle_id']).count_both.sum().to_frame('count_particle').reset_index()\n", 55 | " df = df.merge(df1, how='left', on='particle_id')\n", 56 | " df1 = df.groupby(['track_id']).count_both.sum().to_frame('count_track').reset_index()\n", 57 | " df = df.merge(df1, how='left', on='track_id')\n", 58 | " df['valid'] = (df.count_both > 0.5*np.maximum(df.count_particle, df.count_track))\n", 59 | " truth = truth.merge(df[['track_id', 'particle_id', 'valid']], how='left', on=['track_id', 'particle_id'])\n", 60 | "\n", 61 | " score = truth[truth.valid].weight.sum()\n", 62 | " return score" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": { 69 | "collapsed": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "def score_data(data):\n", 74 | " truth = data\n", 75 | " truth['count_both'] = truth.groupby(['track_id', 'particle_id']).hit_id.transform('count') \n", 76 | " truth['count_particle'] = truth.groupby(['particle_id']).hit_id.transform('count')\n", 77 | " truth['count_track'] = truth.groupby(['track_id']).hit_id.transform('count')\n", 78 | " truth['valid'] = (truth.count_both > 0.5*truth.count_particle) & (truth.count_both > 0.5*truth.count_track)\n", 79 | " score = truth[truth.valid].weight.sum()\n", 80 | " truth.loc[truth.track_id == 0, 'count_track'] = 0\n", 81 | " return score" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": { 88 | "scrolled": true 89 | }, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "event: ../input/train_1/event000001010\n", 96 | "event: ../input/train_1/event000001011\n", 97 | "event: ../input/train_1/event000001012\n", 98 | "event: ../input/train_1/event000001013\n", 99 | "event: ../input/train_1/event000001014\n", 100 | "event: ../input/train_1/event000001015\n", 101 | "event: ../input/train_1/event000001016\n", 102 | "event: ../input/train_1/event000001017\n", 103 | "event: ../input/train_1/event000001018\n", 104 | "event: ../input/train_1/event000001019\n", 105 | "event: ../input/train_1/event000001020\n", 106 | "event: ../input/train_1/event000001021\n", 107 | "event: ../input/train_1/event000001022\n", 108 | "event: ../input/train_1/event000001023\n", 109 | "event: ../input/train_1/event000001024\n", 110 | "event: ../input/train_1/event000001025\n", 111 | "event: ../input/train_1/event000001026\n", 112 | "event: ../input/train_1/event000001027\n", 113 | "event: ../input/train_1/event000001028\n", 114 | "event: ../input/train_1/event000001029\n", 115 | "event: ../input/train_1/event000001030\n", 116 | "event: ../input/train_1/event000001031\n", 117 | "event: ../input/train_1/event000001032\n", 118 | "event: ../input/train_1/event000001033\n", 119 | "event: ../input/train_1/event000001034\n", 120 | "event: ../input/train_1/event000001035\n", 121 | "event: ../input/train_1/event000001036\n", 122 | "event: ../input/train_1/event000001037\n", 123 | "event: ../input/train_1/event000001038\n", 124 | "event: ../input/train_1/event000001039\n", 125 | "event: ../input/train_1/event000001040\n", 126 | "event: ../input/train_1/event000001041\n", 127 | "event: ../input/train_1/event000001042\n", 128 | "event: ../input/train_1/event000001043\n", 129 | "event: ../input/train_1/event000001044\n", 130 | "event: ../input/train_1/event000001045\n", 131 | "event: ../input/train_1/event000001046\n", 132 | "event: ../input/train_1/event000001047\n", 133 | "event: ../input/train_1/event000001048\n", 134 | "event: ../input/train_1/event000001049\n", 135 | "event: ../input/train_1/event000001050\n", 136 | "event: ../input/train_1/event000001051\n", 137 | "event: ../input/train_1/event000001052\n", 138 | "event: ../input/train_1/event000001053\n", 139 | "event: ../input/train_1/event000001054\n", 140 | "event: ../input/train_1/event000001055\n", 141 | "event: ../input/train_1/event000001056\n", 142 | "event: ../input/train_1/event000001057\n", 143 | "event: ../input/train_1/event000001058\n", 144 | "event: ../input/train_1/event000001059\n", 145 | "event: ../input/train_1/event000001060\n", 146 | "event: ../input/train_1/event000001061\n", 147 | "event: ../input/train_1/event000001062\n", 148 | "event: ../input/train_1/event000001063\n", 149 | "event: ../input/train_1/event000001064\n", 150 | "event: ../input/train_1/event000001065\n", 151 | "event: ../input/train_1/event000001066\n", 152 | "event: ../input/train_1/event000001067\n", 153 | "event: ../input/train_1/event000001068\n", 154 | "event: ../input/train_1/event000001069\n", 155 | "event: ../input/train_1/event000001070\n", 156 | "event: ../input/train_1/event000001071\n", 157 | "event: ../input/train_1/event000001072\n", 158 | "event: ../input/train_1/event000001073\n", 159 | "event: ../input/train_1/event000001074\n", 160 | "event: ../input/train_1/event000001075\n", 161 | "event: ../input/train_1/event000001076\n", 162 | "event: ../input/train_1/event000001077\n", 163 | "event: ../input/train_1/event000001078\n", 164 | "event: ../input/train_1/event000001079\n", 165 | "event: ../input/train_1/event000001080\n", 166 | "event: ../input/train_1/event000001081\n", 167 | "event: ../input/train_1/event000001082\n", 168 | "event: ../input/train_1/event000001083\n", 169 | "event: ../input/train_1/event000001084\n", 170 | "event: ../input/train_1/event000001085\n", 171 | "event: ../input/train_1/event000001086\n", 172 | "event: ../input/train_1/event000001087\n", 173 | "event: ../input/train_1/event000001088\n", 174 | "event: ../input/train_1/event000001089\n", 175 | "event: ../input/train_1/event000001090\n", 176 | "event: ../input/train_1/event000001091\n", 177 | "event: ../input/train_1/event000001092\n", 178 | "event: ../input/train_1/event000001093\n", 179 | "event: ../input/train_1/event000001094\n", 180 | "event: ../input/train_1/event000001095\n", 181 | "event: ../input/train_1/event000001096\n", 182 | "event: ../input/train_1/event000001097\n", 183 | "event: ../input/train_1/event000001098\n", 184 | "event: ../input/train_1/event000001099\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "data_l = []\n", 190 | "for i in range(10,100):\n", 191 | " event = '../input/train_1/event0000010%d' % i\n", 192 | " print('event:', event)\n", 193 | " hits, cells, particles, truth = load_event(event)\n", 194 | " data = hits\n", 195 | " data = data.merge(truth, how='left', on='hit_id')\n", 196 | " data = data.merge(particles, how='left', on='particle_id')\n", 197 | " data['rv'] = np.sqrt(data.vx**2 + data.vy**2)\n", 198 | " data = data[(data.rv <= 1) & (data.vz <= 50) & (data.vz >= -50)].copy()\n", 199 | " data = data[data.weight > 0]\n", 200 | " data['event_id'] = i\n", 201 | " \n", 202 | " data_l.append(data)\n", 203 | "\n", 204 | "data = pd.concat(data_l, axis=0)\n", 205 | "data = data.sample(frac=1, random_state=0)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 5, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "text/html": [ 216 | "
\n", 217 | "\n", 230 | "\n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | "
layers
event_idparticle_id
1045038057858007041704 1702 1308 1306 1304 1302 1302 808 806 804...
4503943224754176802 802 804 806 1302 1402 1404 1406 1408 1810 ...
4504011944230912802 804 806 902 1302 1402 1402 1404 1406 1408 ...
4504080663707648802 804 806 1302 1302 1302 1304 1402 1404 1406...
4504149383184384802 804 806 806 808 1302 1304 1304 1402 1404 1...
4504286822137856802 802 802 804 804 902 904 906 908 910 1406
4504355541614592802 802 804 902 904 904 906 906 908 908 910 14...
4504699138998272802 802 804 902 902 904 904 906 906
4504974016905216802 902 904 906 908 910 1406 1408 1410 1412
4505248894812160802 804 806 1302 1402 1404 1406 1408 1810 1812
4505317614288896802 804 806 808 1302 1304 1402 1404 1808 1810
4505455053242368802 804 806 1302 1402 1404 1406 1810 1812 1812
4505523772719104802 802 804 804 806 902 904 1404 1406 1408 141...
4505592492195840802 802 804 902 902 904 904 906 906 1402 1404 ...
4505661211672576802 802 804 806 808 1302 1302 1304
4505798650626048702 704 706 706 708 708 710 710 712 712 714 714
45058673701027841202 1204 1206 1206 1206 1208 1210 1210 710 71...
4506485845393408802 804 804 806 808 808 1302 1304 1306 1804 1804
45068981622538241202 1204 702 704 704 706 706 708 710 710 712 ...
45069668817305601602 1206 1208 1210 1212 1302 714 806 804 802
45071043206840321610 1612 1308 1306 1304 1304 1302 808 806 804...
4507241759637504802 802 802 804
4507310479114240802 802 804 902 904 904 906 908 1404 1404 1406...
4507447918067712802 804 806 1302 1402 1404 1406 1406 1408 1810...
4507516637544448802 804 804 806 1302 1402 1402 1404 1406 1408 ...
4507585357021184802 804 806 1302 1402 1404 1406 1408 1810 1812...
4507860234928128802 804 806 1302 1402 1404 1406 1408 1408 1812
4507928954404864802 804 806 808 1302 1304 1304 1402 1404 1406 ...
4507997673881600802 804 806 1302 1302 1304 1402 1402 1404 1406...
4508135112835072802 804 804 806 1302 1402 1404 1406 1408 1810 ...
.........
99927747845430181888902 902 904 904 906 906 908 908 910 912 914
927747982869135360802 804 902 902 904 904 906 1404 1404 1406 140...
9277481203080888321202 1202 1204 1206 1208 1210 1210 1212 710 71...
927748257747042304802 802 802 804 902 904 906 908 910 1406 1408 ...
927748326466519040802 802 804 806 1302 1402 1404 1406 1408 1810 ...
9277483951859957761602 1602 1604 1206 1208 1210 1212 1302 806 80...
9277484639054725121610 1612 1612 1612 1702 1308 1306 1304 1304 1...
927748532624949248702 702 704 704 706 706 706 708 708 708 710 710
927749563417100288802 802 804 806 1302 1402 1404 1406 1408 1810 ...
927749632136577024802 804 806 902 1302 1402 1404 1406 1408 1812
927750250611867648702 702 704 704 706 706 708 708 710 710 712 712
9277508003676815361308 1306 1306 1304 1304 1302 808 806 804 802
9277508690871582721602 1604 1604 1604 1606 1208 1210 1210 1210 1...
927751075245588480906 906 906 908 908 908 910 910 912 912 914 914
9277515562819256321202 1204 1206 704 706 706 708 708 710 712 714...
927751762440355840702 702 704 704 706 706 708 708 710 710 710 71...
927751831159832576702 702 704 704 706 706 706 708 708 708 710 71...
927752312196169728802 802 804 806 808 1302 1304 1306 1308 1308 1...
927752380915646464802 804 806 808 1302 1302 1304 1306 1308 1804
932245191585169408902 902 902 904 904 906 906 908 908 910 910 91...
932245535182553088802 804 806 808 1302 1304 1306 1308 1308 1702 ...
932245603902029824802 804 806 808 1302 1302 1304 1304 1306 1308 ...
9322460162188902401704 1702 1308 1306 1306 1304 1302 808 806 804...
932246084938366976802 802 802 802 804 806 808 808 1302 1304 1306...
932246565974704128802 804 806 902 904 1402 1404 1406 1408 1410 1412
932246634694180864802 902 904 906 908 910 912 1408 1410 1412
932246703413657600802 804 806 808 1302 1304 1306 1308 1702 1702 ...
932246772133134336802 804 804 806 808 1302 1304 1306 1308
932246840852611072902 902 904 906 908 910 910 912 912 914 1410 1...
932246909572087808802 804 806 808 1302 1304 1306 1306 1308 1702 ...
\n", 493 | "

591203 rows × 1 columns

\n", 494 | "
" 495 | ], 496 | "text/plain": [ 497 | " layers\n", 498 | "event_id particle_id \n", 499 | "10 4503805785800704 1704 1702 1308 1306 1304 1302 1302 808 806 804...\n", 500 | " 4503943224754176 802 802 804 806 1302 1402 1404 1406 1408 1810 ...\n", 501 | " 4504011944230912 802 804 806 902 1302 1402 1402 1404 1406 1408 ...\n", 502 | " 4504080663707648 802 804 806 1302 1302 1302 1304 1402 1404 1406...\n", 503 | " 4504149383184384 802 804 806 806 808 1302 1304 1304 1402 1404 1...\n", 504 | " 4504286822137856 802 802 802 804 804 902 904 906 908 910 1406\n", 505 | " 4504355541614592 802 802 804 902 904 904 906 906 908 908 910 14...\n", 506 | " 4504699138998272 802 802 804 902 902 904 904 906 906\n", 507 | " 4504974016905216 802 902 904 906 908 910 1406 1408 1410 1412\n", 508 | " 4505248894812160 802 804 806 1302 1402 1404 1406 1408 1810 1812\n", 509 | " 4505317614288896 802 804 806 808 1302 1304 1402 1404 1808 1810\n", 510 | " 4505455053242368 802 804 806 1302 1402 1404 1406 1810 1812 1812\n", 511 | " 4505523772719104 802 802 804 804 806 902 904 1404 1406 1408 141...\n", 512 | " 4505592492195840 802 802 804 902 902 904 904 906 906 1402 1404 ...\n", 513 | " 4505661211672576 802 802 804 806 808 1302 1302 1304\n", 514 | " 4505798650626048 702 704 706 706 708 708 710 710 712 712 714 714\n", 515 | " 4505867370102784 1202 1204 1206 1206 1206 1208 1210 1210 710 71...\n", 516 | " 4506485845393408 802 804 804 806 808 808 1302 1304 1306 1804 1804\n", 517 | " 4506898162253824 1202 1204 702 704 704 706 706 708 710 710 712 ...\n", 518 | " 4506966881730560 1602 1206 1208 1210 1212 1302 714 806 804 802\n", 519 | " 4507104320684032 1610 1612 1308 1306 1304 1304 1302 808 806 804...\n", 520 | " 4507241759637504 802 802 802 804\n", 521 | " 4507310479114240 802 802 804 902 904 904 906 908 1404 1404 1406...\n", 522 | " 4507447918067712 802 804 806 1302 1402 1404 1406 1406 1408 1810...\n", 523 | " 4507516637544448 802 804 804 806 1302 1402 1402 1404 1406 1408 ...\n", 524 | " 4507585357021184 802 804 806 1302 1402 1404 1406 1408 1810 1812...\n", 525 | " 4507860234928128 802 804 806 1302 1402 1404 1406 1408 1408 1812\n", 526 | " 4507928954404864 802 804 806 808 1302 1304 1304 1402 1404 1406 ...\n", 527 | " 4507997673881600 802 804 806 1302 1302 1304 1402 1402 1404 1406...\n", 528 | " 4508135112835072 802 804 804 806 1302 1402 1404 1406 1408 1810 ...\n", 529 | "... ...\n", 530 | "99 927747845430181888 902 902 904 904 906 906 908 908 910 912 914\n", 531 | " 927747982869135360 802 804 902 902 904 904 906 1404 1404 1406 140...\n", 532 | " 927748120308088832 1202 1202 1204 1206 1208 1210 1210 1212 710 71...\n", 533 | " 927748257747042304 802 802 802 804 902 904 906 908 910 1406 1408 ...\n", 534 | " 927748326466519040 802 802 804 806 1302 1402 1404 1406 1408 1810 ...\n", 535 | " 927748395185995776 1602 1602 1604 1206 1208 1210 1212 1302 806 80...\n", 536 | " 927748463905472512 1610 1612 1612 1612 1702 1308 1306 1304 1304 1...\n", 537 | " 927748532624949248 702 702 704 704 706 706 706 708 708 708 710 710\n", 538 | " 927749563417100288 802 802 804 806 1302 1402 1404 1406 1408 1810 ...\n", 539 | " 927749632136577024 802 804 806 902 1302 1402 1404 1406 1408 1812\n", 540 | " 927750250611867648 702 702 704 704 706 706 708 708 710 710 712 712\n", 541 | " 927750800367681536 1308 1306 1306 1304 1304 1302 808 806 804 802\n", 542 | " 927750869087158272 1602 1604 1604 1604 1606 1208 1210 1210 1210 1...\n", 543 | " 927751075245588480 906 906 906 908 908 908 910 910 912 912 914 914\n", 544 | " 927751556281925632 1202 1204 1206 704 706 706 708 708 710 712 714...\n", 545 | " 927751762440355840 702 702 704 704 706 706 708 708 710 710 710 71...\n", 546 | " 927751831159832576 702 702 704 704 706 706 706 708 708 708 710 71...\n", 547 | " 927752312196169728 802 802 804 806 808 1302 1304 1306 1308 1308 1...\n", 548 | " 927752380915646464 802 804 806 808 1302 1302 1304 1306 1308 1804\n", 549 | " 932245191585169408 902 902 902 904 904 906 906 908 908 910 910 91...\n", 550 | " 932245535182553088 802 804 806 808 1302 1304 1306 1308 1308 1702 ...\n", 551 | " 932245603902029824 802 804 806 808 1302 1302 1304 1304 1306 1308 ...\n", 552 | " 932246016218890240 1704 1702 1308 1306 1306 1304 1302 808 806 804...\n", 553 | " 932246084938366976 802 802 802 802 804 806 808 808 1302 1304 1306...\n", 554 | " 932246565974704128 802 804 806 902 904 1402 1404 1406 1408 1410 1412\n", 555 | " 932246634694180864 802 902 904 906 908 910 912 1408 1410 1412\n", 556 | " 932246703413657600 802 804 806 808 1302 1304 1306 1308 1702 1702 ...\n", 557 | " 932246772133134336 802 804 804 806 808 1302 1304 1306 1308\n", 558 | " 932246840852611072 902 902 904 906 908 910 910 912 912 914 1410 1...\n", 559 | " 932246909572087808 802 804 806 808 1302 1304 1306 1306 1308 1702 ...\n", 560 | "\n", 561 | "[591203 rows x 1 columns]" 562 | ] 563 | }, 564 | "execution_count": 5, 565 | "metadata": {}, 566 | "output_type": "execute_result" 567 | } 568 | ], 569 | "source": [ 570 | "data['layer'] = 100 * data.volume_id + data.layer_id\n", 571 | "data = data.sort_values(by=['particle_id', 'z']).reset_index(drop=True)\n", 572 | "df = data.groupby(['event_id', 'particle_id']).layer.apply(lambda s: ' '.join([str(i) for i in s]))\n", 573 | "\n", 574 | "df = df.to_frame('layers')\n", 575 | "df" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 6, 581 | "metadata": {}, 582 | "outputs": [ 583 | { 584 | "data": { 585 | "application/vnd.jupyter.widget-view+json": { 586 | "model_id": "266476ebbb684e18850c8810e5af8670" 587 | } 588 | }, 589 | "metadata": {}, 590 | "output_type": "display_data" 591 | }, 592 | { 593 | "name": "stdout", 594 | "output_type": "stream", 595 | "text": [ 596 | "\n" 597 | ] 598 | } 599 | ], 600 | "source": [ 601 | "from collections import Counter\n", 602 | "\n", 603 | "cnt = Counter()\n", 604 | "\n", 605 | "for x in tqdm(df.itertuples(name=None, index=False)):\n", 606 | " layers = x[0].split()\n", 607 | " for i in range(len(layers) - 3):\n", 608 | " s = ' '.join(layers[i:i+4])\n", 609 | " cnt[s] += 1\n", 610 | " " 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 7, 616 | "metadata": { 617 | "collapsed": true 618 | }, 619 | "outputs": [], 620 | "source": [ 621 | "with open('../data/layers_4_center_fix.pkl', 'wb') as file:\n", 622 | " pkl.dump(cnt, file)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 8, 628 | "metadata": {}, 629 | "outputs": [ 630 | { 631 | "data": { 632 | "text/plain": [ 633 | "[('802 804 806 808', 86119),\n", 634 | " ('808 806 804 802', 86049),\n", 635 | " ('804 806 808 1302', 79225),\n", 636 | " ('1302 808 806 804', 78704),\n", 637 | " ('806 808 1302 1304', 78216),\n", 638 | " ('1304 1302 808 806', 77929),\n", 639 | " ('1306 1304 1302 808', 65604),\n", 640 | " ('808 1302 1304 1306', 65527),\n", 641 | " ('902 902 904 904', 52282),\n", 642 | " ('712 712 714 714', 52111)]" 643 | ] 644 | }, 645 | "execution_count": 8, 646 | "metadata": {}, 647 | "output_type": "execute_result" 648 | } 649 | ], 650 | "source": [ 651 | "cnt.most_common(10)" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": null, 657 | "metadata": { 658 | "collapsed": true 659 | }, 660 | "outputs": [], 661 | "source": [] 662 | } 663 | ], 664 | "metadata": { 665 | "kernelspec": { 666 | "display_name": "Python [conda env:tf15]", 667 | "language": "python", 668 | "name": "conda-env-tf15-py" 669 | }, 670 | "language_info": { 671 | "codemirror_mode": { 672 | "name": "ipython", 673 | "version": 3 674 | }, 675 | "file_extension": ".py", 676 | "mimetype": "text/x-python", 677 | "name": "python", 678 | "nbconvert_exporter": "python", 679 | "pygments_lexer": "ipython3", 680 | "version": "3.6.1" 681 | } 682 | }, 683 | "nbformat": 4, 684 | "nbformat_minor": 2 685 | } 686 | --------------------------------------------------------------------------------