├── models ├── __init__.py ├── cnn_classifier.py ├── rcnn.py └── cnn_segmentator.py ├── gui_utils ├── __init__.py ├── fine_tuning.py ├── threading.py ├── abstract_main_window.py ├── data_splitting.py ├── visualization.py ├── auxilary_utils.py ├── processing.py ├── training.py ├── evaluation.py └── mining.py ├── cython_utils ├── __init__.py └── roi.pyx ├── data └── tmp_weights │ └── .gitkeep ├── processing_utils ├── __init__.py ├── postprocess.py ├── roi.py ├── runner.py └── matching.py ├── training_utils ├── __init__.py ├── dataset.py └── training.py ├── requirements.txt ├── setup.py ├── .gitignore ├── LICENSE ├── README.md ├── tests ├── run_utils_test.py └── matching_test.py └── peakonly.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gui_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cython_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/tmp_weights/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gui_utils/fine_tuning.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /processing_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bintrees 2 | matplotlib 3 | numpy 4 | pandas 5 | pymzML 6 | PyQt5 7 | scipy 8 | torch >= 1.2.0 9 | tqdm 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from Cython.Build import cythonize 3 | import numpy as np 4 | 5 | setup( 6 | ext_modules=cythonize('cython_utils/roi.pyx'), 7 | include_dirs=[np.get_include()] 8 | ) 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.mzML 2 | *.tmp 3 | .idea/ 4 | *.ipynb 5 | *__pycache__* 6 | data/test/ 7 | data/train/ 8 | data/val/ 9 | data/annotation/ 10 | data/weights/ 11 | data/tmp_weights/* 12 | !data/tmp_weights/.gitkeep 13 | *.pt 14 | build/ 15 | *.cpp 16 | *.so 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Arsenty Melnikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/cnn_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Block(nn.Module): 6 | def __init__(self, in_channels, out_channels, padding=2, dilation=1, stride=1): 7 | super().__init__() 8 | 9 | self.basic_block = nn.Sequential( 10 | nn.Conv1d(in_channels, out_channels, 5, padding=padding, dilation=dilation, stride=stride), 11 | nn.ReLU() 12 | ) 13 | 14 | def forward(self, x): 15 | return self.basic_block(x) 16 | 17 | 18 | class Classifier(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.convBlock = nn.Sequential( 22 | Block(1, 8), 23 | nn.MaxPool1d(kernel_size=2), 24 | Block(8, 16), 25 | nn.MaxPool1d(kernel_size=2), 26 | Block(16, 32), 27 | nn.MaxPool1d(kernel_size=2), 28 | Block(32, 64), 29 | nn.MaxPool1d(kernel_size=2), 30 | Block(64, 64), 31 | nn.MaxPool1d(kernel_size=2), 32 | Block(64, 128), 33 | nn.MaxPool1d(kernel_size=2), 34 | Block(128, 256), 35 | nn.MaxPool1d(kernel_size=2), 36 | Block(256, 512) 37 | ) 38 | self.classification = nn.Sequential( 39 | nn.Dropout(p=0.5), 40 | nn.Linear(512, 2) 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.convBlock(x) 45 | x, _ = torch.max(x, dim=2) 46 | return self.classification(x), None 47 | -------------------------------------------------------------------------------- /models/rcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class EncodingCNN(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | self.encoding = nn.Sequential( 10 | nn.Conv1d(in_channels=2, out_channels=16, kernel_size=5, padding=2), 11 | nn.ReLU(), 12 | nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, padding=2), 13 | nn.ReLU(), 14 | nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, padding=2), 15 | nn.ReLU() 16 | ) 17 | 18 | def forward(self, x): 19 | return self.encoding(x).transpose(2, 1) 20 | 21 | 22 | class RecurrentCNN(nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.encoding = EncodingCNN() 27 | self.biLSTM = nn.LSTM(64, 64, batch_first=True, bidirectional=True) 28 | self.LSTM = nn.LSTM(128, 128, batch_first=True, bidirectional=False) 29 | self.classifier = nn.Linear(128, 2) 30 | self.integrator = nn.Linear(128, 2) 31 | 32 | def _preprocessing(self, batch): 33 | batch_size, _, n_points = batch.shape 34 | processed_batch = batch.clone().view(batch_size, n_points) 35 | # TO DO: rewrite without loop 36 | for x in processed_batch: 37 | x[x < 1e-4] = 0 38 | pos = (x != 0) 39 | x[pos] = torch.log10(x[pos]) 40 | x[pos] = x[pos] - torch.min(x[pos]) 41 | x[pos] = x[pos] / torch.max(x[pos]) 42 | return processed_batch.view(batch_size, 1, n_points) 43 | 44 | def forward(self, x): 45 | x = torch.cat((x, self._preprocessing(x)), dim=1) 46 | x = self.encoding(x) 47 | x, _ = self.biLSTM(x) 48 | integrator_input, (classifier_input, _) = self.LSTM(x) 49 | classifier_output = self.classifier(classifier_input[0]) 50 | integrator_output = self.integrator(integrator_input) 51 | return classifier_output, integrator_output.transpose(2, 1) 52 | -------------------------------------------------------------------------------- /processing_utils/postprocess.py: -------------------------------------------------------------------------------- 1 | import pymzml 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | 7 | class ResultTable: 8 | def __init__(self, files, features): 9 | n_features = len(features) 10 | n_files = len(files) 11 | self.files = {k: v for v, k in enumerate(files)} 12 | self.intensities = np.zeros((n_files, n_features)) 13 | self.mz = np.zeros(n_features) 14 | self.rtmin = np.zeros(n_features) 15 | self.rtmax = np.zeros(n_features) 16 | # fill in intensities values 17 | for i, feature in enumerate(features): 18 | self.mz[i] = feature.mz 19 | self.rtmin[i] = feature.rtmin 20 | self.rtmax[i] = feature.rtmax 21 | for j, sample in enumerate(feature.samples): 22 | self.intensities[self.files[sample], i] = feature.intensities[j] 23 | 24 | def fill_zeros(self, delta_mz): 25 | print('zero filling...') 26 | for file, k in tqdm(self.files.items()): 27 | # read all scans in mzML file 28 | run = pymzml.run.Reader(file) 29 | scans = [] 30 | for scan in run: 31 | scans.append(scan) 32 | 33 | begin_time = scans[0].scan_time[0] 34 | end_time = scans[-1].scan_time[0] 35 | frequency = len(scans) / (end_time - begin_time) 36 | for m, intensity in enumerate(self.intensities[k]): 37 | if intensity == 0: 38 | mz = self.mz[m] 39 | begin = int((self.rtmin[m] - begin_time) * frequency) - 1 40 | end = int((self.rtmax[m] - begin_time) * frequency) + 1 41 | for scan in scans[begin:end]: 42 | pos = np.searchsorted(scan.mz, mz) 43 | if pos < len(scan.mz) and mz - delta_mz < scan.mz[pos] < mz + delta_mz: 44 | self.intensities[k, m] += scan.i[pos] 45 | if pos >= 1 and mz - delta_mz < scan.mz[pos - 1] < mz + delta_mz: 46 | self.intensities[k, m] += scan.i[pos - 1] 47 | 48 | def to_csv(self, path): 49 | df = pd.DataFrame() 50 | df['mz'] = self.mz 51 | df['rtmin'] = self.rtmin / 60 52 | df['rtmax'] = self.rtmax / 60 53 | for file, k in self.files.items(): 54 | df[file] = self.intensities[k] 55 | df.to_csv(path) 56 | -------------------------------------------------------------------------------- /gui_utils/threading.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore 2 | 3 | 4 | class WorkerSignals(QtCore.QObject): 5 | """ 6 | Defines the signals available from a running worker thread. 7 | 8 | Attributes 9 | ---------- 10 | finished : QtCore.pyqtSignal 11 | No data 12 | error : QtCore.pyqtSignal 13 | `tuple` (exctype, value, traceback.format_exc() ) 14 | result : QtCore.pyqtSignal 15 | `object` data returned from processing, anything 16 | progress : QtCore.pyqtSignal 17 | `int` indicating % progress 18 | download_progress : QtCore.pyqtSignal 19 | `int`, `int`, `int` used to show a count of blocks transferred, 20 | a block size in bytes, the total size of the file 21 | """ 22 | finished = QtCore.pyqtSignal() 23 | error = QtCore.pyqtSignal(tuple) 24 | result = QtCore.pyqtSignal(object) 25 | progress = QtCore.pyqtSignal(int) 26 | operation = QtCore.pyqtSignal(str) 27 | download_progress = QtCore.pyqtSignal(int, int, int) 28 | 29 | 30 | class Worker(QtCore.QRunnable): 31 | """ 32 | Worker thread 33 | 34 | Parameters 35 | ---------- 36 | function : callable 37 | Any callable object 38 | 39 | Attributes 40 | ---------- 41 | mode : str 42 | A one of two 'all in one' of 'sequential' 43 | model : nn.Module 44 | an ANN model if mode is 'all in one' (optional) 45 | classifier : nn.Module 46 | an ANN model for classification (optional) 47 | segmentator : nn.Module 48 | an ANN model for segmentation (optional) 49 | peak_minimum_points : int 50 | minimum peak length in points 51 | 52 | """ 53 | def __init__(self, function, *args, download=False, multiple_process=False, **kwargs): 54 | super(Worker, self).__init__() 55 | 56 | self.function = function 57 | self.args = args 58 | self.kwargs = kwargs 59 | self.signals = WorkerSignals() 60 | 61 | # Add the callback to our kwargs 62 | if not download: 63 | self.kwargs['progress_callback'] = self.signals.progress 64 | else: 65 | self.kwargs['progress_callback'] = self.signals.download_progress 66 | 67 | if multiple_process: 68 | self.kwargs['operation_callback'] = self.signals.operation 69 | 70 | @QtCore.pyqtSlot() 71 | def run(self): 72 | result = self.function(*self.args, **self.kwargs) 73 | self.signals.result.emit(result) # return results 74 | self.signals.finished.emit() # done 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **peakonly** 2 | ________ 3 | 4 | *peakonly* is a novel approach written in Python (v3.6) for peaks (aka features) detection in raw LC-MS data. The main idea underlying the approach is the training of two subsequent artificial neural networks to firstly classify ROIs (regions of interest) into ~~three~~ two classes (noise, peaks, ~~uncertain peaks~~) and then to determine boundaries for every peak to integrate its area. Current approach was developed for the high-resolution LC-MS data for the purposes of metabolomics, but can be applied with several adaptations in other fields that utilize data from high-resolution GC- or LC-MS techniques. 5 | 6 | - **Article**: [Deep learning for the precise peak detection in high-resolution LC-MS data, *Analytical Chemistry*.](http://dx.doi.org/10.1021/acs.analchem.9b04811) 7 | - **Releases**: https://github.com/arseha/peakonly/releases/ 8 | - **Instruction:** [detailed instruction for *peakonly* v.0.2.0-beta](https://bit.ly/peakonly_manual) 9 | - **High-level API:** [ms-peakonly](https://github.com/sorenwacker/ms-peakonly) 10 | 11 | 12 | Supported formats: 13 | 14 | - .mzML 15 | 16 | Operating System Compatibility 17 | ------------------------------ 18 | peakonly has been tested successfully with: 19 | 20 | - Ubuntu 16.04 21 | - macOS Catalina 22 | - Windows 10 23 | - Windows 7 24 | 25 | For Windows7/10 commands should be entered through Windows PowerShell. [Detailed instruction is available](https://bit.ly/peakonly_manual). Be sure that your python version is at least 3.6. 26 | 27 | 28 | Installing and running the application 29 | ---------------------------- 30 | To install and run *peakonly* you should do a few simple steps: 31 | 32 | - download [the latest release of *peakonly*](https://github.com/Arseha/peakonly/releases) 33 | - install requirements in the following automated way (or you can simply open reqirements.txt file and download listed libraries in any other convenient way): 34 | ``` 35 | pip3 install -r requirements.txt 36 | ``` 37 | - run **peakonly.py**: 38 | ``` 39 | python3 peakonly.py 40 | ``` 41 | 42 | The more detailed instruction on how to install and run the application as well as a thorough manual on how to use it is available via [the link.](https://bit.ly/peakonly_manual) 43 | 44 | 45 | Call for Contributions 46 | ---------------------- 47 | 48 | peakonly appreciates help from a wide range of different backgrounds. 49 | Small improvements or fixes are always appreciated. 50 | If you are considering larger contributions, have some questions or an offer for cooperation, 51 | please contact the main contributor (melnikov.arsenty@gmail.com). 52 | -------------------------------------------------------------------------------- /tests/run_utils_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from processing_utils import run_utils 5 | 6 | 7 | class MyTestCase(unittest.TestCase): 8 | def test_intersection(self): 9 | interval_1 = [0., 100.] 10 | interval_2 = [101., 201.] 11 | interval_3 = [25., 125.] 12 | interval_4 = [25., 50.] 13 | # symmetry: 14 | self.assertEqual(run_utils.intersection(*interval_1, *interval_3), 15 | run_utils.intersection(*interval_3, *interval_1)) 16 | self.assertEqual(run_utils.intersection(*interval_1, *interval_4), 17 | run_utils.intersection(*interval_4, *interval_1)) 18 | # No intersection: 19 | self.assertEqual(run_utils.intersection(*interval_1, *interval_2), 0.) 20 | # Identity: 21 | self.assertEqual(run_utils.intersection(*interval_1, *interval_1), 1.) 22 | self.assertEqual(run_utils.intersection(*interval_4, *interval_4), 1.) 23 | # Intersections: 24 | self.assertAlmostEqual(run_utils.intersection(*interval_1, *interval_3), 0.75) 25 | self.assertAlmostEqual(run_utils.intersection(*interval_1, *interval_4), 1.) 26 | self.assertAlmostEqual(run_utils.intersection(*interval_2, *interval_3), 0.24) 27 | 28 | def test_border2average_correction(self): 29 | # Issue #10 regression: 30 | borders = [[231, 453], [460, 477]] 31 | averaged_borders = [[232, 325], [330, 333], [333, 476]] 32 | self.assertEqual(run_utils.border2average_correction(borders, 33 | averaged_borders), 34 | [[231, 325], [453, 453], [460, 477]]) 35 | # len(borders) > len(average_borders): 36 | borders = [[232, 243], [256, 266], [268, 437], [470, 487]] 37 | averaged_borders = [[228, 243], [344, 432], [474, 488]] 38 | self.assertEqual(run_utils.border2average_correction(borders, 39 | averaged_borders), 40 | [[232, 243], [268, 437], [470, 487]]) 41 | # no borders: 42 | borders = [] 43 | averaged_borders = [[228, 243], [344, 432], [474, 488]] 44 | self.assertEqual(run_utils.border2average_correction(borders, 45 | averaged_borders), 46 | [[228, 243], [344, 432], [474, 488]]) 47 | self.assertEqual(run_utils.border2average_correction(averaged_borders, 48 | borders), 49 | []) 50 | 51 | 52 | if __name__ == '__main__': 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /models/cnn_segmentator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def preprocessing(batch): 6 | batch_size, _, n_peaks = batch.shape 7 | processed_batch = batch.clone().view(batch_size, n_peaks) 8 | # TO DO: rewrite without loop 9 | for x in processed_batch: 10 | x[x < 1e-4] = 0 11 | pos = (x != 0) 12 | x[pos] = torch.log10(x[pos]) 13 | x[pos] = x[pos] - torch.min(x[pos]) 14 | x[pos] = x[pos] / torch.max(x[pos]) 15 | return processed_batch.view(batch_size, 1, n_peaks) 16 | 17 | 18 | class Block(nn.Module): 19 | def __init__(self, in_channels, out_channels, padding=2, dilation=1, stride=1): 20 | super().__init__() 21 | 22 | self.basic_block = nn.Sequential( 23 | nn.Conv1d(in_channels, out_channels, 5, padding=padding, dilation=dilation, stride=stride), 24 | nn.BatchNorm1d(out_channels), 25 | nn.ReLU() 26 | ) 27 | 28 | def forward(self, x): 29 | return self.basic_block(x) 30 | 31 | 32 | class Segmentator(nn.Module): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | self.starter = nn.Sequential( 37 | Block(2, 16), 38 | Block(16, 20), 39 | nn.AvgPool1d(kernel_size=2), 40 | Block(20, 24), 41 | nn.AvgPool1d(kernel_size=2), 42 | Block(24, 28), 43 | nn.AvgPool1d(kernel_size=2), 44 | Block(28, 32) 45 | ) 46 | 47 | self.pass_down1 = nn.Sequential( 48 | nn.AvgPool1d(kernel_size=2), 49 | Block(32, 48) 50 | ) 51 | 52 | self.pass_down2 = nn.Sequential( 53 | nn.AvgPool1d(kernel_size=2), 54 | Block(48, 64) 55 | ) 56 | 57 | self.code = nn.Sequential( 58 | nn.AvgPool1d(kernel_size=2), 59 | Block(64, 96), 60 | nn.Upsample(scale_factor=2), 61 | Block(96, 64) 62 | ) 63 | 64 | self.pass_up2 = nn.Sequential( 65 | nn.Upsample(scale_factor=2), 66 | Block(128, 64) 67 | ) 68 | 69 | self.pass_up1 = nn.Sequential( 70 | nn.Upsample(scale_factor=2), 71 | Block(112, 48) 72 | ) 73 | 74 | self.finisher = nn.Sequential( 75 | nn.Upsample(scale_factor=2), 76 | Block(80, 64), 77 | nn.Upsample(scale_factor=2), 78 | Block(64, 32), 79 | nn.Upsample(scale_factor=2), 80 | Block(32, 16), 81 | nn.Conv1d(16, 2, 1, padding=0) 82 | ) 83 | 84 | def forward(self, x): 85 | x = torch.cat((x, preprocessing(x)), dim=1) 86 | starter = self.starter(x) 87 | pass1 = self.pass_down1(starter) 88 | pass2 = self.pass_down2(pass1) 89 | x = self.code(pass2) 90 | x = torch.cat((x, pass2), dim=1) 91 | x = self.pass_up2(x) 92 | x = torch.cat((x, pass1), dim=1) 93 | x = self.pass_up1(x) 94 | x = torch.cat((x, starter), dim=1) 95 | x = self.finisher(x) 96 | return None, x 97 | -------------------------------------------------------------------------------- /tests/matching_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | from processing_utils.matching import stitch_component, align_component 6 | from processing_utils.roi import ROI 7 | 8 | 9 | class MyTestCase(unittest.TestCase): 10 | def test_stitch_component(self): 11 | component = defaultdict(list) 12 | roi1 = ROI(scan=(48, 54), 13 | rt=(48, 54), 14 | i=[1.]*7, 15 | mz=[100] * 7, 16 | mzmean=100) 17 | roi2 = ROI(scan=(25, 30), 18 | rt=(25, 30), 19 | i=list(np.random.randn(6)), 20 | mz=[100] * 6, 21 | mzmean=100) 22 | component['sample'] = [roi1, roi2] 23 | new_component = stitch_component(component) 24 | self.assertEqual(roi1.i, list(new_component['sample'][0].i[-7:])) 25 | self.assertEqual(roi2.i, list(new_component['sample'][0].i[:6])) 26 | 27 | def test_align_component_simple(self): 28 | component = defaultdict(list) 29 | roi1_sample1 = ROI(scan=(-5, 1), 30 | rt=(5, 10), 31 | i=[0, 2, 2, 0, 1, 1, 0], 32 | mz=[105] * 7, 33 | mzmean=105) 34 | component['sample1'] = [roi1_sample1] 35 | 36 | roi1_sample2 = ROI(scan=(-6, 0), 37 | rt=(5, 10), 38 | i=[1, 1, 0, 0.9, 0.8, 0, 0], 39 | mz=[105] * 7, 40 | mzmean=105) 41 | component['sample2'] = [roi1_sample2] 42 | group = align_component(component) 43 | shifts = dict() 44 | for sample, shift in zip(group.samples, group.shifts): 45 | shifts[sample] = shift 46 | self.assertEqual({'sample1': 0, 'sample2': 2}, shifts) 47 | 48 | def test_align_component_complex(self): 49 | component = defaultdict(list) 50 | roi1_sample1 = ROI(scan=(11, 20), 51 | rt=(11, 20), 52 | i=[0] + [1, 2, 3, 4, 3, 2, 1] + [0] * 2, 53 | mz=[100] * 10, 54 | mzmean=100) 55 | component['sample1'] = [roi1_sample1] 56 | 57 | roi1_sample2 = ROI(scan=(12, 19), 58 | rt=(12, 19), 59 | i=[0] + [1, 2, 3, 2, 3, 2, 1], 60 | mz=[100] * 8, 61 | mzmean=100) 62 | roi2_sample2 = ROI(scan=(25, 31), 63 | rt=(25, 31), 64 | i=[0] + [1, 1, 1, 1, 1] + [0], 65 | mz=[100] * 7, 66 | mzmean=100) 67 | component['sample2'] = [roi1_sample2, roi2_sample2] 68 | 69 | roi1_sample3 = ROI(scan=(10, 28), 70 | rt=(10, 28), 71 | i=[0] + [1, 2, 3, 3, 3, 2, 1] + [0] * 3 + [1, 1, 1, 1, 1] + [0]*3, 72 | mz=[100] * 19, 73 | mzmean=100) 74 | component['sample3'] = [roi1_sample3] 75 | 76 | group = align_component(component) 77 | shifts = dict() 78 | for sample, shift in zip(group.samples, group.shifts): 79 | shifts[sample] = shift 80 | self.assertEqual({'sample1': 0, 'sample2': -1, 'sample3': 1}, shifts) 81 | 82 | def test_align_component_strange(self): 83 | component = defaultdict(list) 84 | roi1_sample1 = ROI(scan=(10, 20), 85 | rt=(10, 20), 86 | i=np.random.randn(11), 87 | mz=[100] * 11, 88 | mzmean=100) 89 | component['sample1'] = [roi1_sample1] 90 | 91 | roi1_sample2 = ROI(scan=(100, 108), 92 | rt=(100, 108), 93 | i=np.random.randn(9)+10, 94 | mz=[100] * 9, 95 | mzmean=100) 96 | component['sample2'] = [roi1_sample2] 97 | 98 | group = align_component(component) 99 | shifts = dict() 100 | for sample, shift in zip(group.samples, group.shifts): 101 | shifts[sample] = shift 102 | self.assertEqual([0, 0], group.shifts) 103 | 104 | 105 | if __name__ == '__main__': 106 | unittest.main() 107 | -------------------------------------------------------------------------------- /training_utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from copy import deepcopy 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from scipy.interpolate import interp1d 8 | 9 | 10 | # to do: Reflection should take a ROI (dict) 11 | class Reflection: 12 | """ 13 | class that just reflects any signal 14 | """ 15 | def __init__(self, p=0.5): 16 | self.p = p 17 | 18 | def __call__(self, signal): 19 | if np.random.choice([True, False], p=[self.p, 1 - self.p]): 20 | signal = signal[::-1] 21 | return signal 22 | 23 | 24 | class ROIDataset(Dataset): 25 | """ 26 | A dataset for a training 27 | """ 28 | def __init__(self, path, device, interpolate=False, adaptive_interpolate=False, 29 | length=None, augmentations=None, balanced=False, return_roi_code=False): 30 | """ 31 | :param path: a path to annotated ROIs 32 | :param device: a device where training will occur (GPU / CPU) 33 | :param interpolate: bool, if interpolation is needed 34 | :param adaptive_interpolate: to do: add interpolation to the closest power of 2 35 | :param length: only needed if 'interpolate' is True 36 | :param augmentations: roi augmantations 37 | :param balanced: bool, noise and peaks are returned 50/50 38 | :param return_roi_code: explicitly return the code of the roi 39 | """ 40 | super().__init__() 41 | self.balanced = balanced 42 | self.device = device 43 | self.data = {0: [], 1: []} # a dict from label2roi 44 | self.interpolate = interpolate 45 | self.adaptive_interpolate = interpolate 46 | self.length = length 47 | self.return_roi_code = return_roi_code 48 | for file in os.listdir(path): 49 | if file[0] != '.': 50 | with open(os.path.join(path, file)) as json_file: 51 | roi = json.load(json_file) 52 | roi['intensity'] = np.array(roi['intensity']) 53 | roi['borders'] = np.array(roi['borders']) 54 | if self.interpolate: 55 | roi = self._interpolate(roi) 56 | 57 | self.data[roi['label']].append(roi) 58 | self.augmentations = [] if augmentations is None else augmentations 59 | 60 | def __len__(self): 61 | if self.balanced: 62 | return min(len(self.data[0]), len(self.data[1])) 63 | else: 64 | return len(self.data[0]) + len(self.data[1]) 65 | 66 | @staticmethod 67 | def _get_mask(roi): 68 | integration_mask = np.zeros_like(roi['intensity']) 69 | if roi['number of peaks'] >= 1: 70 | for b, e in roi['borders']: 71 | integration_mask[int(b):int(e)] = 1 72 | 73 | intersection_mask = np.zeros_like(roi['intensity']) 74 | if roi['number of peaks'] >= 2: 75 | for e, b in zip(roi['borders'][:-1, 1], roi['borders'][1:, 0]): 76 | if b - e > 5: 77 | intersection_mask[e + 1:b] = 1 78 | else: 79 | intersection_mask[e - 1:b + 2] = 1 80 | return integration_mask, intersection_mask 81 | 82 | def _interpolate(self, roi): 83 | roi = deepcopy(roi) 84 | points = len(roi['intensity']) 85 | interpolate = interp1d(np.arange(points), roi['intensity'], kind='linear') 86 | roi['intensity'] = interpolate(np.arange(self.length) / (self.length - 1.) * (points - 1.)) 87 | roi['borders'] = np.array(roi['borders']) 88 | roi['borders'] = roi['borders'] * (self.length - 1) // (points - 1) 89 | return roi 90 | 91 | def __getitem__(self, idx): 92 | if self.balanced: 93 | roi = np.random.choice(self.data[idx % 2]) 94 | else: 95 | roi = self.data[0][idx] if idx < len(self.data[0]) else self.data[1][idx - len(self.data[0])] 96 | 97 | for aug in self.augmentations: 98 | roi = deepcopy(roi) 99 | roi = aug(roi) 100 | 101 | x = roi['intensity'] 102 | x = torch.tensor(x, dtype=torch.float32, device=self.device).view(1, -1) 103 | x = x / torch.max(x) 104 | y = torch.tensor(roi['label'], dtype=torch.long, device=self.device) 105 | 106 | integration_mask, intersection_mask = self._get_mask(roi) 107 | integration_mask = torch.tensor(integration_mask, dtype=torch.float32, device=self.device) 108 | intersection_mask = torch.tensor(intersection_mask, dtype=torch.float32, device=self.device) 109 | 110 | if self.return_roi_code: 111 | original_length = len(roi['mz']) 112 | return x, y, integration_mask, intersection_mask, roi['code'], original_length 113 | 114 | return x, y, integration_mask, intersection_mask 115 | -------------------------------------------------------------------------------- /gui_utils/abstract_main_window.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from functools import partial 3 | from PyQt5 import QtWidgets, QtCore 4 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 5 | from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar 6 | from processing_utils.roi import construct_tic, construct_eic 7 | from gui_utils.auxilary_utils import ProgressBarsList, ProgressBarsListItem, FileListWidget, FeatureListWidget 8 | from gui_utils.threading import Worker 9 | 10 | 11 | class AbtractMainWindow(QtWidgets.QMainWindow): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | self._thread_pool = QtCore.QThreadPool() 16 | self._pb_list = ProgressBarsList(self) 17 | 18 | self._list_of_files = FileListWidget() 19 | 20 | self._list_of_features = FeatureListWidget() 21 | self._feature_parameters = None 22 | 23 | self._figure = plt.figure() 24 | self._ax = self._figure.add_subplot(111) # plot here 25 | self._ax.set_xlabel('Retention time [min]') 26 | self._ax.set_ylabel('Intensity') 27 | self._ax.ticklabel_format(axis='y', scilimits=(0, 0)) 28 | self._label2line = dict() # a label (aka line name) to plotted line 29 | self._canvas = FigureCanvas(self._figure) 30 | self._toolbar = NavigationToolbar(self._canvas, self) 31 | 32 | def run_thread(self, caption: str, worker: Worker, text=None, icon=None): 33 | pb = ProgressBarsListItem(caption, parent=self._pb_list) 34 | self._pb_list.addItem(pb) 35 | worker.signals.progress.connect(pb.setValue) 36 | worker.signals.operation.connect(pb.setLabel) 37 | worker.signals.finished.connect(partial(self._threads_finisher, 38 | text=text, icon=icon, pb=pb)) 39 | self._thread_pool.start(worker) 40 | 41 | def _threads_finisher(self, text=None, icon=None, pb=None): 42 | if pb is not None: 43 | self._pb_list.removeItem(pb) 44 | pb.setParent(None) 45 | if text is not None: 46 | msg = QtWidgets.QMessageBox(self) 47 | msg.setText(text) 48 | msg.setIcon(icon) 49 | msg.exec_() 50 | 51 | def set_features(self, obj): 52 | features, parameters = obj 53 | self._list_of_features.clear() 54 | for feature in sorted(features, key=lambda x: x.mz): 55 | self._list_of_features.add_feature(feature) 56 | self._feature_parameters = parameters 57 | 58 | def plotter(self, obj): 59 | if not self._label2line: # in case if 'feature' was plotted 60 | self._figure.clear() 61 | self._ax = self._figure.add_subplot(111) 62 | self._ax.set_xlabel('Retention time [min]') 63 | self._ax.set_ylabel('Intensity') 64 | self._ax.ticklabel_format(axis='y', scilimits=(0, 0)) 65 | 66 | line = self._ax.plot(obj['x'], obj['y'], label=obj['label']) 67 | self._label2line[obj['label']] = line[0] # save line 68 | self._ax.legend(loc='best') 69 | self._figure.tight_layout() 70 | self._canvas.draw() 71 | 72 | def close_file(self, item): 73 | self._list_of_files.deleteFile(item) 74 | 75 | def get_selected_files(self): 76 | return self._list_of_files.selectedItems() 77 | 78 | def get_selected_features(self): 79 | return self._list_of_features.selectedItems() 80 | 81 | def get_plotted_lines(self): 82 | return list(self._label2line.keys()) 83 | 84 | def plot_feature(self, item, shifted=True): 85 | feature = self._list_of_features.get_feature(item) 86 | self._label2line = dict() # empty plotted TIC and EIC 87 | self._figure.clear() 88 | self._ax = self._figure.add_subplot(111) 89 | feature.plot(self._ax, shifted=shifted) 90 | self._ax.set_title(item.text()) 91 | self._ax.set_xlabel('Retention time') 92 | self._ax.set_ylabel('Intensity') 93 | self._ax.ticklabel_format(axis='y', scilimits=(0, 0)) 94 | self._figure.tight_layout() 95 | self._canvas.draw() # refresh canvas 96 | 97 | def plot_tic(self, file): 98 | label = f'TIC: {file[:file.rfind(".")]}' 99 | plotted = False 100 | if label not in self._label2line: 101 | path = self._list_of_files.file2path[file] 102 | 103 | pb = ProgressBarsListItem(f'Plotting TIC: {file}', parent=self._pb_list) 104 | self._pb_list.addItem(pb) 105 | worker = Worker(construct_tic, path, label) 106 | worker.signals.progress.connect(pb.setValue) 107 | worker.signals.result.connect(self.plotter) 108 | worker.signals.finished.connect(partial(self._threads_finisher, pb=pb)) 109 | 110 | self._thread_pool.start(worker) 111 | 112 | plotted = True 113 | return plotted, label 114 | 115 | def plot_eic(self, file, mz, delta): 116 | label = f'EIC {mz:.4f} ± {delta:.4f}: {file[:file.rfind(".")]}' 117 | plotted = False 118 | if label not in self._label2line: 119 | path = self._list_of_files.file2path[file] 120 | 121 | pb = ProgressBarsListItem(f'Plotting EIC (mz={mz:.4f}): {file}', parent=self._pb_list) 122 | self._pb_list.addItem(pb) 123 | worker = Worker(construct_eic, path, label, mz, delta) 124 | worker.signals.progress.connect(pb.setValue) 125 | worker.signals.result.connect(self.plotter) 126 | worker.signals.finished.connect(partial(self._threads_finisher, pb=pb)) 127 | 128 | self._thread_pool.start(worker) 129 | 130 | plotted = True 131 | return plotted, label 132 | 133 | def delete_line(self, label): 134 | self._ax.lines.remove(self._label2line[label]) 135 | del self._label2line[label] 136 | 137 | def refresh_canvas(self): 138 | if self._label2line: 139 | self._ax.legend(loc='best') 140 | self._ax.relim() # recompute the ax.dataLim 141 | self._ax.autoscale_view() # update ax.viewLim using the new dataLim 142 | else: 143 | self._figure.clear() 144 | self._ax = self._figure.add_subplot(111) 145 | self._ax.set_xlabel('Retention time [min]') 146 | self._ax.set_ylabel('Intensity') 147 | self._ax.ticklabel_format(axis='y', scilimits=(0, 0)) 148 | self._canvas.draw() 149 | -------------------------------------------------------------------------------- /gui_utils/data_splitting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from shutil import rmtree, copyfile 5 | from PyQt5 import QtWidgets 6 | from gui_utils.auxilary_utils import GetFoldersWidget 7 | 8 | 9 | class SplitterParameterWindow(QtWidgets.QDialog): 10 | def __init__(self, parent=None): 11 | self.parent = parent 12 | super().__init__(parent) 13 | self.setWindowTitle('peakonly: data splitting') 14 | 15 | self.json_files = set() 16 | 17 | # folder selection 18 | main_layout = QtWidgets.QVBoxLayout() 19 | self.folder_widget = GetFoldersWidget('Choose folder with annotated data: ') 20 | main_layout.addWidget(self.folder_widget) 21 | 22 | # calculation of total number of ROIs 23 | number_layout = QtWidgets.QHBoxLayout() 24 | calculate_rois_button = QtWidgets.QPushButton('Calculate number of ROIs in chosen folders: ') 25 | calculate_rois_button.clicked.connect(self.get_rois_number) 26 | self.rois_number_label = QtWidgets.QLabel() 27 | self.rois_number_label.setText('...') 28 | number_layout.addWidget(calculate_rois_button) 29 | number_layout.addWidget(self.rois_number_label) 30 | 31 | # set sizes 32 | val_size_layout = QtWidgets.QHBoxLayout() 33 | val_size_label = QtWidgets.QLabel() 34 | val_size_label.setText('Size of validation subset: ') 35 | self.val_size_getter = QtWidgets.QLineEdit(self) 36 | self.val_size_getter.setText('...') 37 | val_size_layout.addWidget(val_size_label) 38 | val_size_layout.addWidget(self.val_size_getter) 39 | 40 | test_size_layout = QtWidgets.QHBoxLayout() 41 | test_size_label = QtWidgets.QLabel() 42 | test_size_label.setText('Size of test subset: ') 43 | self.test_size_getter = QtWidgets.QLineEdit(self) 44 | self.test_size_getter.setText('...') 45 | test_size_layout.addWidget(test_size_label) 46 | test_size_layout.addWidget(self.test_size_getter) 47 | 48 | # button 49 | self.split_button = QtWidgets.QPushButton('Split data') 50 | self.split_button.clicked.connect(self.split_data) 51 | 52 | # set main layout 53 | main_layout = QtWidgets.QVBoxLayout() 54 | main_layout.addWidget(self.folder_widget) 55 | main_layout.addLayout(number_layout) 56 | main_layout.addLayout(val_size_layout) 57 | main_layout.addLayout(test_size_layout) 58 | main_layout.addWidget(self.split_button) 59 | self.setLayout(main_layout) 60 | 61 | def get_rois_number(self): 62 | self.json_files = set() 63 | folders = self.folder_widget.get_folders() 64 | for folder in folders: 65 | self.search_json_files(folder) 66 | self.rois_number_label.setText(f'{len(self.json_files)}') 67 | val_size = int(0.15 * len(self.json_files)) 68 | test_size = int(0.15 * len(self.json_files)) 69 | self.val_size_getter.setText(f'{val_size}') 70 | self.test_size_getter.setText(f'{test_size}') 71 | 72 | def search_json_files(self, path): 73 | for sub_path in os.listdir(path): 74 | if not sub_path.startswith('.') and sub_path != '__MACOSX': 75 | sub_path = os.path.join(path, sub_path) 76 | if os.path.isdir(sub_path): 77 | self.search_json_files(sub_path) 78 | elif sub_path.endswith('.json'): 79 | self.json_files.add(sub_path) 80 | 81 | def split_data(self): 82 | try: 83 | self.json_files = set() 84 | folders = self.folder_widget.get_folders() 85 | for folder in folders: 86 | self.search_json_files(folder) 87 | if not self.json_files: 88 | raise ValueError 89 | val_size = int(self.val_size_getter.text()) 90 | test_size = int(self.test_size_getter.text()) 91 | except ValueError: 92 | # popup window with exception 93 | msg = QtWidgets.QMessageBox(self) 94 | msg.setText("Directory should include any *.json files and \n" 95 | "sizes of test and validation datasets should be integers") 96 | msg.setIcon(QtWidgets.QMessageBox.Warning) 97 | msg.exec_() 98 | return None 99 | 100 | # delete old data and create new folders 101 | def remove_dir(path): 102 | try: 103 | rmtree(path) 104 | except OSError: 105 | pass 106 | 107 | def create_dir(path): 108 | try: 109 | os.mkdir(path) 110 | except OSError: 111 | pass 112 | 113 | train_dir = 'data/train' 114 | val_dir = 'data/val' 115 | test_dir = 'data/test' 116 | 117 | remove_dir(train_dir) 118 | remove_dir(val_dir) 119 | remove_dir(test_dir) 120 | 121 | create_dir(train_dir) 122 | create_dir(val_dir) 123 | create_dir(test_dir) 124 | 125 | # get labels of ROIs 126 | label2file = {0: [], 1: []} 127 | for file in self.json_files: 128 | with open(file) as json_file: 129 | roi = json.load(json_file) 130 | label2file[roi['label']].append(file) 131 | 132 | # copy files to val folder 133 | val0size = val_size // 2 134 | val1size = val_size - val0size 135 | for i in range(val0size): 136 | a = np.random.choice(np.arange(len(label2file[0]))) 137 | file_name = label2file[0][a] 138 | copyfile(file_name, os.path.join(val_dir, os.path.basename(file_name))) 139 | label2file[0].pop(a) 140 | 141 | for i in range(val1size): 142 | a = np.random.choice(np.arange(len(label2file[1]))) 143 | file_name = label2file[1][a] 144 | copyfile(file_name, os.path.join(val_dir, os.path.basename(file_name))) 145 | label2file[1].pop(a) 146 | 147 | # copy files to test folder 148 | test0size = test_size // 2 149 | test1size = test_size - test0size 150 | for i in range(test0size): 151 | a = np.random.choice(np.arange(len(label2file[0]))) 152 | file_name = label2file[0][a] 153 | copyfile(file_name, os.path.join(test_dir, os.path.basename(file_name))) 154 | label2file[0].pop(a) 155 | 156 | for i in range(test1size): 157 | a = np.random.choice(np.arange(len(label2file[1]))) 158 | file_name = label2file[1][a] 159 | copyfile(file_name, os.path.join(test_dir, os.path.basename(file_name))) 160 | label2file[1].pop(a) 161 | 162 | # the rest files copy to train folder 163 | for k, v in label2file.items(): 164 | for file_name in v: 165 | copyfile(file_name, os.path.join(train_dir, os.path.basename(file_name))) 166 | -------------------------------------------------------------------------------- /gui_utils/visualization.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from PyQt5 import QtWidgets, QtGui 3 | from gui_utils.abstract_main_window import AbtractMainWindow 4 | from gui_utils.auxilary_utils import ClickableListWidget, FileListWidget 5 | 6 | 7 | class EICParameterWindow(QtWidgets.QDialog): 8 | def __init__(self, parent: AbtractMainWindow): 9 | self.parent = parent 10 | super().__init__(self.parent) 11 | self.setWindowTitle('peakonly: plot EIC') 12 | 13 | mz_layout = QtWidgets.QHBoxLayout() 14 | mz_label = QtWidgets.QLabel(self) 15 | mz_label.setText('m/z=') 16 | self.mz_getter = QtWidgets.QLineEdit(self) 17 | self.mz_getter.setText('100.000') 18 | mz_layout.addWidget(mz_label) 19 | mz_layout.addWidget(self.mz_getter) 20 | 21 | delta_layout = QtWidgets.QHBoxLayout() 22 | delta_label = QtWidgets.QLabel(self) 23 | delta_label.setText('delta=±') 24 | self.delta_getter = QtWidgets.QLineEdit(self) 25 | self.delta_getter.setText('0.005') 26 | delta_layout.addWidget(delta_label) 27 | delta_layout.addWidget(self.delta_getter) 28 | 29 | plot_button = QtWidgets.QPushButton('Plot') 30 | plot_button.clicked.connect(self.plot) 31 | 32 | layout = QtWidgets.QVBoxLayout() 33 | layout.addLayout(mz_layout) 34 | layout.addLayout(delta_layout) 35 | layout.addWidget(plot_button) 36 | self.setLayout(layout) 37 | 38 | def plot(self): 39 | try: 40 | mz = float(self.mz_getter.text()) 41 | delta = float(self.delta_getter.text()) 42 | for file in self.parent.get_selected_files(): 43 | file = file.text() 44 | self.parent.plot_eic(file, mz, delta) 45 | self.close() 46 | except ValueError: 47 | # popup window with exception 48 | msg = QtWidgets.QMessageBox(self) 49 | msg.setText("'m/z' and 'delta' should be float numbers!") 50 | msg.setIcon(QtWidgets.QMessageBox.Warning) 51 | msg.exec_() 52 | 53 | 54 | class VisualizationWindow(QtWidgets.QDialog): 55 | def __init__(self, files, parent: AbtractMainWindow): 56 | self.parent = parent 57 | super().__init__(self.parent) 58 | self.setWindowTitle('peakonly: visualization') 59 | 60 | # files selection 61 | files_layout = QtWidgets.QVBoxLayout() 62 | choose_file_label = QtWidgets.QLabel() 63 | choose_file_label.setText('Choose files to visualize:') 64 | self._list_of_files = FileListWidget() 65 | for file in files: 66 | self._list_of_files.addFile(file) 67 | self._list_of_files.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) 68 | files_layout.addWidget(choose_file_label) 69 | files_layout.addWidget(self._list_of_files) 70 | 71 | # plotted mode 72 | self.plotted_mode_getter = QtWidgets.QComboBox(self) 73 | self.plotted_mode_getter.addItems(['Total Ion Chromatogram (TIC)', 'Extracted Ion Chromatogram (EIC)']) 74 | 75 | mz_layout = QtWidgets.QHBoxLayout() 76 | mz_label = QtWidgets.QLabel(self) 77 | mz_label.setText('m/z=') 78 | self.mz_getter = QtWidgets.QLineEdit(self) 79 | self.mz_getter.setText('100.000') 80 | mz_layout.addWidget(mz_label) 81 | mz_layout.addWidget(self.mz_getter) 82 | 83 | delta_layout = QtWidgets.QHBoxLayout() 84 | delta_label = QtWidgets.QLabel(self) 85 | delta_label.setText('delta=±') 86 | self.delta_getter = QtWidgets.QLineEdit(self) 87 | self.delta_getter.setText('0.005') 88 | delta_layout.addWidget(delta_label) 89 | delta_layout.addWidget(self.delta_getter) 90 | 91 | plot_button = QtWidgets.QPushButton('Plot') 92 | plot_button.clicked.connect(self._plot) 93 | 94 | files_layout.addWidget(self.plotted_mode_getter) 95 | files_layout.addLayout(mz_layout) 96 | files_layout.addLayout(delta_layout) 97 | files_layout.addWidget(plot_button) 98 | 99 | # list of lines 100 | plotted_lines_layout = QtWidgets.QVBoxLayout() 101 | plotted_label = QtWidgets.QLabel() 102 | plotted_label.setText('Currently plotted: ') 103 | self._plotted_list = ClickableListWidget() 104 | self._plotted_list.setSelectionMode( 105 | QtWidgets.QAbstractItemView.ExtendedSelection 106 | ) 107 | for label in self.parent.get_plotted_lines(): 108 | self._plotted_list.addItem(label) 109 | self._plotted_list.connectRightClick(partial(LineContextMenu, self)) 110 | plotted_lines_layout.addWidget(plotted_label) 111 | plotted_lines_layout.addWidget(self._plotted_list) 112 | 113 | # main layout 114 | layout = QtWidgets.QHBoxLayout() 115 | layout.addLayout(files_layout) 116 | layout.addLayout(plotted_lines_layout) 117 | self.setLayout(layout) 118 | 119 | def get_selected_lines(self): 120 | return self._plotted_list.selectedItems() 121 | 122 | def delete_selected(self): 123 | for item in self._plotted_list.selectedItems(): 124 | self.parent.delete_line(item.text()) 125 | self._plotted_list.takeItem(self._plotted_list.row(item)) # delete item from list 126 | self.parent.refresh_canvas() 127 | 128 | def _plot(self): 129 | mode = self.plotted_mode_getter.currentText() 130 | if mode == 'Total Ion Chromatogram (TIC)': 131 | for file in self._list_of_files.selectedItems(): 132 | file = file.text() 133 | plotted, label = self.parent.plot_tic(file) 134 | if plotted: 135 | self._plotted_list.addItem(label) 136 | elif mode == 'Extracted Ion Chromatogram (EIC)': 137 | try: 138 | mz = float(self.mz_getter.text()) 139 | delta = float(self.delta_getter.text()) 140 | for file in self._list_of_files.selectedItems(): 141 | file = file.text() 142 | plotted, label = self.parent.plot_eic(file, mz, delta) 143 | if plotted: 144 | self._plotted_list.addItem(label) 145 | except ValueError: 146 | # popup window with exception 147 | msg = QtWidgets.QMessageBox(self) 148 | msg.setText("'mz' and 'delta' should be float numbers!") 149 | msg.setIcon(QtWidgets.QMessageBox.Warning) 150 | msg.exec_() 151 | 152 | 153 | class LineContextMenu(QtWidgets.QMenu): 154 | def __init__(self, parent: VisualizationWindow): 155 | self.parent = parent 156 | super().__init__(parent) 157 | lines = list(self.parent.get_selected_lines()) 158 | 159 | menu = QtWidgets.QMenu(parent) 160 | 161 | clear = QtWidgets.QAction('Clear', parent) 162 | 163 | menu.addAction(clear) 164 | 165 | action = menu.exec_(QtGui.QCursor.pos()) 166 | 167 | if action == clear: 168 | self.parent.delete_selected() 169 | -------------------------------------------------------------------------------- /gui_utils/auxilary_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PyQt5 import QtWidgets, QtCore 3 | 4 | 5 | class ClickableListWidget(QtWidgets.QListWidget): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.double_click = None 9 | self.right_click = None 10 | 11 | def mousePressEvent(self, QMouseEvent): 12 | super(QtWidgets.QListWidget, self).mousePressEvent(QMouseEvent) 13 | if QMouseEvent.button() == QtCore.Qt.RightButton and self.right_click is not None: 14 | self.right_click() 15 | 16 | def mouseDoubleClickEvent(self, QMouseEvent): 17 | if self.double_click is not None: 18 | if QMouseEvent.button() == QtCore.Qt.LeftButton: 19 | item = self.itemAt(QMouseEvent.pos()) 20 | if item is not None: 21 | self.double_click(item) 22 | 23 | def connectDoubleClick(self, method): 24 | """ 25 | Set a callable object which should be called when a user double-clicks on item 26 | Parameters 27 | ---------- 28 | method : callable 29 | any callable object 30 | Returns 31 | ------- 32 | - : None 33 | """ 34 | self.double_click = method 35 | 36 | def connectRightClick(self, method): 37 | """ 38 | Set a callable object which should be called when a user double-clicks on item 39 | Parameters 40 | ---------- 41 | method : callable 42 | any callable object 43 | Returns 44 | ------- 45 | - : None 46 | """ 47 | self.right_click = method 48 | 49 | 50 | class FileListWidget(ClickableListWidget): 51 | def __init__(self, *args, **kwargs): 52 | super().__init__(*args, **kwargs) 53 | self.file2path = {} 54 | 55 | def addFile(self, path: str): 56 | filename = os.path.basename(path) 57 | self.file2path[filename] = path 58 | self.addItem(filename) 59 | 60 | def deleteFile(self, item: QtWidgets.QListWidgetItem): 61 | del self.file2path[item.text()] 62 | self.takeItem(self.row(item)) 63 | 64 | def getPath(self, item: QtWidgets.QListWidgetItem): 65 | return self.file2path[item.text()] 66 | 67 | 68 | class FeatureListWidget(ClickableListWidget): 69 | def __init__(self, *args, **kwargs): 70 | super().__init__(*args, **kwargs) 71 | self.features = [] 72 | 73 | def add_feature(self, feature): 74 | name = f'#{len(self.features)}: mz = {feature.mz:.4f}, rt = {feature.rtmin:.2f} - {feature.rtmax:.2f}' 75 | self.features.append(feature) 76 | self.addItem(name) 77 | 78 | def get_feature(self, item): 79 | number = item.text() 80 | number = int(number[number.find('#') + 1:number.find(':')]) 81 | return self.features[number] 82 | 83 | def get_all(self): 84 | features = [] 85 | for i in range(self.count()): 86 | item = self.item(i) 87 | features.append(self.get_feature(item)) 88 | return features 89 | 90 | def clear(self): 91 | super(FeatureListWidget, self).clear() 92 | self.features = [] 93 | 94 | 95 | class ProgressBarsListItem(QtWidgets.QWidget): 96 | def __init__(self, text, pb=None, parent=None): 97 | super().__init__(parent) 98 | self.pb = pb 99 | if self.pb is None: 100 | self.pb = QtWidgets.QProgressBar() 101 | 102 | self.label = QtWidgets.QLabel(self) 103 | self.label.setText(text) 104 | 105 | main_layout = QtWidgets.QHBoxLayout() 106 | main_layout.addWidget(self.label, 30) 107 | main_layout.addWidget(self.pb, 70) 108 | 109 | self.setLayout(main_layout) 110 | 111 | def setValue(self, value): 112 | self.pb.setValue(value) 113 | 114 | def setLabel(self, text): 115 | self.pb.setValue(0) 116 | self.label.setText(text) 117 | 118 | 119 | class ProgressBarsList(QtWidgets.QWidget): 120 | def __init__(self, parent=None): 121 | super().__init__(parent) 122 | 123 | self.main_layout = QtWidgets.QVBoxLayout() 124 | self.setLayout(self.main_layout) 125 | 126 | def removeItem(self, item): 127 | self.layout().removeWidget(item) 128 | 129 | def addItem(self, item): 130 | self.layout().addWidget(item) 131 | 132 | 133 | class GetFolderWidget(QtWidgets.QWidget): 134 | def __init__(self, default_directory='', parent=None): 135 | super().__init__(parent) 136 | 137 | button = QtWidgets.QToolButton() 138 | button.setText('...') 139 | button.clicked.connect(self.set_folder) 140 | 141 | if not default_directory: 142 | default_directory = os.getcwd() 143 | self.lineEdit = QtWidgets.QToolButton() 144 | self.lineEdit.setText(default_directory) 145 | 146 | layout = QtWidgets.QHBoxLayout() 147 | layout.addWidget(self.lineEdit, 85) 148 | layout.addWidget(button, 15) 149 | 150 | self.setLayout(layout) 151 | 152 | def set_folder(self): 153 | directory = str(QtWidgets.QFileDialog.getExistingDirectory()) 154 | if directory: 155 | self.lineEdit.setText(directory) 156 | 157 | def get_folder(self): 158 | return self.lineEdit.text() 159 | 160 | 161 | class GetFoldersWidget(QtWidgets.QWidget): 162 | def __init__(self, label, parent=None): 163 | super().__init__(parent) 164 | 165 | button = QtWidgets.QToolButton() 166 | button.setText('...') 167 | button.clicked.connect(self.add_folder) 168 | 169 | self.lineEdit = QtWidgets.QToolButton() 170 | self.lineEdit.setText(label) 171 | 172 | folder_getter_layout = QtWidgets.QHBoxLayout() 173 | folder_getter_layout.addWidget(self.lineEdit, 85) 174 | folder_getter_layout.addWidget(button, 15) 175 | 176 | self.list_widget = QtWidgets.QListWidget() 177 | self.list_widget.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) 178 | 179 | main_layout = QtWidgets.QVBoxLayout() 180 | main_layout.addLayout(folder_getter_layout) 181 | main_layout.addWidget(self.list_widget) 182 | 183 | self.setLayout(main_layout) 184 | 185 | def add_folder(self): 186 | directory = str(QtWidgets.QFileDialog.getExistingDirectory()) 187 | if directory: 188 | self.list_widget.addItem(directory) 189 | 190 | def get_folders(self): 191 | folders = [f.text() for f in self.list_widget.selectedItems()] 192 | return folders 193 | 194 | 195 | class GetFileWidget(QtWidgets.QWidget): 196 | def __init__(self, extension, default_file, parent): 197 | super().__init__(parent) 198 | 199 | self.extension = extension 200 | 201 | button = QtWidgets.QToolButton() 202 | button.setText('...') 203 | button.clicked.connect(self.set_file) 204 | 205 | self.lineEdit = QtWidgets.QToolButton() 206 | self.lineEdit.setText(default_file) 207 | 208 | layout = QtWidgets.QHBoxLayout() 209 | layout.addWidget(self.lineEdit, 85) 210 | layout.addWidget(button, 15) 211 | 212 | self.setLayout(layout) 213 | 214 | def set_file(self): 215 | filter = f'{self.extension} (*.{self.extension})' 216 | file, _ = QtWidgets.QFileDialog.getOpenFileName(None, None, None, filter) 217 | if file: 218 | self.lineEdit.setText(file) 219 | 220 | def get_file(self): 221 | return self.lineEdit.text() 222 | -------------------------------------------------------------------------------- /cython_utils/roi.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | import pymzml 3 | from libcpp.map cimport map 4 | from libcpp.vector cimport vector 5 | from cython.operator cimport dereference, postincrement, postdecrement 6 | from processing_utils.roi import ROI 7 | 8 | cdef struct cROI: 9 | int scan_begin 10 | int scan_end 11 | float rt_begin 12 | float rt_end 13 | vector[float] i 14 | vector[float] mz 15 | float mz_mean 16 | int points # calculate number of non_zero points 17 | 18 | cdef struct MsScan: 19 | vector[float] i 20 | vector[float] mz 21 | float rt 22 | 23 | 24 | def get_ROIs(str path, float delta_mz=0.005, int required_points=15, int dropped_points=3, progress_callback=None): 25 | # read all scans in mzML file 26 | run = pymzml.run.Reader(path) 27 | cdef vector[MsScan] scans 28 | for scan in run: 29 | if scan.ms_level == 1: 30 | scans.push_back(MsScan(scan.i, scan.mz, scan.scan_time[0])) 31 | 32 | cdef vector[cROI] rois # completed ROIs (vector) 33 | cdef map[float, cROI] process_rois # processing ROIs (map) 34 | 35 | # initialize a processed data 36 | cdef MsScan init_scan = scans[0] 37 | cdef float start_time = init_scan.rt 38 | 39 | cdef float min_mz = min(init_scan.mz) 40 | cdef float max_mz = max(init_scan.mz) 41 | cdef cROI new_roi 42 | for n in range(init_scan.i.size()): 43 | if init_scan.i[n] != 0: 44 | new_roi = cROI(0, 0, start_time, start_time, vector[float](), 45 | vector[float](), init_scan.mz[n], 1) 46 | new_roi.i.push_back(init_scan.i[n]) 47 | new_roi.mz.push_back(init_scan.mz[n]) 48 | process_rois[init_scan.mz[n]] = new_roi 49 | 50 | 51 | cdef float ceiling_mz # the closest m/z not less than the given 52 | cdef cROI* ceiling 53 | cdef map[float, cROI].iterator ceiling_it 54 | cdef float floor_mz # the closest m/z not greater than the given 55 | cdef cROI* floor 56 | cdef map[float, cROI].iterator floor_it 57 | cdef float closest_mz # the closest m/z 58 | cdef cROI* closest 59 | 60 | cdef cROI* roi 61 | cdef float mz 62 | cdef MsScan current_scan 63 | 64 | cdef map[float, cROI].iterator map_it 65 | 66 | for number in range(1, scans.size()): 67 | current_scan = scans[number] 68 | for n in range(current_scan.i.size()): 69 | if current_scan.i[n] != 0: 70 | mz = current_scan.mz[n] 71 | # find ceiling_it and floor_it 72 | floor_it = process_rois.lower_bound(mz) 73 | if floor_it != process_rois.end() and floor_it != process_rois.begin(): 74 | ceiling_it = postdecrement(floor_it) 75 | else: 76 | ceiling_it = floor_it 77 | floor_it = process_rois.end() 78 | # get ceiling and floor 79 | if ceiling_it != process_rois.end(): 80 | ceiling = &dereference(ceiling_it).second 81 | ceiling_mz = ceiling.mz_mean 82 | if floor_it != process_rois.end(): 83 | floor = &dereference(floor_it).second 84 | floor_mz = floor.mz_mean 85 | # getting closest roi (if possible) 86 | if ceiling_it == process_rois.end() and floor_it == process_rois.end(): # process_rois is empty? 87 | new_roi = cROI(number, number, current_scan.rt, current_scan.rt, vector[float](), 88 | vector[float](), mz, 1) 89 | new_roi.i.push_back(current_scan.i[n]) 90 | new_roi.mz.push_back(current_scan.mz[n]) 91 | process_rois[mz] = new_roi 92 | elif ceiling_it == process_rois.end(): 93 | closest_mz = floor_mz 94 | closest = floor 95 | elif floor_it == process_rois.end(): 96 | closest_mz = ceiling_mz 97 | closest = ceiling 98 | else: 99 | if ceiling_mz - mz > mz - floor_mz: 100 | closest_mz = floor_mz 101 | closest = floor 102 | else: 103 | closest_mz = ceiling_mz 104 | closest = ceiling 105 | # expanding existing roi or creates a new one 106 | if abs(closest_mz - mz) < delta_mz: 107 | roi = closest 108 | if roi.scan_end == number: 109 | # ROIs is already extended (two peaks in one mz window) (almost not possible) 110 | roi.mz_mean = 0.9 * roi.mz_mean + 0.1 * mz 111 | roi.points += 1 112 | roi.mz[roi.mz.size() - 1] = ((roi.i[roi.mz.size() - 1]*roi.mz[roi.mz.size() - 1] + 113 | current_scan.i[n]*mz) / (roi.i[roi.mz.size() - 1] 114 | + current_scan.i[n])) 115 | roi.i[roi.i.size() - 1] = roi.i[roi.i.size() - 1] + current_scan.i[n] 116 | else: 117 | roi.mz_mean = 0.9 * roi.mz_mean + 0.1 * mz 118 | roi.points += 1 119 | roi.mz.push_back(mz) 120 | roi.i.push_back(current_scan.i[n]) 121 | roi.scan_end = number 122 | roi.rt_end = current_scan.rt 123 | else: 124 | new_roi = cROI(number, number, current_scan.rt, current_scan.rt, vector[float](), 125 | vector[float](), mz, 1) 126 | new_roi.i.push_back(current_scan.i[n]) 127 | new_roi.mz.push_back(current_scan.mz[n]) 128 | process_rois[mz] = new_roi 129 | # Check and cleanup 130 | map_it = process_rois.begin() 131 | while map_it != process_rois.end(): 132 | roi = &dereference(map_it).second 133 | mz = roi.mz_mean 134 | if roi.scan_end < number <= roi.scan_end + dropped_points: 135 | # insert 'zero' in the end 136 | roi.mz.push_back(mz) 137 | roi.i.push_back(0) 138 | postincrement(map_it) 139 | elif roi.scan_end != number: 140 | if roi.points >= required_points: 141 | new_roi = dereference(map_it).second 142 | rois.push_back(new_roi) 143 | process_rois.erase(postincrement(map_it)) 144 | else: 145 | postincrement(map_it) 146 | if progress_callback is not None and not number % 10: 147 | progress_callback.emit(int(number * 100 / scans.size())) 148 | # add final rois 149 | map_it = process_rois.begin() 150 | while map_it != process_rois.end(): 151 | roi = &dereference(map_it).second 152 | if roi.points >= required_points: 153 | for n in range(dropped_points - (scans.size() - 1 - roi.scan_end)): 154 | roi.mz.push_back(roi.mz_mean) 155 | roi.i.push_back(0) 156 | rois.push_back(dereference(roi)) 157 | postincrement(map_it) 158 | # expand constructed roi and creating python object 159 | cdef vector[cROI].iterator roi_it = rois.begin() 160 | python_rois = [] 161 | while roi_it != rois.end(): 162 | roi = &dereference(roi_it) 163 | roi.i.insert(roi.i.begin(), dropped_points, 0) 164 | roi.mz.insert(roi.mz.begin(), dropped_points, roi.mz_mean) 165 | # change scan numbers (necessary for future matching) 166 | roi.scan_begin = roi.scan_begin - dropped_points 167 | roi.scan_end = roi.scan_end + dropped_points 168 | 169 | python_rois.append(ROI([roi.scan_begin, roi.scan_end], [roi.rt_begin, roi.rt_end], 170 | roi.i, roi.mz, roi.mz_mean)) 171 | postincrement(roi_it) 172 | return python_rois 173 | -------------------------------------------------------------------------------- /gui_utils/processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PyQt5 import QtWidgets 4 | from gui_utils.abstract_main_window import AbtractMainWindow 5 | from gui_utils.auxilary_utils import FileListWidget, GetFileWidget 6 | from gui_utils.threading import Worker 7 | from processing_utils.runner import FilesRunner 8 | from models.rcnn import RecurrentCNN 9 | from models.cnn_classifier import Classifier 10 | from models.cnn_segmentator import Segmentator 11 | 12 | 13 | class ProcessingParameterWindow(QtWidgets.QDialog): 14 | """ 15 | Main Processing Window, where one can choose files and parameters for data processing 16 | 17 | Parameters 18 | ---------- 19 | files : list of str 20 | A list of *.mzML files 21 | mode : str 22 | A one of two 'all in one' of 'sequential' 23 | parent : MainWindow(QtWidgets.QMainWindow) 24 | - 25 | Attributes 26 | ---------- 27 | mode : str 28 | A one of two 'all in one' of 'sequential' 29 | parent : MainWindow(QtWidgets.QMainWindow) 30 | - 31 | list_of_files : FileListWidget 32 | QtWidget which stores and shows *.mzML files 33 | weights_widget : GetFileWidget 34 | Stores a path for 'All in one' ANN (optional attribute) 35 | weights_classifier_widget : GetFileWidget 36 | Stores a path for a classifier (optional attribute) 37 | weights_segmentator_widget : GetFileWidget 38 | Stores a path for a segmentator (optional attribute) 39 | mz_getter : QtWidgets.QLineEdit 40 | A getter for delta_mz parameter 41 | roi_points_getter : QtWidgets.QLineEdit 42 | A getter for required_points parameter 43 | dropped_points_getter : QtWidgets.QLineEdit 44 | A getter for dropped_points parameter 45 | peak_points_getter : QtWidgets.QLineEdit 46 | A getter for peak_minimum_points parameter 47 | """ 48 | def __init__(self, files, mode, parent: AbtractMainWindow): 49 | self.parent = parent 50 | self.mode = mode 51 | super().__init__(parent) 52 | self.setWindowTitle('peakonly: feature detection') 53 | self._init_ui(files) # initialize user interface 54 | 55 | def _init_ui(self, files): 56 | # files selection 57 | choose_file_label = QtWidgets.QLabel() 58 | choose_file_label.setText('Choose files to process:') 59 | self.list_of_files = FileListWidget() 60 | self.list_of_files.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) 61 | for file in files: 62 | self.list_of_files.addFile(file) 63 | for i in range(self.list_of_files.count()): 64 | self.list_of_files.item(i).setSelected(True) 65 | 66 | # left 'half' layout 67 | left_half_layout = QtWidgets.QVBoxLayout() 68 | left_half_layout.addWidget(choose_file_label) 69 | left_half_layout.addWidget(self.list_of_files) 70 | 71 | # ANN's weights selection 72 | weights_layout = QtWidgets.QVBoxLayout() 73 | if self.mode == 'all in one': 74 | choose_weights_label = QtWidgets.QLabel() 75 | choose_weights_label.setText("Choose weights for a 'all in one' model:") 76 | self.weights_widget = GetFileWidget('pt', os.path.join(os.getcwd(), 'data', 'weights', 'RecurrentCNN.pt'), 77 | self.parent) 78 | weights_layout.addWidget(choose_weights_label) 79 | weights_layout.addWidget(self.weights_widget) 80 | elif self.mode == 'sequential': 81 | choose_classifier_weights_label = QtWidgets.QLabel() 82 | choose_classifier_weights_label.setText('Choose weights for a Classifier:') 83 | self.weights_classifier_widget = GetFileWidget('pt', os.path.join(os.getcwd(), 84 | 'data', 'weights', 'Classifier.pt'), 85 | self.parent) 86 | choose_segmentator_weights_label = QtWidgets.QLabel() 87 | choose_segmentator_weights_label.setText('Choose weights for a Segmentator:') 88 | self.weights_segmentator_widget = GetFileWidget('pt', os.path.join(os.getcwd(), 89 | 'data', 'weights', 'Segmentator.pt'), 90 | self.parent) 91 | weights_layout.addWidget(choose_classifier_weights_label) 92 | weights_layout.addWidget(self.weights_classifier_widget) 93 | weights_layout.addWidget(choose_segmentator_weights_label) 94 | weights_layout.addWidget(self.weights_segmentator_widget) 95 | 96 | # Selection of parameters 97 | parameters_layout = QtWidgets.QVBoxLayout() 98 | 99 | mz_label = QtWidgets.QLabel() 100 | mz_label.setText('m/z deviation:') 101 | self.mz_getter = QtWidgets.QLineEdit(self) 102 | self.mz_getter.setText('0.005') 103 | 104 | roi_points_label = QtWidgets.QLabel() 105 | roi_points_label.setText('Minimal length of ROI:') 106 | self.roi_points_getter = QtWidgets.QLineEdit(self) 107 | self.roi_points_getter.setText('15') 108 | 109 | dropped_points_label = QtWidgets.QLabel() 110 | dropped_points_label.setText('Maximal number of zero points in a row:') 111 | self.dropped_points_getter = QtWidgets.QLineEdit(self) 112 | self.dropped_points_getter.setText('3') 113 | 114 | peak_points_label = QtWidgets.QLabel() 115 | peak_points_label.setText('Minimal length of peak:') 116 | self.peak_points_getter = QtWidgets.QLineEdit(self) 117 | self.peak_points_getter.setText('8') 118 | 119 | parameters_layout.addWidget(mz_label) 120 | parameters_layout.addWidget(self.mz_getter) 121 | parameters_layout.addWidget(roi_points_label) 122 | parameters_layout.addWidget(self.roi_points_getter) 123 | parameters_layout.addWidget(dropped_points_label) 124 | parameters_layout.addWidget(self.dropped_points_getter) 125 | parameters_layout.addWidget(peak_points_label) 126 | parameters_layout.addWidget(self.peak_points_getter) 127 | 128 | # run button 129 | run_button = QtWidgets.QPushButton('Run processing') 130 | run_button.clicked.connect(self.start_processing) 131 | 132 | # right 'half' layout 133 | right_half_layout = QtWidgets.QVBoxLayout() 134 | right_half_layout.addLayout(weights_layout) 135 | right_half_layout.addLayout(parameters_layout) 136 | right_half_layout.addWidget(run_button) 137 | 138 | # main layout 139 | main_layout = QtWidgets.QHBoxLayout() 140 | main_layout.addLayout(left_half_layout, 30) 141 | main_layout.addLayout(right_half_layout, 70) 142 | self.setLayout(main_layout) 143 | 144 | def start_processing(self): 145 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 146 | # to do: device should be customizable parameter 147 | try: 148 | delta_mz = float(self.mz_getter.text()) 149 | required_points = int(self.roi_points_getter.text()) 150 | dropped_points = int(self.dropped_points_getter.text()) 151 | minimum_peak_points = int(self.peak_points_getter.text()) 152 | path2mzml = [] 153 | for file in self.list_of_files.selectedItems(): 154 | path2mzml.append(self.list_of_files.file2path[file.text()]) 155 | if not path2mzml: 156 | raise ValueError 157 | if self.mode == 'all in one': 158 | # to do: save models as pytorch scripts 159 | model = RecurrentCNN().to(device) 160 | path2weights = self.weights_widget.get_file() 161 | model.load_state_dict(torch.load(path2weights, map_location=device)) 162 | model.eval() 163 | models = [model] 164 | elif self.mode == 'sequential': 165 | classifier = Classifier().to(device) 166 | path2classifier_weights = self.weights_classifier_widget.get_file() 167 | classifier.load_state_dict(torch.load(path2classifier_weights, map_location=device)) 168 | classifier.eval() 169 | segmentator = Segmentator().to(device) 170 | path2segmentator_weights = self.weights_segmentator_widget.get_file() 171 | segmentator.load_state_dict(torch.load(path2segmentator_weights, map_location=device)) 172 | segmentator.eval() 173 | models = [classifier, segmentator] 174 | elif self.mode == 'simple': 175 | self.mode = 'sequential' 176 | classifier = Classifier().to(device) 177 | path2classifier_weights = os.path.join('data', 'weights', 'Classifier.pt') 178 | classifier.load_state_dict(torch.load(path2classifier_weights, map_location=device)) 179 | classifier.eval() 180 | segmentator = Segmentator().to(device) 181 | path2segmentator_weights = os.path.join('data', 'weights', 'Segmentator.pt') 182 | segmentator.load_state_dict(torch.load(path2segmentator_weights, map_location=device)) 183 | segmentator.eval() 184 | models = [classifier, segmentator] 185 | else: 186 | assert False, self.mode 187 | 188 | runner = FilesRunner(self.mode, models, delta_mz, 189 | required_points, dropped_points, 190 | minimum_peak_points, device) 191 | 192 | worker = Worker(runner, path2mzml, multiple_process=True) 193 | worker.signals.result.connect(self.parent.set_features) 194 | self.parent.run_thread('Data processing:', worker) 195 | 196 | self.close() 197 | except ValueError: 198 | # popup window with exception 199 | msg = QtWidgets.QMessageBox(self) 200 | msg.setText("Check parameters. Something is wrong!") 201 | msg.setIcon(QtWidgets.QMessageBox.Warning) 202 | msg.exec_() 203 | -------------------------------------------------------------------------------- /processing_utils/roi.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pymzml 3 | import numpy as np 4 | from tqdm import tqdm 5 | from bintrees import FastAVLTree 6 | 7 | 8 | def construct_ROI(roi_dict): 9 | """ 10 | Construct an ROI object from dict 11 | :param roi: a dict with 'description' (not necessary), 12 | 'code' (basically the name of file, not necessary), 13 | 'label' (annotated class), 14 | 'number of peaks' (quantity of peaks within ROI), 15 | 'begins' (a list of scan numbers), 16 | 'ends' (a list of scan numbers), 17 | 'intersections' (a list of scan numbers), 18 | 'scan' (first and last scan of ROI), 19 | 'rt', 20 | 'intensity', 21 | 'mz' 22 | """ 23 | return ROI(roi_dict['scan'], roi_dict['rt'], roi_dict['intensity'], roi_dict['mz'], np.mean(roi_dict['mz'])) 24 | 25 | 26 | class ROI: 27 | def __init__(self, scan, rt, i, mz, mzmean): 28 | self.scan = scan 29 | self.rt = rt 30 | self.i = i 31 | self.mz = mz 32 | self.mzmean = mzmean 33 | 34 | def __repr__(self): 35 | return 'mz = {:.4f}, rt = {:.2f} - {:.2f}'.format(self.mzmean, self.rt[0], self.rt[1]) 36 | 37 | def save_annotated(self, path, code=None, label=0, number_of_peaks=0, peaks_labels=None, borders=None, 38 | description=None): 39 | roi = dict() 40 | roi['code'] = code 41 | roi['label'] = label 42 | roi['number of peaks'] = number_of_peaks 43 | roi["peaks' labels"] = [] if peaks_labels is None else peaks_labels 44 | roi['borders'] = [] if borders is None else borders 45 | roi['description'] = description 46 | 47 | roi['rt'] = self.rt 48 | roi['scan'] = self.scan 49 | roi['intensity'] = list(map(float, self.i)) 50 | roi['mz'] = list(map(float, self.mz)) 51 | 52 | with open(path, 'w') as jsonfile: 53 | json.dump(roi, jsonfile) 54 | 55 | 56 | class ProcessROI(ROI): 57 | def __init__(self, scan, rt, i, mz, mzmean): 58 | super().__init__(scan, rt, i, mz, mzmean) 59 | self.points = 1 60 | 61 | 62 | def get_closest(mzmean, mz, pos): 63 | if pos == len(mzmean): 64 | res = pos - 1 65 | elif pos == 0: 66 | res = pos 67 | else: 68 | res = pos if (mzmean[pos] - mz) < (mz - mzmean[pos - 1]) else pos - 1 69 | return res 70 | 71 | 72 | def get_ROIs(path, delta_mz=0.005, required_points=15, dropped_points=3, progress_callback=None): 73 | ''' 74 | :param path: path to mzml file 75 | :param delta_mz: 76 | :param required_points: 77 | :param dropped_points: can be zero points 78 | :param pbar: an pyQt5 progress bar to visualize 79 | :return: ROIs - a list of ROI objects found in current file 80 | ''' 81 | # read all scans in mzML file 82 | run = pymzml.run.Reader(path) 83 | scans = [] 84 | for scan in run: 85 | if scan.ms_level == 1: 86 | scans.append(scan) 87 | 88 | ROIs = [] # completed ROIs 89 | process_ROIs = FastAVLTree() # processed ROIs 90 | 91 | # initialize a processed data 92 | number = 1 # number of processed scan 93 | init_scan = scans[0] 94 | start_time = init_scan.scan_time[0] 95 | 96 | min_mz = max(init_scan.mz) 97 | max_mz = min(init_scan.mz) 98 | for mz, i in zip(init_scan.mz, init_scan.i): 99 | if i != 0: 100 | process_ROIs[mz] = ProcessROI([1, 1], 101 | [start_time, start_time], 102 | [i], 103 | [mz], 104 | mz) 105 | min_mz = min(min_mz, mz) 106 | max_mz = max(max_mz, mz) 107 | 108 | for scan in tqdm(scans): 109 | if number == 1: # already processed scan 110 | number += 1 111 | continue 112 | # expand ROI 113 | for n, mz in enumerate(scan.mz): 114 | if scan.i[n] != 0: 115 | ceiling_mz, ceiling_item = None, None 116 | floor_mz, floor_item = None, None 117 | if mz < max_mz: 118 | _, ceiling_item = process_ROIs.ceiling_item(mz) 119 | ceiling_mz = ceiling_item.mzmean 120 | if mz > min_mz: 121 | _, floor_item = process_ROIs.floor_item(mz) 122 | floor_mz = floor_item.mzmean 123 | # choose closest 124 | if ceiling_mz is None and floor_mz is None: 125 | time = scan.scan_time[0] 126 | process_ROIs[mz] = ProcessROI([number, number], 127 | [time, time], 128 | [scan.i[n]], 129 | [mz], 130 | mz) 131 | continue 132 | elif ceiling_mz is None: 133 | closest_mz, closest_item = floor_mz, floor_item 134 | elif floor_mz is None: 135 | closest_mz, closest_item = ceiling_mz, ceiling_item 136 | else: 137 | if ceiling_mz - mz > mz - floor_mz: 138 | closest_mz, closest_item = floor_mz, floor_item 139 | else: 140 | closest_mz, closest_item = ceiling_mz, ceiling_item 141 | 142 | if abs(closest_item.mzmean - mz) < delta_mz: 143 | roi = closest_item 144 | if roi.scan[1] == number: 145 | # ROIs is already extended (two peaks in one mz window) 146 | roi.mzmean = (roi.mzmean * roi.points + mz) / (roi.points + 1) 147 | roi.points += 1 148 | roi.mz[-1] = (roi.i[-1]*roi.mz[-1] + scan.i[n]*mz) / (roi.i[-1] + scan.i[n]) 149 | roi.i[-1] = (roi.i[-1] + scan.i[n]) 150 | else: 151 | roi.mzmean = (roi.mzmean * roi.points + mz) / (roi.points + 1) 152 | roi.points += 1 153 | roi.mz.append(mz) 154 | roi.i.append(scan.i[n]) 155 | roi.scan[1] = number # show that we extended the roi 156 | roi.rt[1] = scan.scan_time[0] 157 | else: 158 | time = scan.scan_time[0] 159 | process_ROIs[mz] = ProcessROI([number, number], 160 | [time, time], 161 | [scan.i[n]], 162 | [mz], 163 | mz) 164 | # Check and cleanup 165 | to_delete = [] 166 | for mz, roi in process_ROIs.items(): 167 | if roi.scan[1] < number <= roi.scan[1] + dropped_points: 168 | # insert 'zero' in the end 169 | roi.mz.append(roi.mzmean) 170 | roi.i.append(0) 171 | elif roi.scan[1] != number: 172 | to_delete.append(mz) 173 | if roi.points >= required_points: 174 | ROIs.append(ROI( 175 | roi.scan, 176 | roi.rt, 177 | roi.i, 178 | roi.mz, 179 | roi.mzmean 180 | )) 181 | process_ROIs.remove_items(to_delete) 182 | try: 183 | min_mz, _ = process_ROIs.min_item() 184 | max_mz, _ = process_ROIs.max_item() 185 | except ValueError: 186 | min_mz = float('inf') 187 | max_mz = 0 188 | number += 1 189 | if progress_callback is not None and not number % 10: 190 | progress_callback.emit(int(number * 100 / len(scans))) 191 | # add final rois 192 | for mz, roi in process_ROIs.items(): 193 | if roi.points >= required_points: 194 | for n in range(dropped_points - (number - 1 - roi.scan[1])): 195 | # insert 'zero' in the end 196 | roi.mz.append(roi.mzmean) 197 | roi.i.append(0) 198 | ROIs.append(ROI( 199 | roi.scan, 200 | roi.rt, 201 | roi.i, 202 | roi.mz, 203 | roi.mzmean 204 | )) 205 | # expand constructed roi 206 | for roi in ROIs: 207 | for n in range(dropped_points): 208 | # insert in the begin 209 | roi.i.insert(0, 0) 210 | roi.mz.insert(0, roi.mzmean) 211 | # change scan numbers (necessary for future matching) 212 | roi.scan = (roi.scan[0] - dropped_points, roi.scan[1] + dropped_points) 213 | assert roi.scan[1] - roi.scan[0] == len(roi.i) - 1 214 | return ROIs 215 | 216 | 217 | def construct_tic(path, label, progress_callback=None): 218 | run = pymzml.run.Reader(path) 219 | t_measure = None 220 | time = [] 221 | tic = [] 222 | spectrum_count = run.get_spectrum_count() 223 | for i, scan in enumerate(run): 224 | if scan.ms_level == 1: 225 | tic.append(scan.TIC) # get total ion of scan 226 | t, measure = scan.scan_time # get scan time 227 | time.append(t) 228 | if not t_measure: 229 | t_measure = measure 230 | if progress_callback is not None and not i % 10: 231 | progress_callback.emit(int(i * 100 / spectrum_count)) 232 | if t_measure == 'second': 233 | time = np.array(time) / 60 234 | return {'x': time, 'y': tic, 'label': label} 235 | 236 | 237 | def construct_eic(path, label, mz, delta, progress_callback=None): 238 | run = pymzml.run.Reader(path) 239 | t_measure = None 240 | time = [] 241 | eic = [] 242 | spectrum_count = run.get_spectrum_count() 243 | for i, scan in enumerate(run): 244 | if scan.ms_level == 1: 245 | t, measure = scan.scan_time # get scan time 246 | time.append(t) 247 | pos = np.searchsorted(scan.mz, mz) 248 | closest = get_closest(scan.mz, mz, pos) 249 | if abs(scan.mz[closest] - mz) < delta: 250 | eic.append(scan.i[closest]) 251 | else: 252 | eic.append(0) 253 | if not t_measure: 254 | t_measure = measure 255 | if progress_callback is not None and not i % 10: 256 | progress_callback.emit(int(i * 100 / spectrum_count)) 257 | if t_measure == 'second': 258 | time = np.array(time) / 60 259 | return {'x': time, 'y': eic, 'label': label} 260 | -------------------------------------------------------------------------------- /training_utils/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def accuracy(logits, y_true): 8 | """ 9 | :param logits: np.ndarray, output of the model 10 | :param y_true: np.ndarray 11 | """ 12 | predictions = np.argmax(logits, axis=1) 13 | correct_samples = np.sum(predictions == y_true) 14 | total_samples = y_true.shape[0] 15 | return float(correct_samples) / total_samples 16 | 17 | 18 | def compute_accuracy(model, loader): 19 | """ 20 | :param model: a model which returns classifier_output and segmentator_output 21 | :param loader: data loader 22 | """ 23 | model.eval() # enter evaluation mode 24 | score_accum = 0 25 | count = 0 26 | 27 | for x, y, _, _ in loader: 28 | classifier_output, _ = model(x) 29 | score_accum += accuracy(classifier_output.data.cpu().numpy(), y.data.cpu().numpy()) * y.shape[0] 30 | count += y.shape[0] 31 | 32 | return float(score_accum / count) 33 | 34 | 35 | def iou(logits, y_true, smooth=1e-2): 36 | """ 37 | :param logits: np.ndarray, output of the model 38 | :param y_true: np.ndarray 39 | :param smooth: float 40 | """ 41 | batch_size, channels, samples = logits.shape 42 | values = np.zeros(channels) 43 | 44 | for i in range(channels): 45 | pred = logits[:, i, :] > 0.5 46 | gt = y_true[:, i, :].astype(np.bool) 47 | intersection = (pred & gt).sum(axis=1) 48 | union = (pred | gt).sum(axis=1) 49 | values[i] = np.mean((intersection + smooth) / (union + smooth)) 50 | 51 | return np.mean(values) 52 | 53 | 54 | def compute_iou(model, loader): 55 | """ 56 | Computes intersection over union on the dataset wrapped in a loader 57 | :param model: a model which returns classifier_output and segmentator_output 58 | :param loader: data loader 59 | returns: IoU (jaccard index) for integration and intersection masks 60 | """ 61 | model.eval() # Evaluation mode 62 | integration_score = [] 63 | intersection_score = [] 64 | 65 | for x, _, integration_mask, intersection_mask in loader: 66 | _, segmentator_output = model(x) 67 | predicted_integration_mask = segmentator_output[:, 0, :].data.cpu().numpy() 68 | predicted_intersection_mask = segmentator_output[:, 1, :].data.cpu().numpy() 69 | 70 | integration_mask = integration_mask.data.cpu().numpy() 71 | intersection_mask = intersection_mask.data.cpu().numpy() 72 | 73 | integration_score.append(iou(predicted_integration_mask, integration_mask)) 74 | intersection_score.append(iou(predicted_intersection_mask, intersection_mask)) 75 | 76 | return np.mean(integration_score), np.mean(intersection_score) 77 | 78 | 79 | class WeightedBCE: 80 | def __init__(self, weights=None): 81 | self.weights = weights 82 | self.logsigmoid = nn.LogSigmoid() 83 | 84 | def __call__(self, output, target): 85 | if self.weights is not None: 86 | assert len(self.weights) == 2 87 | loss = self.weights[1] * (target * self.logsigmoid(output)) + \ 88 | self.weights[0] * ((1 - target) * self.logsigmoid(-output)) 89 | else: 90 | loss = target * self.logsigmoid(output) + (1 - target) * self.logsigmoid(-output) 91 | return torch.neg(torch.mean(loss)) 92 | 93 | 94 | class DiceLoss: 95 | def __init__(self, smooth=1e-2): 96 | self.smooth = smooth 97 | 98 | def __call__(self, output, target): 99 | output = output.sigmoid() 100 | numerator = torch.sum(output * target, dim=1) 101 | denominator = torch.sum(torch.sqrt(output) + target, dim=1) 102 | return 1 - torch.mean((2 * numerator + self.smooth) / (denominator + self.smooth)) 103 | 104 | 105 | class CombinedLoss: 106 | def __init__(self, weights=None): 107 | self.dice = DiceLoss() 108 | self.bce = WeightedBCE(weights) 109 | 110 | def __call__(self, output, target): 111 | return self.dice(output, target) + self.bce(output, target) 112 | 113 | 114 | def train_model(model, loader, val_loader, 115 | optimizer, num_epoch, 116 | print_epoch=10, 117 | classification_metric=None, 118 | segmentation_metric=None, 119 | scheduler=None, 120 | label_criterion=None, 121 | integration_criterion=None, 122 | intersection_criterion=None, 123 | accumulation=1, 124 | loss_ax=None, 125 | classification_score_ax=None, 126 | segmentation_score_ax=None, 127 | figure=None, canvas=None): 128 | loss_history = [] 129 | train_classification_score_history = [] 130 | train_segmentation_score_history = [] 131 | val_classification_score_history = [] 132 | val_segmentation_score_history = [] 133 | best_score = 0 134 | for epoch in range(num_epoch): 135 | model.train() # enter train mode 136 | loss_accum = 0 137 | classification_score_accum = 0 138 | segemntation_score_accum = 0 139 | count = 0 140 | step = 0 141 | for x, y, integration_mask, intersection_mask in loader: 142 | classifier_output, integrator_output = model(x) 143 | # classifier_output = classifier_output.view(1, -1) 144 | # calculate loss and gradients 145 | loss = torch.tensor(0, dtype=torch.float32, device=x.device) 146 | if label_criterion is not None: 147 | loss = loss + label_criterion(classifier_output, y) 148 | if integration_criterion is not None: 149 | loss = loss + integration_criterion(integrator_output[:, 0, :], integration_mask) 150 | if intersection_criterion is not None: 151 | loss = loss + intersection_criterion(integrator_output[:, 1, :], intersection_mask) 152 | loss = loss / accumulation 153 | loss.backward() 154 | 155 | step += 1 156 | if step == accumulation: # accumulate loss over few batches 157 | optimizer.step() 158 | optimizer.zero_grad() 159 | step = 0 160 | 161 | if classification_metric is not None: 162 | classification_score_accum += classification_metric(classifier_output.detach().cpu().numpy(), 163 | y.detach().cpu().numpy()) * len(y) 164 | if segmentation_metric is not None: 165 | gt = np.stack((integration_mask.data.cpu().numpy(), 166 | intersection_mask.data.cpu().numpy())).transpose(1, 0, 2) 167 | segemntation_score_accum += segmentation_metric(integrator_output.detach().cpu().sigmoid().numpy(), 168 | gt) * len(y) 169 | loss_accum += loss 170 | count += len(y) 171 | loss_history.append(float(loss_accum / count)) # average loss over epoch 172 | train_classification_score_history.append(float(classification_score_accum / count)) 173 | train_segmentation_score_history.append(float(segemntation_score_accum / count)) 174 | 175 | model.eval() # enter evaluation mode 176 | classification_score_accum = 0 177 | segemntation_score_accum = 0 178 | count = 0 179 | for x, y, integration_mask, intersection_mask in val_loader: 180 | classifier_output, integrator_output = model(x) 181 | if classification_metric is not None: 182 | classification_score_accum += classification_metric(classifier_output.detach().cpu().numpy(), 183 | y.detach().cpu().numpy()) * len(y) 184 | if segmentation_metric is not None: 185 | gt = np.stack((integration_mask.data.cpu().numpy(), 186 | intersection_mask.data.cpu().numpy())).transpose(1, 0, 2) 187 | segemntation_score_accum += segmentation_metric(integrator_output.detach().cpu().sigmoid().numpy(), 188 | gt) * len(y) 189 | count += len(y) 190 | val_classification_score_history.append(float(classification_score_accum / count)) 191 | val_segmentation_score_history.append(float(segemntation_score_accum / count)) 192 | 193 | # save best model based on classification score (if it is not None) 194 | if classification_metric is not None and segmentation_metric is not None: 195 | if best_score < val_classification_score_history[-1] * val_segmentation_score_history[-1]: 196 | best_score = val_classification_score_history[-1] * val_segmentation_score_history[-1] 197 | torch.save(model.state_dict(), 198 | os.path.join('data/tmp_weights', model.__class__.__name__)) # save best model 199 | elif classification_metric is not None: 200 | if best_score < val_classification_score_history[-1]: 201 | best_score = val_classification_score_history[-1] 202 | torch.save(model.state_dict(), 203 | os.path.join('data/tmp_weights', model.__class__.__name__)) # save best model 204 | elif segmentation_metric is not None: 205 | if best_score < val_segmentation_score_history[-1]: 206 | best_score = val_segmentation_score_history[-1] 207 | torch.save(model.state_dict(), 208 | os.path.join('data/tmp_weights', model.__class__.__name__)) # save best model 209 | 210 | if scheduler: 211 | scheduler.step() 212 | 213 | if not epoch % print_epoch or epoch == num_epoch - 1: 214 | print('Epoch #{}, train loss: {:.4f}'.format( 215 | epoch, loss_history[-1])) 216 | if classification_metric is not None: 217 | print('Train classification score: {:.4f}, val classificiation score: {:.4f}'.format( 218 | train_classification_score_history[-1], 219 | val_classification_score_history[-1] 220 | )) 221 | if segmentation_metric is not None: 222 | print('Train segmentation score: {:.4f}, val segmentation score: {:.4f}'.format( 223 | train_segmentation_score_history[-1], 224 | val_segmentation_score_history[-1] 225 | )) 226 | 227 | # visualization 228 | if loss_ax is not None: 229 | loss_ax.clear() 230 | loss_ax.plot(loss_history) 231 | loss_ax.set_title('Loss function') 232 | if classification_score_ax is not None: 233 | classification_score_ax.clear() 234 | classification_score_ax.plot(train_classification_score_history, label='train') 235 | classification_score_ax.plot(val_classification_score_history, label='validation') 236 | classification_score_ax.legend(loc='best') 237 | classification_score_ax.set_title('Classification score') 238 | if segmentation_score_ax is not None: 239 | segmentation_score_ax.clear() 240 | segmentation_score_ax.plot(train_segmentation_score_history, label='train') 241 | segmentation_score_ax.plot(val_segmentation_score_history, label='validation') 242 | segmentation_score_ax.legend(loc='best') 243 | segmentation_score_ax.set_title('Segmentation score') 244 | if figure is not None: 245 | figure.tight_layout() 246 | if canvas is not None: 247 | canvas.draw() 248 | return (loss_history, 249 | train_classification_score_history, 250 | train_segmentation_score_history, 251 | val_classification_score_history, 252 | val_segmentation_score_history) 253 | -------------------------------------------------------------------------------- /processing_utils/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | try: 4 | from cython_utils.roi import get_ROIs 5 | except ImportError: 6 | from processing_utils.roi import get_ROIs 7 | from processing_utils.matching import construct_mzregions, rt_grouping, align_component 8 | from processing_utils.run_utils import preprocess, get_borders, Feature, \ 9 | border_correction, build_features, feature_collapsing 10 | 11 | 12 | class BasicRunner: 13 | """ 14 | A runner to process single roi 15 | 16 | Parameters 17 | ---------- 18 | mode : str 19 | A one of two 'all in one' of 'sequential' 20 | models : list 21 | a list of models 22 | peak_minimum_points : int 23 | - 24 | 25 | Attributes 26 | ---------- 27 | mode : str 28 | A one of two 'all in one' of 'sequential' 29 | model : nn.Module 30 | an ANN model if mode is 'all in one' (optional) 31 | classifier : nn.Module 32 | an ANN model for classification (optional) 33 | segmentator : nn.Module 34 | an ANN model for segmentation (optional) 35 | peak_minimum_points : int 36 | minimum peak length in points 37 | 38 | """ 39 | def __init__(self, mode, models, peak_minimum_points, device): 40 | self.mode = mode 41 | if self.mode == 'all in one': 42 | self.model = models[0] 43 | elif self.mode == 'sequential': 44 | self.classifier, self.segmentator = models 45 | else: 46 | assert False, mode 47 | self.peak_minimum_points = peak_minimum_points 48 | self.device = device 49 | 50 | def __call__(self, roi, sample_name, progress_callback=None, operation_callback=None): 51 | """ 52 | Processing single roi 53 | 54 | Parameters 55 | ---------- 56 | roi : ROI 57 | - 58 | sample_name : str 59 | Arbitrary sample name 60 | Returns 61 | ------- 62 | feature : list 63 | a list of 'Feature' objects 64 | """ 65 | if self.mode == 'all in one': 66 | signal = preprocess(roi.i, self.device) 67 | classifier_output, segmentator_output = self.model(signal) 68 | elif self.mode == 'sequential': 69 | signal = preprocess(roi.i, self.device, interpolate=True, length=256) 70 | classifier_output, _ = self.classifier(signal) 71 | # to do: second step should be only for peaks 72 | _, segmentator_output = self.segmentator(signal) 73 | else: 74 | assert False, self.mode 75 | classifier_output = classifier_output.data.cpu().numpy() 76 | segmentator_output = segmentator_output.data.sigmoid().cpu().numpy() 77 | 78 | # get label 79 | label = np.argmax(classifier_output) 80 | # get borders 81 | features = [] 82 | if label == 1: 83 | borders = get_borders(segmentator_output[0, 0, :], segmentator_output[0, 1, :], 84 | peak_minimum_points=self.peak_minimum_points, 85 | interpolation_factor=len(signal[0, 0]) / len(roi.i)) 86 | for border in borders: 87 | # to do: check correctness of rt calculations 88 | scan_frequency = (roi.scan[1] - roi.scan[0]) / (roi.rt[1] - roi.rt[0]) 89 | rtmin = roi.rt[0] + border[0] / scan_frequency 90 | rtmax = roi.rt[0] + border[1] / scan_frequency 91 | feature = Feature([sample_name], [roi], [border], [0], [np.sum(roi.i[border[0]:border[1]])], 92 | roi.mzmean, rtmin, rtmax, 0, 0) 93 | features.append(feature) 94 | return features 95 | 96 | 97 | class FilesRunner(BasicRunner): 98 | """ 99 | A runner to process *.mzML files 100 | 101 | Parameters 102 | ---------- 103 | mode : str 104 | A one of two 'all in one' of 'sequential' 105 | models : list 106 | a list of models 107 | delta_mz : float 108 | - 109 | required_points : int 110 | - 111 | peak_minimum_points : int 112 | - 113 | 114 | Attributes 115 | ---------- 116 | mode : str 117 | A one of two 'all in one' of 'sequential' 118 | model : nn.Module 119 | an ANN model if mode is 'all in one' (optional) 120 | classifier : nn.Module 121 | an ANN model for classification (optional) 122 | segmentator : nn.Module 123 | an ANN model for segmentation (optional) 124 | delta_mz : float 125 | a parameters for mz window in ROI detection 126 | required_points : int 127 | minimum ROI length in points 128 | dropped_points : int 129 | maximal number of zero points in a row (for ROI detection) 130 | peak_minimum_points : int 131 | minimum peak length in points 132 | 133 | """ 134 | def __init__(self, mode, models, delta_mz, 135 | required_points, dropped_points, 136 | peak_minimum_points, device): 137 | super(FilesRunner, self).__init__(mode, models, peak_minimum_points, device) 138 | self.delta_mz = delta_mz 139 | self.required_points = required_points 140 | self.dropped_points = dropped_points 141 | 142 | def __call__(self, files, progress_callback=None, operation_callback=None): 143 | if len(files) == 1: 144 | file = files[0] 145 | features = self._single_run(file, progress_callback, operation_callback) 146 | elif len(files) > 1: 147 | features = self._batch_run(files, progress_callback, operation_callback) 148 | else: 149 | features = [] 150 | return features 151 | 152 | def _single_run(self, file, progress_callback=None, operation_callback=None): 153 | """ 154 | Processing single *.mzML file 155 | 156 | Parameters 157 | ---------- 158 | file : str 159 | path to *.mzML file 160 | 161 | Returns 162 | ------- 163 | features : list 164 | a list of 'Feature' objects (each consist of single ROI) 165 | """ 166 | # get ROIs from raw spectrum 167 | if operation_callback is not None: 168 | operation_callback.emit(f'Detecting ROIs in {os.path.basename(file)}:') 169 | rois = get_ROIs(file, self.delta_mz, self.required_points, self.dropped_points, progress_callback) 170 | features = [] 171 | percentage = -1 172 | if operation_callback is not None: 173 | operation_callback.emit(f'Finding peaks in detected ROIs:') 174 | for i, roi in enumerate(rois): 175 | features_from_roi = super(FilesRunner, self).__call__(roi, file) 176 | features.extend(features_from_roi) 177 | new_percentage = int(i * 100 / len(rois)) 178 | if progress_callback is not None and new_percentage > percentage: 179 | percentage = new_percentage 180 | progress_callback.emit(percentage) 181 | 182 | parameters = {'files': [file], 'delta mz': self.delta_mz, 'required points': self.required_points, 183 | 'dropped_points': self.dropped_points, 'peak minimum points': self.peak_minimum_points} 184 | return features, parameters 185 | 186 | def _batch_run(self, files, progress_callback=None, operation_callback=None): 187 | """ 188 | Processing a batch of *.mzML files 189 | 190 | Parameters 191 | ---------- 192 | files : list 193 | list of paths to *.mzML files 194 | 195 | Returns 196 | ------- 197 | features : list 198 | a list of 'Feature' objects 199 | """ 200 | # ROI detection 201 | rois = {} 202 | for file in files: # get ROIs for every file 203 | if operation_callback is not None: 204 | operation_callback.emit(f'Detecting ROIs in {os.path.basename(file)}:') 205 | rois[file] = get_ROIs(file, self.delta_mz, self.required_points, self.dropped_points, progress_callback) 206 | 207 | 208 | if operation_callback is not None: 209 | operation_callback.emit(f'Alignment of ROIs:') 210 | # ROI alignment 211 | mzregions = construct_mzregions(rois, self.delta_mz) # construct mz regions 212 | components = rt_grouping(mzregions) # group ROIs in mz regions based on RT 213 | aligned_components = [] # component alignment 214 | percentage = -1 215 | for i, component in enumerate(components): 216 | aligned_components.append(align_component(component)) 217 | new_percentage = int(i * 100 / len(components)) 218 | if progress_callback is not None and new_percentage > percentage: 219 | percentage = new_percentage 220 | progress_callback.emit(percentage) 221 | 222 | if operation_callback is not None: 223 | operation_callback.emit(f'Finding peaks in detected ROIs:') 224 | # Classification, integration and correction 225 | component_number = 0 226 | features = [] 227 | percentage = -1 228 | for j, component in enumerate(aligned_components): # run through components 229 | borders = {} # borders for rois with peaks 230 | to_delete = [] # noisy rois in components 231 | for i, (sample, roi) in enumerate(zip(component.samples, component.rois)): 232 | if self.mode == 'all in one': 233 | signal = preprocess(roi.i, self.device) 234 | classifier_output, segmentator_output = self.model(signal) 235 | classifier_output = classifier_output.data.cpu().numpy() 236 | label = np.argmax(classifier_output) 237 | if label == 1: 238 | segmentator_output = segmentator_output.data.sigmoid().cpu().numpy() 239 | borders[sample] = get_borders(segmentator_output[0, 0, :], segmentator_output[0, 1, :], 240 | peak_minimum_points=self.peak_minimum_points) 241 | else: 242 | to_delete.append(i) 243 | elif self.mode == 'sequential': 244 | signal = preprocess(roi.i, self.device, interpolate=True, length=256) 245 | classifier_output, _ = self.classifier(signal) 246 | classifier_output = classifier_output.data.cpu().numpy() 247 | label = np.argmax(classifier_output) 248 | if label == 1: 249 | _, segmentator_output = self.segmentator(signal) 250 | segmentator_output = segmentator_output.data.sigmoid().cpu().numpy() 251 | borders[sample] = get_borders(segmentator_output[0, 0, :], segmentator_output[0, 1, :], 252 | peak_minimum_points=self.peak_minimum_points, 253 | interpolation_factor=len(signal[0, 0]) / len(roi.i)) 254 | else: 255 | to_delete.append(i) 256 | else: 257 | assert False, self.mode 258 | 259 | if len(borders) > len(files) // 3: # enough rois contain a peak 260 | component.pop(to_delete) # delete ROIs which don't contain peaks 261 | border_correction(component, borders) 262 | features.extend(build_features(component, borders, component_number)) 263 | component_number += 1 264 | 265 | new_percentage = int(j * 100 / len(aligned_components)) 266 | if progress_callback is not None and new_percentage > percentage: 267 | percentage = new_percentage 268 | progress_callback.emit(percentage) 269 | 270 | features = feature_collapsing(features) 271 | # to do: is it necessary? 272 | # explicitly delete features which were found in not enough quantity of ROIs 273 | to_delete = [] 274 | for i, feature in enumerate(features): 275 | if len(feature) <= len(files) // 3: # to do: adjustable parameter 276 | to_delete.append(i) 277 | for j in to_delete[::-1]: 278 | features.pop(j) 279 | print('total number of features: {}'.format(len(features))) 280 | parameters = {'files': files, 'delta mz': self.delta_mz, 'required points': self.required_points, 281 | 'dropped_points': self.dropped_points, 'peak minimum points': self.peak_minimum_points} 282 | return features, parameters 283 | -------------------------------------------------------------------------------- /processing_utils/matching.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from collections import defaultdict 5 | from scipy.sparse.csgraph import connected_components 6 | from processing_utils.roi import ROI 7 | 8 | 9 | class mzRegion: 10 | """ 11 | A class that stores the beginning and the of the mass region 12 | and all the ROIs that lays there 13 | """ 14 | def __init__(self, mzbegin, mzend, rois=None): 15 | """ 16 | :param mzbegin: begin of the mass region 17 | :param mzend: end of the mass region 18 | :param rois: ROIs in the region should be defaultdict(list) 19 | """ 20 | self.mzbegin = mzbegin 21 | self.mzend = mzend 22 | self.rois = defaultdict(list) if rois is None else rois 23 | 24 | def __contains__(self, mz): 25 | return self.mzbegin <= mz <= self.mzend 26 | 27 | def __len__(self): 28 | return len(self.rois) 29 | 30 | def extend(self, sample_rois): 31 | for k, v in sample_rois.items(): 32 | self.rois[k].extend(v) 33 | 34 | def append(self, sample, roi): 35 | self.rois[sample].append(roi) 36 | 37 | 38 | def construct_mzregions(ROIs, delta_mz): 39 | """ 40 | :param ROIs: a dictionary, where keys are file names and values are lists of rois 41 | :param delta_mz: int 42 | :return: a list of mzRegion objects 43 | """ 44 | mz_mins = np.array([min(roi.mz) for s in ROIs.values() for roi in s]) 45 | mz_maxs = np.array([max(roi.mz) for s in ROIs.values() for roi in s]) 46 | rois = np.array([(name, roi) for name, s in ROIs.items() for roi in s]) 47 | 48 | # reorder values based on mins 49 | order = np.argsort(mz_mins) 50 | mz_mins = mz_mins[order] 51 | mz_maxs = mz_maxs[order] 52 | rois = rois[order] 53 | 54 | mzregions = [] 55 | roi_dict = defaultdict(list) # save all rois within current region 56 | region_begin, region_end = mz_mins[0], mz_maxs[0] 57 | for begin, end, name_roi in zip(mz_mins, mz_maxs, rois): 58 | if begin > region_end + delta_mz: 59 | # add new mzRegion object 60 | mzregions.append(mzRegion(region_begin, region_end, roi_dict)) 61 | roi_dict = defaultdict(list) 62 | region_begin, region_end = begin, end 63 | else: 64 | region_end = end if end > region_end else region_end 65 | name, roi = name_roi 66 | roi_dict[name].append(roi) 67 | mzregions.append(mzRegion(region_begin, region_end, roi_dict)) 68 | return mzregions 69 | 70 | 71 | def intersected(begin1, end1, begin2, end2, percentage=None): 72 | """ 73 | A simple function which determines if two segments intersect 74 | :return: bool 75 | """ 76 | lower = (end1 <= end2) and (end1 > begin2) 77 | bigger = (end1 > end2) and (end2 > begin1) 78 | if percentage is None: 79 | ans = (lower or bigger) 80 | else: 81 | if lower: 82 | intersection = end1 - np.max([begin1, begin2]) 83 | smallest = min((end1 - begin1, end2 - begin2)) 84 | ans = (intersection / smallest) > percentage 85 | elif bigger: 86 | intersection = end2 - np.max([begin1, begin2]) 87 | smallest = min((end1 - begin1, end2 - begin2)) 88 | ans = (intersection / smallest) > percentage 89 | else: 90 | ans = False 91 | return ans 92 | 93 | 94 | def roi_intersected(one_roi, two_roi, percentage=0.3): 95 | """ 96 | A function that determines if two roi intersect based on rt and mz. 97 | :return: bool 98 | """ 99 | ans = False 100 | if intersected(min(one_roi.mz), 101 | max(one_roi.mz), 102 | min(two_roi.mz), 103 | max(two_roi.mz)) and \ 104 | intersected(one_roi.rt[0], 105 | one_roi.rt[1], 106 | two_roi.rt[0], 107 | two_roi.rt[1], 108 | percentage): 109 | ans = True 110 | return ans 111 | 112 | 113 | def rt_grouping(mzregions): 114 | """ 115 | A function that groups roi inside mzregions. 116 | :param mzregions: a list of mzRegion objects 117 | :return: a list of defaultdicts, where the key is the name of file and value is a list of ROIs 118 | """ 119 | components = [] 120 | for region in mzregions: 121 | region = np.array([(name, roi) for name, s in region.rois.items() for roi in s]) 122 | n = len(region) 123 | graph = np.zeros((n, n), dtype=np.uint8) 124 | for i in range(n - 1): 125 | for j in range(i + 1, n): 126 | graph[i, j] = roi_intersected(region[i][1], region[j][1]) 127 | n_components, labels = connected_components(graph, directed=False) 128 | 129 | for k in range(n_components): 130 | rois = region[labels == k] 131 | component = defaultdict(list) 132 | for roi in rois: 133 | component[roi[0]].append(roi[1]) 134 | components.append(component) 135 | return components 136 | 137 | 138 | class groupedROI: 139 | """ 140 | A class that represents a group of ROIs 141 | """ 142 | 143 | def __init__(self, rois, shifts, samples, grouping): 144 | self.rois = rois # rois 145 | self.shifts = shifts # shifts for each roi 146 | self.samples = samples # samples names 147 | self.grouping = grouping # similarity groups 148 | 149 | def __len__(self): 150 | length = len(self.rois) 151 | assert length == len(self.shifts) 152 | assert length == len(self.samples) 153 | assert length == len(self.grouping) 154 | return length 155 | 156 | def append(self, roi, shift, sample, group_number): 157 | self.rois.append(roi) 158 | self.shifts.append(shift) 159 | self.samples.append(sample) 160 | self.grouping.append(group_number) 161 | 162 | def pop(self, idx): 163 | if isinstance(idx, list): 164 | for j in sorted(idx, reverse=True): 165 | self.rois.pop(j) 166 | self.shifts.pop(j) 167 | self.samples.pop(j) 168 | self.grouping.pop(j) 169 | else: 170 | assert isinstance(idx, int) 171 | self.rois.pop(idx) 172 | self.shifts.pop(idx) 173 | self.samples.pop(idx) 174 | self.grouping.pop(idx) 175 | 176 | def plot(self, based_on_grouping=False): 177 | """ 178 | Visualize a groupedROI object 179 | """ 180 | name2label = {} 181 | label2class = {} 182 | labels = set() 183 | if based_on_grouping: 184 | labels = set(self.grouping) 185 | for i, sample in enumerate(self.samples): 186 | name2label[sample] = self.grouping[i] 187 | for label in labels: 188 | label2class[label] = label # identical transition 189 | else: 190 | for sample in self.samples: 191 | label = os.path.basename(os.path.dirname(sample)) 192 | labels.add(label) 193 | name2label[sample] = label 194 | 195 | for i, label in enumerate(labels): 196 | label2class[label] = i 197 | 198 | m = len(labels) 199 | mz = [] 200 | scan_begin = [] 201 | scan_end = [] 202 | fig, axes = plt.subplots(1, 2) 203 | for sample, roi, shift in zip(self.samples, self.rois, self.shifts): 204 | mz.append(roi.mzmean) 205 | scan_begin.append(roi.scan[0] + shift) 206 | scan_end.append(roi.scan[1] + shift) 207 | y = roi.i 208 | x = np.linspace(roi.scan[0], roi.scan[1], len(y)) 209 | x_shifted = np.linspace(roi.scan[0] + shift, roi.scan[1] + shift, len(y)) 210 | label = label2class[name2label[sample]] 211 | c = [label / m, 0.0, (m - label) / m] 212 | axes[0].plot(x, y, color=c) 213 | axes[1].plot(x_shifted, y, color=c) 214 | fig.suptitle('mz = {:.4f}, scan = {:.2f} -{:.2f}'.format(np.mean(mz), min(scan_begin), max(scan_end))) 215 | 216 | def adjust(self, history, adjustment_threshold): 217 | labels, counts = np.unique(self.grouping, return_counts=True) 218 | counter = {label: count for label, count in zip(labels, counts)} 219 | 220 | for i, sample in enumerate(self.samples): 221 | if counter[self.grouping[i]] == 1: 222 | best_gn, best_corr, best_shift = None, None, None 223 | for gn, corr, shift in history[sample]: 224 | if (best_corr is None or corr > best_corr) and counter[gn] != 1: 225 | best_gn = gn 226 | best_corr = corr 227 | best_shift = shift 228 | if best_corr is not None and best_corr > adjustment_threshold: 229 | self.grouping[i] = best_gn 230 | self.shifts[i] = best_shift 231 | 232 | 233 | def stitch_component(component): 234 | """ 235 | Stitching roi, which resulted from one file in one group 236 | :param component: defaultdict where the key is the name of file and value is a list of ROIs 237 | :return: new_component with stitched ROIs 238 | """ 239 | new_component = defaultdict(list) 240 | for file in component: 241 | begin_scan, end_scan = component[file][0].scan 242 | begin_rt, end_rt = component[file][0].rt 243 | for roi in component[file]: 244 | begin, end = roi.scan 245 | if begin < begin_scan: 246 | begin_scan = begin 247 | begin_rt = roi.rt[0] 248 | if end > end_scan: 249 | end_scan = end 250 | end_rt = roi.rt[1] 251 | 252 | # to do: use parameter with missing zeros (not 7) 253 | i = np.zeros(end_scan - begin_scan + 1) 254 | mz = np.zeros(end_scan - begin_scan + 1) 255 | for roi in component[file]: 256 | begin, end = roi.scan 257 | i[begin - begin_scan:end - begin_scan + 1] = roi.i 258 | mz[begin - begin_scan:end - begin_scan + 1] = roi.mz 259 | mzmean = np.mean(mz[mz != 0]) # mean based on nonzero elements 260 | new_component[file] = [ROI([begin_scan, end_scan], 261 | [begin_rt, end_rt], 262 | i, mz, mzmean)] 263 | return new_component 264 | 265 | 266 | def conv2correlation(roi_i, base_roi_i, conv_vector): 267 | n = np.zeros_like(conv_vector) 268 | x = np.sum(roi_i) 269 | y = np.sum(base_roi_i) 270 | x_square = np.sum(np.power(roi_i, 2)) # to do: make roi.i np.array by default 271 | y_square = np.sum(np.power(base_roi_i, 2)) 272 | 273 | min_l = min((len(roi_i), len(base_roi_i))) 274 | max_l = max((len(roi_i), len(base_roi_i))) 275 | 276 | n[:min_l - 1] = np.arange(len(conv_vector), max_l, -1) 277 | n[min_l - 1:min_l + max_l - min_l] = max_l 278 | n[min_l + max_l - min_l:] = np.arange(max_l + 1, len(conv_vector) + 1, 1) 279 | 280 | return (n * conv_vector - x * y) / (np.sqrt(n * x_square - x ** 2) * np.sqrt(n * y_square - y ** 2)) 281 | 282 | 283 | def align_component(component, max_shift=20): 284 | """ 285 | Align ROIs in component based on point-wise correlation 286 | :param component: defaultdict where the key is the name of file and value is a list of ROIs 287 | :param max_shift: maximum shift in scans 288 | :return: an groupedROI object 289 | """ 290 | # stitching first 291 | component = stitch_component(component) 292 | # find base_sample which correspond to the sample with highest intensity within roi 293 | correlation_threshold = 0.8 294 | adjustment_threshold = 0.4 295 | group_number = 0 296 | aligned_component = groupedROI([], [], [], []) 297 | # save (group-number, correlation, shift) 298 | history = defaultdict(list) 299 | 300 | while len(component) != 0: 301 | # choose base ROI from the remaining 302 | max_i = 0 303 | base_sample, base_roi = None, None 304 | for sample in component: 305 | assert len(component[sample]) == 1 306 | for roi in component[sample]: # in fact there are only one roi after stitching 307 | i = np.max(roi.i) 308 | if i > max_i: 309 | max_i = i 310 | base_sample, base_roi = sample, roi 311 | 312 | component.pop(base_sample) # delete chosen ROI from component 313 | aligned_component.append(base_roi, 0, base_sample, group_number) 314 | 315 | to_delete = [] 316 | for sample in component: 317 | roi = component[sample][0] 318 | # position, when two ROIs begins simultaneously 319 | pos = len(roi.i) - 1 # to do: check it 320 | conv_vector = np.convolve(roi.i[::-1], base_roi.i, mode='full') # reflection is necessary 321 | corr_vector = conv2correlation(roi.i, base_roi.i, conv_vector) 322 | # to do: find local maxima greater than threshold 323 | shift = np.argmax(corr_vector) - pos - roi.scan[0] + base_roi.scan[0] 324 | max_corr = np.max(corr_vector) 325 | if max_corr > correlation_threshold: 326 | to_delete.append(sample) # delete ROI from component 327 | aligned_component.append(roi, shift, sample, group_number) 328 | history[sample].append((group_number, max_corr, shift)) # history for after adjustment 329 | 330 | for sample in to_delete: 331 | component.pop(sample) 332 | 333 | group_number += 1 # increase group number 334 | 335 | aligned_component.adjust(history, adjustment_threshold) # to do: decide is it necessary 336 | return aligned_component 337 | -------------------------------------------------------------------------------- /gui_utils/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import shutil 6 | import threading 7 | from torch.utils.data import DataLoader 8 | from PyQt5 import QtWidgets 9 | import matplotlib.pyplot as plt 10 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 11 | from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar 12 | from gui_utils.auxilary_utils import GetFolderWidget 13 | from training_utils.dataset import ROIDataset 14 | from training_utils.training import train_model, CombinedLoss, accuracy, iou 15 | from models.rcnn import RecurrentCNN 16 | from models.cnn_classifier import Classifier 17 | from models.cnn_segmentator import Segmentator 18 | 19 | 20 | class TrainingParameterWindow(QtWidgets.QDialog): 21 | """ 22 | Training Parameter Window, where one should choose parameters for training 23 | 24 | Parameters 25 | ---------- 26 | mode : str 27 | A one of two 'all in one' of 'sequential' 28 | parent : MainWindow(QtWidgets.QMainWindow) 29 | - 30 | Attributes 31 | ---------- 32 | mode : str 33 | A one of two 'all in one' of 'sequential' 34 | parent : MainWindow(QtWidgets.QMainWindow) 35 | - 36 | train_folder_getter : GetFolderWidget 37 | A getter for a path to train data 38 | val_folder_getter : GetFolderWidget 39 | A getter for a path to validation data 40 | """ 41 | def __init__(self, mode, parent=None): 42 | self.mode = mode 43 | self.parent = parent 44 | super().__init__(parent) 45 | self.setWindowTitle('peakonly: models') 46 | 47 | train_folder_label = QtWidgets.QLabel() 48 | train_folder_label.setText('Choose a folder with train data:') 49 | self.train_folder_getter = GetFolderWidget(os.path.join(os.getcwd(), 'data', 'train'), self) 50 | 51 | val_folder_label = QtWidgets.QLabel() 52 | val_folder_label.setText('Choose a folder with validation data:') 53 | self.val_folder_getter = GetFolderWidget(os.path.join(os.getcwd(), 'data', 'val'), self) 54 | 55 | continue_button = QtWidgets.QPushButton('Continue') 56 | continue_button.clicked.connect(self._continue) 57 | 58 | main_layout = QtWidgets.QVBoxLayout() 59 | main_layout.addWidget(train_folder_label) 60 | main_layout.addWidget(self.train_folder_getter) 61 | main_layout.addWidget(val_folder_label) 62 | main_layout.addWidget(self.val_folder_getter) 63 | main_layout.addWidget(continue_button) 64 | 65 | self.setLayout(main_layout) 66 | 67 | def _continue(self): 68 | try: 69 | train_folder = self.train_folder_getter.get_folder() 70 | val_folder = self.val_folder_getter.get_folder() 71 | main_window = TrainingMainWindow(self.mode, train_folder, val_folder, self.parent) 72 | main_window.show() 73 | self.close() 74 | except ValueError: 75 | # popup window with exception 76 | msg = QtWidgets.QMessageBox(self) 77 | msg.setText("Check parameters. Something is wrong!") 78 | msg.setIcon(QtWidgets.QMessageBox.Warning) 79 | msg.exec_() 80 | 81 | 82 | class TrainingMainWindow(QtWidgets.QDialog): 83 | """ 84 | Training Main Window, where training process occurs 85 | 86 | Parameters 87 | ---------- 88 | mode : str 89 | A one of two 'all in one' of 'sequential' 90 | train_folder : str 91 | A path to the folder with training data 92 | val_folder : str 93 | A path to the folder with validation data 94 | parent : MainWindow(QtWidgets.QMainWindow) 95 | - 96 | Attributes 97 | ---------- 98 | mode : str 99 | A one of two 'all in one' of 'sequential' 100 | parent : MainWindow(QtWidgets.QMainWindow) 101 | - 102 | 103 | """ 104 | def __init__(self, mode, train_folder, val_folder, parent): 105 | self.mode = mode 106 | self.parent = parent 107 | super().__init__(parent) 108 | main_layout = QtWidgets.QVBoxLayout() 109 | # to do: device should be adjustable parameter 110 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 111 | if self.mode == 'all in one': 112 | # create data loaders 113 | train_dataset = ROIDataset(path=train_folder, device=device, balanced=True) 114 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) 115 | val_dataset = ROIDataset(path=val_folder, device=device, balanced=False) 116 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) 117 | # create model 118 | model = RecurrentCNN().to(device) 119 | optimizer = optim.Adam(params=model.parameters(), lr=1e-3) 120 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15, eta_min=1e-6) 121 | label_criterion = nn.CrossEntropyLoss() 122 | integration_criterion = CombinedLoss([0.4, 0.2]) 123 | intersection_criterion = CombinedLoss([0.1, 2]) 124 | # add training widget 125 | main_layout.addWidget(TrainingMainWidget(train_loader, val_loader, model, optimizer, accuracy, iou, 126 | scheduler, label_criterion, integration_criterion, 127 | intersection_criterion, 64, self)) 128 | elif self.mode == 'sequential': 129 | # create data loaders 130 | batch_size = 64 131 | train_dataset = ROIDataset(path=train_folder, device=device, interpolate=True, length=256, balanced=True) 132 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 133 | val_dataset = ROIDataset(path=val_folder, device=device, interpolate=True, length=256, balanced=False) 134 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 135 | # classifier 136 | classifier = Classifier().to(device) 137 | optimizer = optim.Adam(params=classifier.parameters(), lr=1e-3) 138 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6) 139 | label_criterion = nn.CrossEntropyLoss() 140 | main_layout.addWidget(TrainingMainWidget(train_loader, val_loader, classifier, optimizer, accuracy, None, 141 | scheduler, label_criterion, None, None, 1, self)) 142 | # segmentator 143 | segmentator = Segmentator().to(device) 144 | optimizer = optim.Adam(params=segmentator.parameters(), lr=1e-2) 145 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15, eta_min=1e-6) 146 | integration_criterion = CombinedLoss([0.4, 0.2]) 147 | intersection_criterion = CombinedLoss([0.1, 2]) 148 | main_layout.addWidget(TrainingMainWidget(train_loader, val_loader, segmentator, optimizer, None, iou, 149 | scheduler, None, integration_criterion, 150 | intersection_criterion, 1, self)) 151 | self.setLayout(main_layout) 152 | 153 | 154 | class TrainingMainWidget(QtWidgets.QWidget): 155 | """ 156 | Training Main Widget, where training process of one model occurs 157 | 158 | Parameters 159 | ---------- 160 | train_loader : DataLoader 161 | - 162 | val_loader : DataLoader 163 | - 164 | model : nn.Module 165 | model to train 166 | parent : QDialog 167 | - 168 | Attributes 169 | ---------- 170 | parent : MainWindow(QtWidgets.QMainWindow) 171 | - 172 | """ 173 | def __init__(self, train_loader, val_loader, model, optimizer, classification_metric, segmenatation_metric, 174 | scheduler, label_criterion, integration_criterion, intersection_criterion, accumulation, parent): 175 | self.parent = parent 176 | super().__init__(parent) 177 | self.setWindowTitle('peakonly: training') 178 | 179 | self.train_loader = train_loader 180 | self.val_loader = val_loader 181 | self.model = model 182 | self.optimizer = optimizer 183 | self.classification_metric = classification_metric 184 | self.segmentation_metric = segmenatation_metric 185 | self.scheduler = scheduler 186 | self.label_criterion = label_criterion 187 | self.integration_criterion = integration_criterion 188 | self.intersection_criterion = intersection_criterion 189 | self.accumulation = accumulation 190 | 191 | self._init_ui() 192 | 193 | def _init_ui(self): 194 | # canvas layout (with 3 subplots) 195 | self.figure = plt.figure() 196 | self.loss_ax = self.figure.add_subplot(131) 197 | self.loss_ax.set_title('Loss function') 198 | self.classification_score_ax = self.figure.add_subplot(132) 199 | self.classification_score_ax.set_title('Classification score') 200 | self.segmentation_score_ax = self.figure.add_subplot(133) 201 | self.segmentation_score_ax.set_title('Segmentation score') 202 | self.canvas = FigureCanvas(self.figure) 203 | toolbar = NavigationToolbar(self.canvas, self) 204 | canvas_layout = QtWidgets.QVBoxLayout() 205 | canvas_layout.addWidget(toolbar) 206 | canvas_layout.addWidget(self.canvas) 207 | self.figure.tight_layout() 208 | 209 | 210 | # training parameters layout 211 | parameters_layout = QtWidgets.QVBoxLayout() 212 | empty_label = QtWidgets.QLabel() 213 | 214 | number_of_epochs_label = QtWidgets.QLabel() 215 | number_of_epochs_label.setText('Number of epochs:') 216 | self.number_of_epochs_getter = QtWidgets.QLineEdit(self) 217 | self.number_of_epochs_getter.setText('100') 218 | 219 | learning_rate_label = QtWidgets.QLabel() 220 | learning_rate_label.setText('Learning rate:') 221 | self.learning_rate_getter = QtWidgets.QLineEdit(self) 222 | self.learning_rate_getter.setText('1e-3') 223 | 224 | parameters_layout.addWidget(empty_label, 80) 225 | parameters_layout.addWidget(number_of_epochs_label, 5) 226 | parameters_layout.addWidget(self.number_of_epochs_getter, 5) 227 | parameters_layout.addWidget(learning_rate_label, 5) 228 | parameters_layout.addWidget(self.learning_rate_getter, 5) 229 | 230 | # buttons layout 231 | buttons_layout = QtWidgets.QHBoxLayout() 232 | restart_button = QtWidgets.QPushButton('Restart') 233 | restart_button.clicked.connect(self.restart) 234 | buttons_layout.addWidget(restart_button) 235 | save_weights_button = QtWidgets.QPushButton('Save weights') 236 | save_weights_button.clicked.connect(self.save_weights) 237 | buttons_layout.addWidget(save_weights_button) 238 | run_training_button = QtWidgets.QPushButton('Run training') 239 | run_training_button.clicked.connect(self.run_training) 240 | buttons_layout.addWidget(run_training_button) 241 | 242 | # main layouts 243 | upper_layout = QtWidgets.QHBoxLayout() 244 | upper_layout.addLayout(canvas_layout, 85) 245 | upper_layout.addLayout(parameters_layout, 15) 246 | 247 | main_layout = QtWidgets.QVBoxLayout() 248 | main_layout.addLayout(upper_layout) 249 | main_layout.addLayout(buttons_layout) 250 | self.setLayout(main_layout) 251 | 252 | def restart(self): 253 | # to do: change restart (problem with optimizer, etc.) 254 | self.loss_ax.clear() 255 | self.loss_ax.set_title('Loss function') 256 | self.classification_score_ax.clear() 257 | self.classification_score_ax.set_title('Classification score') 258 | self.segmentation_score_ax.clear() 259 | self.classification_score_ax.set_title('Segmentation score') 260 | self.figure.tight_layout() 261 | self.canvas.draw() 262 | self.model = self.model.__class__() 263 | 264 | def save_weights(self): 265 | subwindow = SaveModelWindow(self.model, self) 266 | subwindow.show() 267 | 268 | def run_training(self): 269 | try: 270 | number_of_epoch = int(self.number_of_epochs_getter.text()) 271 | learning_rate = float(self.learning_rate_getter.text()) 272 | for param_group in self.optimizer.param_groups: 273 | param_group['lr'] = learning_rate 274 | 275 | thread = threading.Thread(target=train_model, args=(self.model, self.train_loader, self.val_loader, 276 | self.optimizer, number_of_epoch, 10, 277 | self.classification_metric, self.segmentation_metric, 278 | self.scheduler, self.label_criterion, 279 | self.integration_criterion, self.intersection_criterion, 280 | self.accumulation, self.loss_ax, 281 | self.classification_score_ax, 282 | self.segmentation_score_ax, 283 | self.figure, self.canvas)) 284 | thread.start() 285 | except ValueError: 286 | # popup window with exception 287 | msg = QtWidgets.QMessageBox(self) 288 | msg.setText("Check parameters. Something is wrong!") 289 | msg.setIcon(QtWidgets.QMessageBox.Warning) 290 | msg.exec_() 291 | 292 | 293 | class SaveModelWindow(QtWidgets.QDialog): 294 | def __init__(self, model, parent): 295 | self.parent = parent 296 | super().__init__(parent) 297 | self.model = model 298 | 299 | folder_label = QtWidgets.QLabel() 300 | folder_label.setText('Choose a folder where to save:') 301 | self.folder_getter = GetFolderWidget(os.path.join(os.getcwd(), 'data', 'weights'), self) 302 | 303 | name_label = QtWidgets.QLabel() 304 | name_label.setText('Set a name of file: ') 305 | self.name_getter = QtWidgets.QLineEdit(self) 306 | self.name_getter.setText('model.pt') 307 | 308 | save_button = QtWidgets.QPushButton('Save') 309 | save_button.clicked.connect(self.save) 310 | 311 | main_layout = QtWidgets.QVBoxLayout() 312 | main_layout.addWidget(folder_label) 313 | main_layout.addWidget(self.folder_getter) 314 | main_layout.addWidget(name_label) 315 | main_layout.addWidget(self.name_getter) 316 | main_layout.addWidget(save_button) 317 | 318 | self.setLayout(main_layout) 319 | 320 | def save(self): 321 | folder = self.folder_getter.get_folder() 322 | name = self.name_getter.text() 323 | shutil.copyfile(os.path.join('data/tmp_weights', self.model.__class__.__name__), 324 | os.path.join(folder, name)) 325 | -------------------------------------------------------------------------------- /gui_utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import threading 5 | import numpy as np 6 | from functools import partial 7 | from PyQt5 import QtWidgets 8 | import matplotlib.pyplot as plt 9 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 10 | from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar 11 | from gui_utils.auxilary_utils import GetFolderWidget, GetFileWidget, FeatureListWidget 12 | from models.rcnn import RecurrentCNN 13 | from models.cnn_classifier import Classifier 14 | from models.cnn_segmentator import Segmentator 15 | from processing_utils.runner import BasicRunner 16 | from processing_utils.roi import construct_ROI 17 | from processing_utils.run_utils import Feature 18 | 19 | 20 | class EvaluationParameterWindow(QtWidgets.QDialog): 21 | """ 22 | Evaluation Parameter Window, where one should choose parameters for evaluation 23 | 24 | Parameters 25 | ---------- 26 | mode : str 27 | A one of two 'all in one' of 'sequential' 28 | parent : MainWindow(QtWidgets.QMainWindow) 29 | - 30 | Attributes 31 | ---------- 32 | mode : str 33 | A one of two 'all in one' of 'sequential' 34 | parent : MainWindow(QtWidgets.QMainWindow) 35 | - 36 | test_folder_getter : GetFolderWidget 37 | A getter for a path to test data 38 | model_weights_getter : GetFileWidget 39 | A getter for a path to weights for 'all-in-one' model (optional) 40 | classifier_weights_getter : GetFileWidget 41 | A getter for a path to weights for 'all-in-one' model (optional) 42 | peak_points_getter : QtWidgets.QLineEdit 43 | A getter for peak_minimum_points parameter 44 | segmentator_weights_getter : GetFileWidget 45 | A getter for a path to weights for 'all-in-one' model (optional) 46 | """ 47 | def __init__(self, mode, parent=None): 48 | self.mode = mode 49 | self.parent = parent 50 | super().__init__(parent) 51 | self.setWindowTitle('peakonly: evaluation') 52 | 53 | test_folder_label = QtWidgets.QLabel() 54 | test_folder_label.setText('Choose a folder with test data:') 55 | self.test_folder_getter = GetFolderWidget(os.path.join(os.getcwd(), 'data', 'test'), self) 56 | 57 | if mode == 'all in one': 58 | model_weights_label = QtWidgets.QLabel() 59 | model_weights_label.setText("Choose weights for 'all-in-one' model") 60 | # to do: save a pytorch script, not a model state 61 | self.model_weights_getter = GetFileWidget('pt', os.path.join(os.getcwd(), 62 | 'data/weights/RecurrentCNN.pt'), self) 63 | elif mode == 'sequential': 64 | classifier_weights_label = QtWidgets.QLabel() 65 | classifier_weights_label.setText('Choose weights for a classifier') 66 | # to do: save a pytorch script, not a model state 67 | self.classifier_weights_getter = GetFileWidget('pt', os.path.join(os.getcwd(), 68 | 'data/weights/Classifier.pt'), self) 69 | segmentator_weights_label = QtWidgets.QLabel() 70 | segmentator_weights_label.setText('Choose weights for a segmentator') 71 | # to do: save a pytorch script, not a model state 72 | self.segmentator_weights_getter = GetFileWidget('pt', os.path.join(os.getcwd(), 73 | 'data/weights/Segmentator.pt'), self) 74 | else: 75 | assert False, mode 76 | 77 | peak_points_label = QtWidgets.QLabel() 78 | peak_points_label.setText('Minimal length of peak:') 79 | self.peak_points_getter = QtWidgets.QLineEdit(self) 80 | self.peak_points_getter.setText('8') 81 | 82 | run_button = QtWidgets.QPushButton('Run evaluation') 83 | run_button.clicked.connect(self._run_evaluation) 84 | 85 | main_layout = QtWidgets.QVBoxLayout() 86 | main_layout.addWidget(test_folder_label) 87 | main_layout.addWidget(self.test_folder_getter) 88 | if mode == 'all in one': 89 | main_layout.addWidget(model_weights_label) 90 | main_layout.addWidget(self.model_weights_getter) 91 | elif mode == 'sequential': 92 | main_layout.addWidget(classifier_weights_label) 93 | main_layout.addWidget(self.classifier_weights_getter) 94 | main_layout.addWidget(segmentator_weights_label) 95 | main_layout.addWidget(self.segmentator_weights_getter) 96 | main_layout.addWidget(peak_points_label) 97 | main_layout.addWidget(self.peak_points_getter) 98 | main_layout.addWidget(run_button) 99 | 100 | self.setLayout(main_layout) 101 | 102 | def _run_evaluation(self): 103 | try: 104 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 105 | # to do: device should be customizable parameter 106 | test_folder = self.test_folder_getter.get_folder() 107 | if self.mode == 'all in one': 108 | # to do: save models as pytorch scripts 109 | model = RecurrentCNN().to(device) 110 | path2weights = self.model_weights_getter.get_file() 111 | model.load_state_dict(torch.load(path2weights, map_location=device)) 112 | model.eval() 113 | models = [model] 114 | elif self.mode == 'sequential': 115 | classifier = Classifier().to(device) 116 | path2classifier_weights = self.classifier_weights_getter.get_file() 117 | classifier.load_state_dict(torch.load(path2classifier_weights, map_location=device)) 118 | classifier.eval() 119 | segmentator = Segmentator().to(device) 120 | path2segmentator_weights = self.segmentator_weights_getter.get_file() 121 | segmentator.load_state_dict(torch.load(path2segmentator_weights, map_location=device)) 122 | segmentator.eval() 123 | models = [classifier, segmentator] 124 | else: 125 | assert False, self.mode 126 | minimum_peak_points = int(self.peak_points_getter.text()) 127 | 128 | runner = BasicRunner(self.mode, models, 129 | minimum_peak_points, device) 130 | 131 | main_window = EvaluationMainWindow(test_folder, runner, self.parent) 132 | main_window.show() 133 | self.close() 134 | except ValueError: 135 | # popup window with exception 136 | msg = QtWidgets.QMessageBox(self) 137 | msg.setText("Check parameters. Something is wrong!") 138 | msg.setIcon(QtWidgets.QMessageBox.Warning) 139 | msg.exec_() 140 | 141 | 142 | class EvaluationMainWindow(QtWidgets.QDialog): 143 | """ 144 | Evaluation Main Window, where one can look into the model quality 145 | 146 | Parameters 147 | ---------- 148 | test_folder : str 149 | A path to folder with test data 150 | runner : BasicRunner 151 | - 152 | parent : MainWindow(QtWidgets.QMainWindow) 153 | - 154 | Attributes 155 | ---------- 156 | test_folder : str 157 | A path to folder with test data 158 | runner : BasicRunner 159 | - 160 | parent : MainWindow(QtWidgets.QMainWindow) 161 | - 162 | tp_features : FeatureListWidget 163 | true positives features 164 | tn_features : FeatureListWidget 165 | true negatives features 166 | fp_features : FeatureListWidget 167 | false positives features 168 | fn_features : FeatureListWidget 169 | false negatives features 170 | figure : Figure 171 | - 172 | ax : Axes 173 | - 174 | canvas : FigureCanvasQTAgg 175 | - 176 | """ 177 | def __init__(self, test_folder, runner, parent): 178 | self.parent = parent 179 | super().__init__(parent) 180 | self.setWindowTitle('peakonly: evaluation') 181 | 182 | self.test_folder = test_folder 183 | self.runner = runner 184 | 185 | self._init_ui() 186 | 187 | def _init_ui(self): 188 | # create lists of features 189 | lists_layout = QtWidgets.QHBoxLayout() 190 | 191 | tp_layout = QtWidgets.QVBoxLayout() 192 | tp_label = QtWidgets.QLabel() 193 | tp_label.setText('True positives:') 194 | tp_layout.addWidget(tp_label) 195 | self.tp_features = self.create_list_of_features() 196 | tp_layout.addWidget(self.tp_features) 197 | tp_next_button = QtWidgets.QPushButton('Next') 198 | tp_next_button.clicked.connect(partial(self.next_feature, self.tp_features)) 199 | tp_layout.addWidget(tp_next_button) 200 | lists_layout.addLayout(tp_layout) 201 | 202 | tn_layout = QtWidgets.QVBoxLayout() 203 | tn_label = QtWidgets.QLabel() 204 | tn_label.setText('True negatives:') 205 | tn_layout.addWidget(tn_label) 206 | self.tn_features = self.create_list_of_features() 207 | tn_layout.addWidget(self.tn_features) 208 | tn_next_button = QtWidgets.QPushButton('Next') 209 | tn_next_button.clicked.connect(partial(self.next_feature, self.tn_features)) 210 | tn_layout.addWidget(tn_next_button) 211 | lists_layout.addLayout(tn_layout) 212 | 213 | fp_layout = QtWidgets.QVBoxLayout() 214 | fp_label = QtWidgets.QLabel() 215 | fp_label.setText('False positives:') 216 | fp_layout.addWidget(fp_label) 217 | self.fp_features = self.create_list_of_features() 218 | fp_layout.addWidget(self.fp_features) 219 | fp_next_button = QtWidgets.QPushButton('Next') 220 | fp_next_button.clicked.connect(partial(self.next_feature, self.fp_features)) 221 | fp_layout.addWidget(fp_next_button) 222 | lists_layout.addLayout(fp_layout) 223 | 224 | fn_layout = QtWidgets.QVBoxLayout() 225 | fn_label = QtWidgets.QLabel() 226 | fn_label.setText('False negatives:') 227 | fn_layout.addWidget(fn_label) 228 | self.fn_features = self.create_list_of_features() 229 | fn_layout.addWidget(self.fn_features) 230 | fn_next_button = QtWidgets.QPushButton('Next') 231 | fn_next_button.clicked.connect(partial(self.next_feature, self.fn_features)) 232 | fn_layout.addWidget(fn_next_button) 233 | lists_layout.addLayout(fn_layout) 234 | 235 | # statistic button 236 | right_half_layout = QtWidgets.QVBoxLayout() 237 | right_half_layout.addLayout(lists_layout) 238 | 239 | statistics_button = QtWidgets.QPushButton('Plot confusion matrix') 240 | statistics_button.clicked.connect(self.plot_confusion_matrix) 241 | right_half_layout.addWidget(statistics_button) 242 | 243 | # Main canvas and toolbar 244 | self.figure = plt.figure() 245 | self.ax = self.figure.add_subplot(111) # plot here 246 | self.canvas = FigureCanvas(self.figure) 247 | toolbar = NavigationToolbar(self.canvas, self) 248 | canvas_layout = QtWidgets.QVBoxLayout() 249 | canvas_layout.addWidget(toolbar) 250 | canvas_layout.addWidget(self.canvas) 251 | 252 | main_layout = QtWidgets.QHBoxLayout() 253 | main_layout.addLayout(canvas_layout, 60) 254 | main_layout.addLayout(right_half_layout, 40) 255 | self.setLayout(main_layout) 256 | 257 | thread = threading.Thread(target=self.update) 258 | thread.start() 259 | 260 | def create_list_of_features(self): 261 | list_of_features = FeatureListWidget() 262 | list_of_features.connectDoubleClick(self.feature_click) 263 | return list_of_features 264 | 265 | def feature_click(self, item): 266 | list_widget = item.listWidget() 267 | feature = list_widget.get_feature(item) 268 | self.plot_feature(feature) 269 | 270 | def next_feature(self, list_widget): 271 | raw = list_widget.currentRow() 272 | item = list_widget.item(min(raw + 1, list_widget.count() - 1)) 273 | list_widget.setCurrentItem(item) 274 | feature = list_widget.get_feature(item) 275 | self.plot_feature(feature) 276 | 277 | def update(self): 278 | for file in os.listdir(self.test_folder): 279 | if file[0] != '.': 280 | with open(os.path.join(self.test_folder, file)) as json_file: 281 | dict_roi = json.load(json_file) 282 | # get predicted features 283 | roi = construct_ROI(dict_roi) 284 | features = self.runner(roi, 'predicted/' + file) 285 | # append gt (ground truth) features 286 | for border in dict_roi['borders']: 287 | gt = np.zeros(len(roi.i), dtype=np.bool) 288 | gt[border[0]:border[1]+1] = 1 289 | scan_frequency = (roi.scan[1] - roi.scan[0]) / (roi.rt[1] - roi.rt[0]) 290 | rtmin = roi.rt[0] + border[0] / scan_frequency 291 | rtmax = roi.rt[0] + border[1] / scan_frequency 292 | match = False 293 | for feature in features: 294 | if len(feature) == 1 and feature.samples[0][:2] == 'pr': 295 | predicted_border = feature.borders[0] 296 | pred = np.zeros(len(roi.i), dtype=np.bool) 297 | pred[predicted_border[0]:predicted_border[1]+1] = 1 298 | # calculate iou 299 | intersection = (pred & gt).sum() # will be zero if Truth=0 or Prediction=0 300 | union = (pred | gt).sum() 301 | if intersection / union > 0.5: 302 | match = True 303 | feature.append('gt/' + file, roi, border, 0, np.sum(roi.i[border[0]:border[1]]), 304 | roi.mzmean, rtmin, rtmax) 305 | break 306 | if not match: 307 | features.append(Feature(['gt/' + file], [roi], [border], [0], [np.sum(roi.i[border[0]:border[1]])], 308 | roi.mzmean, rtmin, rtmax, 0, 0)) 309 | 310 | # append tp, tn, fp, fn 311 | for feature in features: 312 | if len(feature) == 2: 313 | self.tp_features.add_feature(feature) 314 | elif len(feature) == 1 and feature.samples[0][:2] == 'pr': 315 | self.fp_features.add_feature(feature) 316 | elif len(feature) == 1 and feature.samples[0][:2] == 'gt': 317 | self.fn_features.add_feature(feature) 318 | else: 319 | print(len(feature)), print(feature.samples[0][:2]) 320 | assert False, feature.samples 321 | 322 | if len(features) == 0: 323 | noise_feature = Feature(['noise/' + file], [roi], [[0, 0]], [0], [0], 324 | roi.mzmean, roi.rt[0], roi.rt[1], 0, 0) 325 | self.tn_features.add_feature(noise_feature) 326 | 327 | def plot_feature(self, feature): 328 | self.ax.clear() 329 | feature.plot(self.ax, shifted=False, show_legend=True) 330 | self.canvas.draw() # refresh canvas 331 | 332 | def plot_confusion_matrix(self): 333 | # to do: create a window with stats 334 | tp_features = self.tp_features.get_all() 335 | tn_features = self.tn_features.get_all() 336 | fp_features = self.fp_features.get_all() 337 | fn_features = self.fn_features.get_all() 338 | subwindow = EvaluationStatisticsWindow(tp_features, tn_features, fp_features, fn_features, self) 339 | subwindow.show() 340 | 341 | 342 | class EvaluationStatisticsWindow(QtWidgets.QDialog): 343 | def __init__(self, tp_features, tn_features, fp_features, fn_features, parent): 344 | self.parent = parent 345 | super().__init__(parent) 346 | self.setWindowTitle('evaluation: confusion matrix') 347 | # auxiliary calculations 348 | precision = len(tp_features) / (len(tp_features) + len(fp_features)) 349 | recall = len(tp_features) / (len(tp_features) + len(fn_features)) 350 | 351 | integration_accuracy = np.zeros(len(tp_features)) 352 | for i, feature in enumerate(tp_features): 353 | integration_accuracy[i] = np.abs(feature.intensities[0] - feature.intensities[1]) / feature.intensities[1] 354 | integration_accuracy = 1 - np.mean(integration_accuracy) 355 | 356 | # print metrics 357 | precision_label = QtWidgets.QLabel() 358 | precision_label.setText(f'Precision = {precision:.2f}') 359 | recall_label = QtWidgets.QLabel() 360 | recall_label.setText(f'Recall = {recall:.2f}') 361 | integration_accuracy_label = QtWidgets.QLabel() 362 | integration_accuracy_label.setText(f'Integration accuracy = {integration_accuracy:.2f}') 363 | 364 | # canvas for confusion matrix 365 | self.figure = plt.figure() 366 | self.ax = self.figure.add_subplot(111) # plot here 367 | self.canvas = FigureCanvas(self.figure) 368 | 369 | main_layout = QtWidgets.QVBoxLayout() 370 | main_layout.addWidget(self.canvas) 371 | main_layout.addWidget(precision_label) 372 | main_layout.addWidget(recall_label) 373 | main_layout.addWidget(integration_accuracy_label) 374 | 375 | self.setLayout(main_layout) 376 | 377 | self.plot_confusion_matrix(len(tp_features), len(tn_features), len(fp_features), len(fn_features)) 378 | 379 | def plot_confusion_matrix(self, tp, tn, fp, fn): 380 | confusion_matrix = np.zeros((2, 2), np.int) 381 | confusion_matrix[0, 0] = tp 382 | confusion_matrix[0, 1] = fp 383 | confusion_matrix[1, 0] = fn 384 | confusion_matrix[1, 1] = tn 385 | 386 | self.ax.set_title("Confusion matrix") 387 | res = self.ax.imshow(confusion_matrix, cmap='GnBu', interpolation='nearest') 388 | self.figure.colorbar(res) 389 | self.ax.set_xticks(np.arange(2)) 390 | self.ax.set_xticklabels(['peak', 'noise']) 391 | self.ax.set_yticks(np.arange(2)) 392 | self.ax.set_yticklabels(['peak', 'noise']) 393 | self.ax.set_ylabel("predicted") 394 | self.ax.set_xlabel("ground truth") 395 | for i, row in enumerate(confusion_matrix): 396 | for j, count in enumerate(row): 397 | plt.text(j, i, count, fontsize=14, horizontalalignment='center', verticalalignment='center') 398 | 399 | self.canvas.draw() 400 | -------------------------------------------------------------------------------- /peakonly.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import urllib.request 4 | import zipfile 5 | from functools import partial 6 | import matplotlib.pyplot as plt 7 | from PyQt5 import QtWidgets, QtGui, QtCore 8 | from processing_utils.postprocess import ResultTable 9 | from processing_utils.run_utils import find_mzML 10 | from gui_utils.abstract_main_window import AbtractMainWindow 11 | from gui_utils.auxilary_utils import ProgressBarsListItem 12 | from gui_utils.mining import AnnotationParameterWindow, ReAnnotationParameterWindow 13 | from gui_utils.visualization import EICParameterWindow, VisualizationWindow 14 | from gui_utils.processing import ProcessingParameterWindow 15 | from gui_utils.training import TrainingParameterWindow 16 | from gui_utils.evaluation import EvaluationParameterWindow 17 | from gui_utils.data_splitting import SplitterParameterWindow 18 | from gui_utils.threading import Worker 19 | 20 | 21 | class MainWindow(AbtractMainWindow): 22 | # Initialization 23 | def __init__(self): 24 | super().__init__() 25 | # create menu 26 | self._create_menu() 27 | 28 | # tune list of files 29 | self._list_of_files.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) 30 | self._list_of_files.connectRightClick(partial(FileContextMenu, self)) 31 | 32 | # tune list of features 33 | self._list_of_features.connectDoubleClick(self.plot_feature) 34 | self._list_of_features.connectRightClick(partial(FeatureContextMenu, self)) 35 | 36 | self._init_ui() 37 | 38 | # Set geometry and title 39 | self.setGeometry(300, 300, 900, 600) 40 | self.setWindowTitle('peakonly') 41 | self.show() 42 | 43 | def _create_menu(self): 44 | menu = self.menuBar() 45 | 46 | # file submenu 47 | file = menu.addMenu('File') 48 | 49 | file_import = QtWidgets.QMenu('Open', self) 50 | file_import_mzML = QtWidgets.QAction('Open *.mzML', self) 51 | file_import_mzML.triggered.connect(self._open_file) 52 | file_import.addAction(file_import_mzML) 53 | file_import_folder_mzML = QtWidgets.QAction('Open folder with *.mzML files', self) 54 | file_import_folder_mzML.triggered.connect(self._open_folder) 55 | file_import.addAction(file_import_folder_mzML) 56 | 57 | 58 | file_export = QtWidgets.QMenu('Save', self) 59 | file_export_features_csv = QtWidgets.QAction('Save a *.csv file with detected features', self) 60 | file_export_features_csv.triggered.connect(partial(self._export_features, 'csv')) 61 | file_export.addAction(file_export_features_csv) 62 | file_export_features_png = QtWidgets.QAction('Save features as *.png files', self) 63 | file_export_features_png.triggered.connect(partial(self._export_features, 'png')) 64 | file_export.addAction(file_export_features_png) 65 | 66 | file_clear = QtWidgets.QMenu('Clear', self) 67 | file_clear_features = QtWidgets.QAction('Clear panel with detected features', self) 68 | file_clear_features.triggered.connect(self._list_of_features.clear) 69 | file_clear.addAction(file_clear_features) 70 | 71 | file_exit = QtWidgets.QAction("Exit", self) 72 | file_exit.triggered.connect(QtWidgets.QApplication.quit) # to do: create visualization 73 | 74 | file.addMenu(file_import) 75 | file.addMenu(file_export) 76 | file.addMenu(file_clear) 77 | file.addAction(file_exit) 78 | 79 | # data submenu 80 | data = menu.addMenu('Data') 81 | 82 | data_processing = QtWidgets.QAction('Feature detection', self) 83 | data_processing.triggered.connect(partial(self._data_processing, 'simple')) 84 | 85 | data_download = QtWidgets.QMenu('Download', self) 86 | data_download_models = QtWidgets.QAction('Download trained models', self) 87 | data_download_models.triggered.connect(partial(self._download_button, mode='models')) 88 | data_download.addAction(data_download_models) 89 | data_download_annotated_data = QtWidgets.QAction('Download annotated data', self) 90 | data_download_annotated_data.triggered.connect(partial(self._download_button, mode='data')) 91 | data_download.addAction(data_download_annotated_data) 92 | data_download_example = QtWidgets.QAction('Download *.mzML example', self) 93 | data_download_example.triggered.connect(partial(self._download_button, mode='example')) 94 | data_download.addAction(data_download_example) 95 | 96 | data_visualization = QtWidgets.QAction('Visualization', self) 97 | data_visualization.triggered.connect(self._open_visualization_window) # to do: create visualization 98 | 99 | data.addAction(data_processing) 100 | data.addMenu(data_download) 101 | data.addAction(data_visualization) 102 | 103 | # advanced submenu 104 | advanced = menu.addMenu('Advanced') 105 | 106 | advanced_data_processing = QtWidgets.QMenu('Advanced feature detection', self) 107 | advanced_data_processing_all = QtWidgets.QAction('RecurrentCNN (testing)', self) 108 | advanced_data_processing_all.triggered.connect(partial(self._data_processing, 'all in one')) 109 | advanced_data_processing.addAction(advanced_data_processing_all) 110 | advanced_data_processing_sequential = QtWidgets.QAction('Two subsequent CNNs', self) 111 | advanced_data_processing_sequential.triggered.connect(partial(self._data_processing, 'sequential')) 112 | advanced_data_processing.addAction(advanced_data_processing_sequential) 113 | 114 | advanced_data_mining = QtWidgets.QMenu('Data mining', self) 115 | advanced_data_mining_manual = QtWidgets.QAction('Manual annotation', self) 116 | advanced_data_mining_manual.triggered.connect(partial(self._data_mining, mode='manual')) 117 | advanced_data_mining.addAction(advanced_data_mining_manual) 118 | advanced_data_mining_reannotation = QtWidgets.QAction('Reannotation', self) 119 | advanced_data_mining_reannotation.triggered.connect(partial(self._data_mining, mode='reannotation')) 120 | advanced_data_mining.addAction(advanced_data_mining_reannotation) 121 | advanced_data_mining_split = QtWidgets.QAction('Split data', self) 122 | advanced_data_mining_split.triggered.connect(self._split_data) 123 | 124 | advanced_model = QtWidgets.QMenu('Model', self) 125 | advanced_model_training = QtWidgets.QMenu('Training', self) # training 126 | advanced_model_training_all = QtWidgets.QAction('RecurrentCNN (testing)', self) 127 | advanced_model_training_all.triggered.connect(partial(self._model_training, 'all in one')) 128 | advanced_model_training.addAction(advanced_model_training_all) 129 | advanced_model_training_sequential = QtWidgets.QAction('Two subsequent CNNs', self) 130 | advanced_model_training_sequential.triggered.connect(partial(self._model_training, 'sequential')) 131 | advanced_model_training.addAction(advanced_model_training_sequential) 132 | advanced_model_fine_tuning = QtWidgets.QMenu('Fine-tuning (in developing)', self) # fine-tuning 133 | advanced_model_fine_tuning_all = QtWidgets.QAction('RecurrentCNN (testing)', self) 134 | advanced_model_fine_tuning_all.triggered.connect(partial(self._model_fine_tuning, 'all in one')) 135 | advanced_model_fine_tuning.addAction(advanced_model_fine_tuning_all) 136 | advanced_model_fine_tuning_sequential = QtWidgets.QAction('Two subsequent CNNs', self) 137 | advanced_model_fine_tuning_sequential.triggered.connect(partial(self._model_fine_tuning, 'sequential')) 138 | advanced_model_fine_tuning.addAction(advanced_model_fine_tuning_sequential) 139 | advanced_model_evaluation = QtWidgets.QMenu('Evaluation', self) # evaluation 140 | advanced_model_evaluation_all = QtWidgets.QAction('RecurrentCNN (testing)', self) 141 | advanced_model_evaluation_all.triggered.connect(partial(self._model_evaluation, 'all in one')) 142 | advanced_model_evaluation.addAction(advanced_model_evaluation_all) 143 | advanced_model_evaluation_sequential = QtWidgets.QAction('Two subsequent CNNs', self) 144 | advanced_model_evaluation_sequential.triggered.connect(partial(self._model_evaluation, 'sequential')) 145 | advanced_model_evaluation.addAction(advanced_model_evaluation_sequential) 146 | advanced_model.addMenu(advanced_model_training) # add to menu 147 | advanced_model.addMenu(advanced_model_fine_tuning) 148 | advanced_model.addMenu(advanced_model_evaluation) 149 | 150 | advanced.addMenu(advanced_data_processing) 151 | advanced.addMenu(advanced_data_mining) 152 | advanced.addMenu(advanced_model) 153 | 154 | def _init_ui(self): 155 | # Layouts 156 | files_layout = QtWidgets.QVBoxLayout() 157 | files_label = QtWidgets.QLabel(self) 158 | files_label.setText('Opened files:') 159 | files_layout.addWidget(files_label) 160 | files_layout.addWidget(self._list_of_files) 161 | 162 | features_layout = QtWidgets.QVBoxLayout() 163 | features_label = QtWidgets.QLabel(self) 164 | features_label.setText('Detected features:') 165 | features_layout.addWidget(features_label) 166 | features_layout.addWidget(self._list_of_features) 167 | 168 | canvas_layout = QtWidgets.QVBoxLayout() 169 | canvas_layout.addWidget(self._toolbar) 170 | canvas_layout.addWidget(self._canvas) 171 | 172 | canvas_files_features_layout = QtWidgets.QHBoxLayout() 173 | canvas_files_features_layout.addLayout(files_layout, 15) 174 | canvas_files_features_layout.addLayout(canvas_layout, 70) 175 | canvas_files_features_layout.addLayout(features_layout, 15) 176 | 177 | scrollable_pb_list = QtWidgets.QScrollArea() 178 | scrollable_pb_list.setWidget(self._pb_list) 179 | scrollable_pb_list.setWidgetResizable(True) 180 | 181 | main_layout = QtWidgets.QVBoxLayout() 182 | main_layout.addLayout(canvas_files_features_layout, 90) 183 | main_layout.addWidget(scrollable_pb_list, 10) 184 | 185 | widget = QtWidgets.QWidget() 186 | widget.setLayout(main_layout) 187 | 188 | self.setCentralWidget(widget) 189 | 190 | # Auxiliary methods 191 | def _open_file(self): 192 | files_names = QtWidgets.QFileDialog.getOpenFileNames(None, '', '', 'mzML (*.mzML)')[0] 193 | for name in files_names: 194 | self._list_of_files.addFile(name) 195 | 196 | def _open_folder(self): 197 | path = str(QtWidgets.QFileDialog.getExistingDirectory()) 198 | for name in sorted(find_mzML(path)): 199 | self._list_of_files.addFile(name) 200 | 201 | def _export_features(self, mode): 202 | if self._list_of_features.count() > 0: 203 | if mode == 'csv': 204 | # to do: features should be QTreeWidget (root should keep basic information: files and parameters) 205 | files = self._feature_parameters['files'] 206 | table = ResultTable(files, self._list_of_features.features) 207 | table.fill_zeros(self._feature_parameters['delta mz']) 208 | file_name, _ = QtWidgets.QFileDialog.getSaveFileName(self, 'Export features', '', 209 | 'csv (*.csv)') 210 | if file_name: 211 | table.to_csv(file_name) 212 | elif mode == 'png': 213 | directory = str(QtWidgets.QFileDialog.getExistingDirectory(self, 'Choose a directory where to save')) 214 | 215 | worker = Worker(self._save_features_png, features=self._list_of_features.features, directory=directory) 216 | self.run_thread('Saving features as *.png files:', worker) 217 | else: 218 | assert False, mode 219 | else: 220 | msg = QtWidgets.QMessageBox(self) 221 | msg.setText('You should firstly detect features in *mzML files:\n' 222 | 'Data -> Feature detection') 223 | msg.setIcon(QtWidgets.QMessageBox.Warning) 224 | msg.exec_() 225 | 226 | def _get_eic_parameters(self): 227 | subwindow = EICParameterWindow(self) 228 | subwindow.show() 229 | 230 | @staticmethod 231 | def _show_downloading_progress(number_of_block, size_of_block, total_size, pb): 232 | pb.setValue(int(number_of_block * size_of_block * 100 / total_size)) 233 | 234 | # Buttons, which creates threads 235 | def _download_button(self, mode): 236 | if mode == 'models': 237 | text = 'Downloading trained models:' 238 | elif mode == 'data': 239 | text = 'Downloading annotated data:' 240 | elif mode == 'example': 241 | text = 'Downloading *.mzML example:' 242 | else: 243 | assert False, mode 244 | 245 | pb = ProgressBarsListItem(text, parent=self._pb_list) 246 | self._pb_list.addItem(pb) 247 | worker = Worker(self._download, download=True, mode=mode) 248 | worker.signals.download_progress.connect(partial(self._show_downloading_progress, pb=pb)) 249 | worker.signals.finished.connect(partial(self._threads_finisher, 250 | text='Download is successful', 251 | icon=QtWidgets.QMessageBox.Information, 252 | pb=pb)) 253 | self._thread_pool.start(worker) 254 | 255 | # Main functionality 256 | @staticmethod 257 | def _download(mode, progress_callback): 258 | """ 259 | Download necessary data 260 | Parameters 261 | ---------- 262 | mode : str 263 | one of three ('models', 'data', 'example') 264 | progress_callback : QtCore.pyqtSignal 265 | indicating progress in % 266 | """ 267 | if mode == 'models': 268 | folder = 'data/weights' 269 | if not os.path.exists(folder): 270 | os.mkdir(folder) 271 | # Classifier 272 | url = 'https://getfile.dokpub.com/yandex/get/https://yadi.sk/d/rAhl2u7WeIUGYA' 273 | file = os.path.join(folder, 'Classifier.pt') 274 | urllib.request.urlretrieve(url, file, progress_callback.emit) 275 | # Segmentator 276 | url = 'https://getfile.dokpub.com/yandex/get/https://yadi.sk/d/9m5e3C0q0HKbuw' 277 | file = os.path.join(folder, 'Segmentator.pt') 278 | urllib.request.urlretrieve(url, file, progress_callback.emit) 279 | # RecurrentCNN 280 | url = 'https://getfile.dokpub.com/yandex/get/https://yadi.sk/d/1IrXRWDWhANqKw' 281 | file = os.path.join(folder, 'RecurrentCNN.pt') 282 | urllib.request.urlretrieve(url, file, progress_callback.emit) 283 | elif mode == 'data': 284 | folder = 'data/annotation' 285 | if not os.path.exists(folder): 286 | os.mkdir(folder) 287 | url = 'https://getfile.dokpub.com/yandex/get/https://yadi.sk/d/f6BiwqWYF4UVnA' 288 | file = 'data/annotation/annotation.zip' 289 | urllib.request.urlretrieve(url, file, progress_callback.emit) 290 | with zipfile.ZipFile(file) as zip_file: 291 | zip_file.extractall(folder) 292 | os.remove(file) 293 | elif mode == 'example': 294 | url = 'https://getfile.dokpub.com/yandex/get/https://yadi.sk/d/BhQNge3db7M2Lw' 295 | file = 'data/mix.mzML' 296 | urllib.request.urlretrieve(url, file, progress_callback.emit) 297 | else: 298 | assert False, mode 299 | 300 | @staticmethod 301 | def _save_features_png(features, directory, progress_callback): 302 | fig = plt.figure() 303 | for i, feature in enumerate(features): 304 | ax = fig.add_subplot(111) 305 | feature.plot(ax, shifted=True) 306 | fig.savefig(os.path.join(directory, f'{i}.png')) 307 | fig.clear() 308 | progress_callback.emit(int(i * 100 / len(features))) 309 | plt.close(fig) 310 | 311 | def _split_data(self): 312 | subwindow = SplitterParameterWindow(self) 313 | subwindow.show() 314 | 315 | def _data_mining(self, mode='manual'): 316 | if mode != 'reannotation': 317 | files = [self._list_of_files.file2path[self._list_of_files.item(i).text()] 318 | for i in range(self._list_of_files.count())] 319 | subwindow = AnnotationParameterWindow(files, mode, self) 320 | subwindow.show() 321 | else: 322 | subwindow = ReAnnotationParameterWindow(self) 323 | subwindow.show() 324 | 325 | def _data_processing(self, mode): 326 | if mode == 'simple' and (not os.path.isfile(os.path.join('data', 'weights', 'Classifier.pt')) 327 | or not os.path.isfile(os.path.join('data', 'weights', 'Segmentator.pt'))): 328 | msg = QtWidgets.QMessageBox(self) 329 | msg.setText('You should download models in order to process your data:\n' 330 | 'Data -> Download -> Download trained models') 331 | msg.setIcon(QtWidgets.QMessageBox.Warning) 332 | msg.exec_() 333 | else: 334 | files = [self._list_of_files.file2path[self._list_of_files.item(i).text()] 335 | for i in range(self._list_of_files.count())] 336 | if not files: 337 | msg = QtWidgets.QMessageBox(self) 338 | msg.setText('You should firstly open *.mzML files:\n' 339 | 'File -> Open -> Open *.mzML') 340 | msg.setIcon(QtWidgets.QMessageBox.Warning) 341 | msg.exec_() 342 | else: 343 | subwindow = ProcessingParameterWindow(files, mode, self) 344 | subwindow.show() 345 | 346 | def _open_visualization_window(self): 347 | files = [self._list_of_files.file2path[self._list_of_files.item(i).text()] 348 | for i in range(self._list_of_files.count())] 349 | subwindow = VisualizationWindow(files, self) 350 | subwindow.show() 351 | 352 | # Model functionality 353 | def _model_training(self, mode): 354 | subwindow = TrainingParameterWindow(mode, self) 355 | subwindow.show() 356 | 357 | def _model_fine_tuning(self, mode): 358 | pass 359 | 360 | def _model_evaluation(self, mode): 361 | subwindow = EvaluationParameterWindow(mode, self) 362 | subwindow.show() 363 | 364 | 365 | class FileContextMenu(QtWidgets.QMenu): 366 | def __init__(self, parent: MainWindow): 367 | self.parent = parent 368 | super().__init__(parent) 369 | 370 | menu = QtWidgets.QMenu(parent) 371 | 372 | tic = QtWidgets.QAction('Plot TIC', parent) 373 | eic = QtWidgets.QAction('Plot EIC', parent) 374 | close = QtWidgets.QAction('Close', parent) 375 | 376 | menu.addAction(tic) 377 | menu.addAction(eic) 378 | menu.addAction(close) 379 | 380 | action = menu.exec_(QtGui.QCursor.pos()) 381 | 382 | if action == tic: 383 | for file in self.parent.get_selected_files(): 384 | file = file.text() 385 | self.parent.plot_tic(file) 386 | elif action == eic: 387 | subwindow = EICParameterWindow(self.parent) 388 | subwindow.show() 389 | elif action == close: 390 | self.close_files() 391 | 392 | def close_files(self): 393 | for item in self.parent.get_selected_files(): 394 | self.parent.close_file(item) 395 | 396 | 397 | class FeatureContextMenu(QtWidgets.QMenu): 398 | def __init__(self, parent: MainWindow): 399 | self.parent = parent 400 | super().__init__(parent) 401 | feature = None 402 | for item in self.parent.get_selected_features(): 403 | feature = item 404 | 405 | menu = QtWidgets.QMenu(parent) 406 | 407 | with_rt_correction = QtWidgets.QAction('Plot with rt correction', parent) 408 | without_rt_correction = QtWidgets.QAction('Plot without rt correction', parent) 409 | 410 | menu.addAction(with_rt_correction) 411 | menu.addAction(without_rt_correction) 412 | 413 | action = menu.exec_(QtGui.QCursor.pos()) 414 | 415 | if action == with_rt_correction: 416 | self.parent.plot_feature(feature, shifted=True) 417 | elif action == without_rt_correction: 418 | self.parent.plot_feature(feature, shifted=False) 419 | 420 | 421 | if __name__ == '__main__': 422 | plt.switch_backend('Agg') # to do: check if it is alright??? 423 | app = QtWidgets.QApplication(sys.argv) 424 | window = MainWindow() 425 | sys.exit(app.exec_()) 426 | -------------------------------------------------------------------------------- /gui_utils/mining.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from functools import partial 6 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 7 | from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar 8 | from PyQt5 import QtWidgets, QtGui 9 | try: 10 | from cython_utils.roi import get_ROIs 11 | except ImportError: 12 | from processing_utils.roi import get_ROIs 13 | from processing_utils.roi import construct_ROI 14 | # from processing_utils.run_utils import classifier_prediction 15 | from gui_utils.abstract_main_window import AbtractMainWindow 16 | from gui_utils.auxilary_utils import FileListWidget, GetFolderWidget, ProgressBarsListItem 17 | from gui_utils.threading import Worker 18 | 19 | 20 | class ReAnnotationParameterWindow(QtWidgets.QDialog): 21 | def __init__(self, parent: AbtractMainWindow): 22 | self.mode = 'reannotation' 23 | self.parent = parent 24 | super().__init__(parent) 25 | self.setWindowTitle('peakonly: reannotation') 26 | 27 | save_to_label = QtWidgets.QLabel() 28 | save_to_label.setText('Choose a folder with annotated ROIs:') 29 | self.folder_widget = GetFolderWidget() 30 | 31 | self.run_button = QtWidgets.QPushButton('Run reannotation') 32 | self.run_button.clicked.connect(self.start_reannotation) 33 | 34 | main_layout = QtWidgets.QVBoxLayout() 35 | main_layout.addWidget(save_to_label) 36 | main_layout.addWidget(self.folder_widget) 37 | main_layout.addWidget(self.run_button) 38 | 39 | self.setLayout(main_layout) 40 | 41 | def start_reannotation(self): 42 | folder = self.folder_widget.get_folder() 43 | subwindow = AnnotationMainWindow([], folder, None, None, 44 | None, self.mode, 45 | None, parent=self.parent) 46 | subwindow.show() 47 | self.close() 48 | 49 | 50 | class AnnotationParameterWindow(QtWidgets.QDialog): 51 | def __init__(self, files, mode, parent: AbtractMainWindow): 52 | self.mode = mode 53 | self.parent = parent 54 | super().__init__(parent) 55 | self.setWindowTitle('peakonly: manual annotation') 56 | 57 | self.files = files 58 | self.description = None 59 | self.file_prefix = None 60 | self.file_suffix = None 61 | self.minimum_peak_points = None 62 | self.folder = None 63 | 64 | self._init_ui() 65 | 66 | def _init_ui(self): 67 | # file and folder selection 68 | choose_file_label = QtWidgets.QLabel() 69 | choose_file_label.setText('Choose a file to annotate:') 70 | self.list_of_files = FileListWidget() 71 | for file in self.files: 72 | self.list_of_files.addFile(file) 73 | 74 | save_to_label = QtWidgets.QLabel() 75 | save_to_label.setText('Choose a folder where to save annotated ROIs:') 76 | self.folder_widget = GetFolderWidget() 77 | 78 | file_layout = QtWidgets.QVBoxLayout() 79 | file_layout.addWidget(choose_file_label) 80 | file_layout.addWidget(self.list_of_files) 81 | file_layout.addWidget(save_to_label) 82 | file_layout.addWidget(self.folder_widget) 83 | 84 | # parameters selection 85 | 86 | instrumental_label = QtWidgets.QLabel() 87 | instrumental_label.setText('Instrumentals description') 88 | self.instrumental_getter = QtWidgets.QLineEdit(self) 89 | self.instrumental_getter.setText('Q-oa-TOF, total time=10 min, scan frequency=10Hz') 90 | 91 | prefix_label = QtWidgets.QLabel() 92 | prefix_label.setText('Prefix of filename: ') 93 | self.prefix_getter = QtWidgets.QLineEdit(self) 94 | self.prefix_getter.setText('Example') 95 | 96 | suffix_label = QtWidgets.QLabel() 97 | suffix_label.setText('Code of file (suffix, will be increased during annotation): ') 98 | self.suffix_getter = QtWidgets.QLineEdit(self) 99 | self.suffix_getter.setText('0') 100 | 101 | mz_label = QtWidgets.QLabel() 102 | mz_label.setText('m/z deviation:') 103 | self.mz_getter = QtWidgets.QLineEdit(self) 104 | self.mz_getter.setText('0.005') 105 | 106 | roi_points_label = QtWidgets.QLabel() 107 | roi_points_label.setText('Minimal length of ROI:') 108 | self.roi_points_getter = QtWidgets.QLineEdit(self) 109 | self.roi_points_getter.setText('15') 110 | 111 | if self.mode == 'semi-automatic': 112 | peak_points_label = QtWidgets.QLabel() 113 | peak_points_label.setText('Minimal length of peak:') 114 | self.peak_points_getter = QtWidgets.QLineEdit(self) 115 | self.peak_points_getter.setText('8') 116 | 117 | dropped_points_label = QtWidgets.QLabel() 118 | dropped_points_label.setText('Maximal number of zero points in a row:') 119 | self.dropped_points_getter = QtWidgets.QLineEdit(self) 120 | self.dropped_points_getter.setText('3') 121 | 122 | run_button = QtWidgets.QPushButton('Run annotation') 123 | run_button.clicked.connect(self._run_button) 124 | 125 | parameter_layout = QtWidgets.QVBoxLayout() 126 | parameter_layout.addWidget(instrumental_label) 127 | parameter_layout.addWidget(self.instrumental_getter) 128 | parameter_layout.addWidget(prefix_label) 129 | parameter_layout.addWidget(self.prefix_getter) 130 | parameter_layout.addWidget(suffix_label) 131 | parameter_layout.addWidget(self.suffix_getter) 132 | parameter_layout.addWidget(mz_label) 133 | parameter_layout.addWidget(self.mz_getter) 134 | parameter_layout.addWidget(roi_points_label) 135 | parameter_layout.addWidget(self.roi_points_getter) 136 | # if self.mode == 'semi-automatic': 137 | # parameter_layout.addWidget(peak_points_label) 138 | # parameter_layout.addWidget(self.peak_points_getter) 139 | parameter_layout.addWidget(dropped_points_label) 140 | parameter_layout.addWidget(self.dropped_points_getter) 141 | parameter_layout.addWidget(run_button) 142 | 143 | # main layout 144 | main_layout = QtWidgets.QHBoxLayout() 145 | main_layout.addLayout(file_layout) 146 | main_layout.addLayout(parameter_layout) 147 | 148 | self.setLayout(main_layout) 149 | 150 | def _run_button(self): 151 | try: 152 | self.description = self.instrumental_getter.text() 153 | self.file_prefix = self.prefix_getter.text() 154 | self.file_suffix = int(self.suffix_getter.text()) 155 | delta_mz = float(self.mz_getter.text()) 156 | required_points = int(self.roi_points_getter.text()) 157 | dropped_points = int(self.dropped_points_getter.text()) 158 | if self.mode == 'semi-automatic': 159 | self.minimum_peak_points = int(self.peak_points_getter.text()) 160 | 161 | self.folder = self.folder_widget.get_folder() 162 | path2mzml = None 163 | for file in self.list_of_files.selectedItems(): 164 | path2mzml = self.list_of_files.file2path[file.text()] 165 | if path2mzml is None: 166 | raise ValueError 167 | 168 | worker = Worker(get_ROIs, path2mzml, delta_mz, required_points, dropped_points) 169 | worker.signals.result.connect(self._start_annotation) 170 | self.parent.run_thread('ROI detection:', worker) 171 | 172 | self.close() 173 | except ValueError: 174 | # popup window with exception 175 | msg = QtWidgets.QMessageBox(self) 176 | msg.setText("Check parameters. Something is wrong!") 177 | msg.setIcon(QtWidgets.QMessageBox.Warning) 178 | msg.exec_() 179 | 180 | def _start_annotation(self, rois): 181 | self.rois = rois 182 | subwindow = AnnotationMainWindow(self.rois, self.folder, self.file_prefix, self.file_suffix, 183 | self.description, self.mode, 184 | self.minimum_peak_points, parent=self.parent) 185 | subwindow.show() 186 | 187 | 188 | class AnnotationMainWindow(QtWidgets.QDialog): 189 | def __init__(self, ROIs, folder, file_prefix, file_suffix, description, mode, 190 | minimum_peak_points, parent=None): 191 | super().__init__(parent) 192 | self.setWindowTitle('peakonly: annotation window') 193 | self.file_prefix = file_prefix 194 | self.file_suffix = file_suffix 195 | self.description = description 196 | self.current_description = description 197 | self.folder = folder 198 | self.mode = mode 199 | self.plotted_roi = None 200 | self.plotted_path = None 201 | self.plotted_item = None # data reannotation 202 | self.current_flag = False 203 | 204 | # if self.mode == 'semi-automatic': # load models 205 | # self.minimum_peak_points = minimum_peak_points 206 | # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 207 | # self.classify = Classifier() 208 | # self.classify.load_state_dict(torch.load('data/Classifier', map_location=self.device)) 209 | # self.classify.to(self.device) 210 | # self.classify.eval() 211 | # self.integrate = Integrator() 212 | # self.integrate.load_state_dict(torch.load('data/Integrator', map_location=self.device)) 213 | # self.integrate.to(self.device) 214 | # self.integrate.eval() 215 | # # variables where save CNNs predictions 216 | # self.label = 0 217 | # self.borders = [] 218 | # if self.mode == 'skip noise': 219 | # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 220 | # self.classify = Classifier() 221 | # self.classify.load_state_dict(torch.load('data/Classifier', map_location=self.device)) 222 | # self.classify.to(self.device) 223 | # self.classify.eval() 224 | # # variables where save CNN predictions 225 | # self.label = 0 226 | # shuffle ROIs 227 | self.ROIs = ROIs 228 | np.random.seed(1313) 229 | np.random.shuffle(self.ROIs) 230 | 231 | self.figure = plt.figure() # a figure instance to plot on 232 | self.canvas = FigureCanvas(self.figure) 233 | 234 | self.rois_list = FileListWidget() 235 | self.rois_list.connectRightClick(self.file_right_click) 236 | self.rois_list.connectDoubleClick(self.file_double_click) 237 | files = [] 238 | for created_file in os.listdir(self.folder): 239 | if created_file.endswith('.json'): 240 | begin = created_file.find('_') + 1 241 | end = created_file.find('.json') 242 | code = int(created_file[begin:end]) 243 | files.append((code, created_file)) 244 | for _, file in sorted(files): 245 | self.rois_list.addFile(os.path.join(self.folder, file)) 246 | 247 | self._init_ui() # initialize user interface 248 | 249 | self.plot_current() # initial plot 250 | 251 | def _init_ui(self): 252 | """ 253 | Initialize all buttons and layouts. 254 | """ 255 | # canvas layout 256 | toolbar = NavigationToolbar(self.canvas, self) 257 | canvas_layout = QtWidgets.QVBoxLayout() 258 | canvas_layout.addWidget(toolbar) 259 | canvas_layout.addWidget(self.canvas) 260 | 261 | # canvas and files list layout 262 | canvas_files_layout = QtWidgets.QHBoxLayout() 263 | canvas_files_layout.addLayout(canvas_layout, 80) 264 | canvas_files_layout.addWidget(self.rois_list, 20) 265 | 266 | if self.mode != 'reannotation': 267 | # plot current button 268 | plot_current_button = QtWidgets.QPushButton('Plot current ROI') 269 | plot_current_button.clicked.connect(self.plot_current) 270 | # noise button 271 | noise_button = QtWidgets.QPushButton('Noise') 272 | noise_button.clicked.connect(self.noise) 273 | # peak button 274 | peak_button = QtWidgets.QPushButton('Peak') 275 | peak_button.clicked.connect(self.peak) 276 | # skip button 277 | skip_button = QtWidgets.QPushButton('Skip') 278 | skip_button.clicked.connect(self.skip) 279 | # plot chosen button 280 | plot_chosen_button = QtWidgets.QPushButton('Plot chosen ROI') 281 | plot_chosen_button.clicked.connect(self.press_plot_chosen) 282 | 283 | 284 | # button layout 285 | button_layout = QtWidgets.QHBoxLayout() 286 | if self.mode != 'reannotation': 287 | button_layout.addWidget(plot_current_button) 288 | button_layout.addWidget(noise_button) 289 | button_layout.addWidget(peak_button) 290 | button_layout.addWidget(skip_button) 291 | # if self.mode == 'semi-automatic': 292 | # # agree button 293 | # agree_button = QtWidgets.QPushButton('Save CNNs annotation') 294 | # agree_button.clicked.connect(self.save_auto_annotation) 295 | # button_layout.addWidget(agree_button) 296 | button_layout.addWidget(plot_chosen_button) 297 | 298 | # main layout 299 | main_layout = QtWidgets.QVBoxLayout() 300 | main_layout.addLayout(canvas_files_layout) 301 | main_layout.addLayout(button_layout) 302 | self.setLayout(main_layout) 303 | 304 | # Auxiliary methods 305 | def file_right_click(self): 306 | FileContextMenu(self) 307 | 308 | def file_double_click(self, item): 309 | self.plotted_item = item 310 | self.plot_chosen() 311 | 312 | def get_chosen(self): 313 | chosen_item = None 314 | for item in self.rois_list.selectedItems(): 315 | chosen_item = item 316 | return chosen_item 317 | 318 | def close_file(self, item): 319 | if item == self.plotted_item: 320 | index = min(self.rois_list.row(self.plotted_item) + 1, self.rois_list.count() - 2) 321 | self.plotted_item = self.rois_list.item(index) 322 | self.plotted_item.setSelected(True) 323 | self.plot_chosen() 324 | self.rois_list.deleteFile(item) 325 | 326 | def delete_file(self, item): 327 | os.remove(self.rois_list.getPath(item)) 328 | self.close_file(item) 329 | 330 | # Buttons 331 | def noise(self): 332 | code = os.path.basename(self.plotted_path) 333 | code = code[:code.rfind('.')] 334 | label = 0 335 | self.plotted_roi.save_annotated(self.plotted_path, code, label, description=self.current_description) 336 | 337 | if self.current_flag: 338 | self.current_flag = False 339 | self.rois_list.addFile(self.plotted_path) 340 | self.file_suffix += 1 341 | self.plot_current() 342 | else: 343 | self.plotted_item.setSelected(False) 344 | index = min(self.rois_list.row(self.plotted_item) + 1, self.rois_list.count() - 1) 345 | self.plotted_item = self.rois_list.item(index) 346 | self.plotted_item.setSelected(True) 347 | self.plot_chosen() 348 | 349 | def peak(self): 350 | title = 'Annotate peak borders and press "save".' 351 | subwindow = AnnotationGetNumberOfPeaksNovel(self) 352 | subwindow.show() 353 | 354 | def skip(self): 355 | if self.current_flag: 356 | self.file_suffix += 1 357 | self.current_flag = False 358 | self.plot_current() 359 | else: 360 | self.plotted_item.setSelected(False) 361 | index = min(self.rois_list.row(self.plotted_item) + 1, self.rois_list.count() - 1) 362 | self.plotted_item = self.rois_list.item(index) 363 | self.plotted_item.setSelected(True) 364 | self.plot_chosen() 365 | 366 | def save_auto_annotation(self): 367 | if self.current_flag: 368 | number_of_peaks = len(self.borders) 369 | begins = [] 370 | ends = [] 371 | for begin, end in self.borders: 372 | begins.append(int(begin)) 373 | ends.append(int(end)) 374 | intersections = [] 375 | for i in range(number_of_peaks - 1): 376 | intersections.append(int(np.argmin(self.plotted_roi.i[ends[i]:begins[i+1]]) + ends[i])) 377 | 378 | code = os.path.basename(self.plotted_path) 379 | code = code[:code.rfind('.')] 380 | self.plotted_roi.save_annotated(self.plotted_path, int(self.label), code, number_of_peaks, 381 | begins, ends, intersections, self.description) 382 | 383 | self.current_flag = False 384 | self.rois_list.addFile(self.plotted_path) 385 | self.file_suffix += 1 386 | self.plot_current() 387 | 388 | def press_plot_chosen(self): 389 | try: 390 | self.plotted_item = self.get_chosen() 391 | if self.plotted_item is None: 392 | raise ValueError 393 | self.plot_chosen() 394 | except ValueError: 395 | # popup window with exception 396 | msg = QtWidgets.QMessageBox(self) 397 | msg.setText('Choose a ROI to plot from the list!') 398 | msg.setIcon(QtWidgets.QMessageBox.Warning) 399 | msg.exec_() 400 | 401 | # Visualization 402 | def plot_current(self): 403 | if self.mode != 'reannotation': 404 | if not self.current_flag: 405 | self.current_flag = True 406 | self.current_description = self.description 407 | self.plotted_roi = self.ROIs[self.file_suffix] 408 | # if self.mode == 'skip noise': 409 | # self.label = classifier_prediction(self.plotted_roi, self.classify, self.device) 410 | # while self.label == 0: 411 | # self.file_suffix += 1 412 | # self.plotted_roi = self.ROIs[self.file_suffix] 413 | # self.label = classifier_prediction(self.plotted_roi, self.classify, self.device) 414 | 415 | filename = f'{self.file_prefix}_{self.file_suffix}.json' 416 | self.plotted_path = os.path.join(self.folder, filename) 417 | 418 | self.figure.clear() 419 | ax = self.figure.add_subplot(111) 420 | ax.plot(self.plotted_roi.i, label=filename) 421 | title = f'mz = {self.plotted_roi.mzmean:.3f}, ' \ 422 | f'rt = {self.plotted_roi.rt[0]:.1f} - {self.plotted_roi.rt[1]:.1f}' 423 | 424 | # if self.mode == 'semi-automatic': # label and border predictions 425 | # self.label = classifier_prediction(self.plotted_roi, self.classify, self.device) 426 | # self.borders = [] 427 | # if self.label != 0: 428 | # self.borders = border_prediction(self.plotted_roi, self.integrate, 429 | # self.device, self.minimum_peak_points) 430 | # if self.label == 0: 431 | # title = 'label = noise, ' + title 432 | # elif self.label == 1: 433 | # title = 'label = peak, ' + title 434 | # elif self.label == 2: 435 | # title = 'label = uncertain peak, ' + title 436 | # 437 | # for begin, end in self.borders: 438 | # ax.fill_between(range(begin, end + 1), self.plotted_roi.i[begin:end + 1], alpha=0.5) 439 | 440 | ax.legend(loc='best') 441 | ax.set_title(title) 442 | self.canvas.draw() # refresh canvas 443 | 444 | def plot_chosen(self): 445 | filename = self.plotted_item.text() 446 | path2roi = self.rois_list.file2path[filename] 447 | with open(path2roi) as json_file: 448 | roi = json.load(json_file) 449 | self.current_description = roi['description'] 450 | self.plotted_roi = construct_ROI(roi) 451 | self.plotted_path = path2roi 452 | self.figure.clear() 453 | ax = self.figure.add_subplot(111) 454 | ax.plot(self.plotted_roi.i, label=filename) 455 | title = f'mz = {self.plotted_roi.mzmean:.3f}, ' \ 456 | f'rt = {self.plotted_roi.rt[0]:.1f} - {self.plotted_roi.rt[1]:.1f}' 457 | 458 | if roi['label'] == 0: 459 | title = 'label = noise, ' + title 460 | elif roi['label'] == 1: 461 | title = 'label = peak, ' + title 462 | 463 | for border, peak_label in zip(roi['borders'], roi["peaks' labels"]): 464 | begin, end = border 465 | ax.fill_between(range(begin, end + 1), self.plotted_roi.i[begin:end + 1], alpha=0.5, 466 | label=f"pl: {peak_label}, borders={begin}-{end}") 467 | 468 | ax.set_title(title) 469 | ax.legend(loc='best') 470 | self.canvas.draw() 471 | self.current_flag = False 472 | 473 | def plot_preview(self, borders): 474 | filename = os.path.basename(self.plotted_path) 475 | self.figure.clear() 476 | ax = self.figure.add_subplot(111) 477 | ax.plot(self.plotted_roi.i, label=filename) 478 | title = f'mz = {self.plotted_roi.mzmean:.3f}, ' \ 479 | f'rt = {self.plotted_roi.rt[0]:.1f} - {self.plotted_roi.rt[1]:.1f}' 480 | 481 | for border in borders: 482 | begin, end = border 483 | ax.fill_between(range(begin, end + 1), self.plotted_roi.i[begin:end + 1], alpha=0.5) 484 | ax.set_title(title) 485 | ax.legend(loc='best') 486 | self.canvas.draw() # refresh canvas 487 | 488 | 489 | class AnnotationGetNumberOfPeaksNovel(QtWidgets.QDialog): 490 | def __init__(self, parent: AnnotationMainWindow): 491 | self.parent = parent 492 | super().__init__(parent) 493 | self.setWindowTitle('annotation: number of peaks') 494 | 495 | label = QtWidgets.QLabel() 496 | label.setText('Print number of peaks in current ROI:') 497 | 498 | n_of_peaks_layout = QtWidgets.QHBoxLayout() 499 | n_of_peaks_label = QtWidgets.QLabel() 500 | n_of_peaks_label.setText('number of peaks = ') 501 | self.n_of_peaks_getter = QtWidgets.QLineEdit(self) 502 | self.n_of_peaks_getter.setText('0') 503 | n_of_peaks_layout.addWidget(n_of_peaks_label) 504 | n_of_peaks_layout.addWidget(self.n_of_peaks_getter) 505 | 506 | continue_button = QtWidgets.QPushButton('Continue') 507 | continue_button.clicked.connect(self.proceed) 508 | 509 | main_layout = QtWidgets.QVBoxLayout() 510 | main_layout.addWidget(label) 511 | main_layout.addLayout(n_of_peaks_layout) 512 | main_layout.addWidget(continue_button) 513 | 514 | self.setLayout(main_layout) 515 | 516 | def proceed(self): 517 | try: 518 | number_of_peaks = int(self.n_of_peaks_getter.text()) 519 | except ValueError: 520 | # popup window with exception 521 | msg = QtWidgets.QMessageBox(self) 522 | msg.setText("'Number of peaks' should be an integer value!") 523 | msg.setIcon(QtWidgets.QMessageBox.Warning) 524 | msg.exec_() 525 | return None 526 | 527 | subwindow = AnnotationGetBordersWindowNovel(number_of_peaks, self.parent) 528 | subwindow.show() 529 | self.close() 530 | 531 | 532 | class AnnotationPeakLayoutNovel(QtWidgets.QWidget): 533 | def __init__(self, peak_number, parent): 534 | super().__init__(parent) 535 | 536 | borders_layout = QtWidgets.QHBoxLayout() 537 | 538 | label = QtWidgets.QLabel() 539 | label.setText(f'Peak #{peak_number}') 540 | 541 | begin_label = QtWidgets.QLabel() 542 | begin_label.setText('begin = ') 543 | self.begin_getter = QtWidgets.QLineEdit(self) 544 | end_label = QtWidgets.QLabel() 545 | end_label.setText('end = ') 546 | self.end_getter = QtWidgets.QLineEdit(self) 547 | borders_layout.addWidget(begin_label) 548 | borders_layout.addWidget(self.begin_getter) 549 | borders_layout.addWidget(end_label) 550 | borders_layout.addWidget(self.end_getter) 551 | 552 | peak_label_layout = QtWidgets.QHBoxLayout() 553 | peak_label_label = QtWidgets.QLabel() 554 | peak_label_label.setText('peak label = ') 555 | self.peak_label_getter = QtWidgets.QComboBox(self) 556 | self.peak_label_getter.addItems(['', 'Good (smooth, high intensive)', 'Low intensive (close to LOD)', 557 | 'Lousy (not good)', 'Noisy, strange (probably chemical noise)']) 558 | peak_label_layout.addWidget(peak_label_label) 559 | peak_label_layout.addWidget(self.peak_label_getter) 560 | 561 | main_layout = QtWidgets.QVBoxLayout() 562 | main_layout.addWidget(label) 563 | main_layout.addLayout(borders_layout) 564 | main_layout.addLayout(peak_label_layout) 565 | 566 | self.setLayout(main_layout) 567 | 568 | 569 | class AnnotationGetBordersWindowNovel(QtWidgets.QDialog): 570 | def __init__(self, number_of_peaks: int, parent: AnnotationMainWindow): 571 | self.str2label = {'': 0, '': 0, 'Good (smooth, high intensive)': 1, 572 | 'Low intensive (close to LOD)': 2, 'Lousy (not good)': 3, 573 | 'Noisy, strange (probably chemical noise)': 4} 574 | self.number_of_peaks = number_of_peaks 575 | self.parent = parent 576 | super().__init__(parent) 577 | self.setWindowTitle("annotation: peaks' borders") 578 | 579 | main_layout = QtWidgets.QVBoxLayout() 580 | self.peak_layouts = [] 581 | for i in range(number_of_peaks): 582 | self.peak_layouts.append(AnnotationPeakLayoutNovel(i + 1, self)) 583 | main_layout.addWidget(self.peak_layouts[-1]) 584 | 585 | preview_button = QtWidgets.QPushButton('Preview') 586 | preview_button.clicked.connect(self.preview) 587 | main_layout.addWidget(preview_button) 588 | 589 | save_button = QtWidgets.QPushButton('Save') 590 | save_button.clicked.connect(self.save) 591 | main_layout.addWidget(save_button) 592 | 593 | self.setLayout(main_layout) 594 | 595 | def preview(self): 596 | try: 597 | borders = [] 598 | for pl in self.peak_layouts: 599 | if pl.begin_getter.text() and pl.end_getter.text(): 600 | begin = pl.begin_getter.text() 601 | end = pl.end_getter.text() 602 | if begin and end: 603 | begin = int(begin) 604 | end = int(end) 605 | borders.append((begin, end)) 606 | self.parent.plot_preview(borders) 607 | except ValueError: 608 | # popup window with exception 609 | msg = QtWidgets.QMessageBox(self) 610 | msg.setText("'begin' and 'end' of each peak should be integer numbers!") 611 | msg.setIcon(QtWidgets.QMessageBox.Warning) 612 | msg.exec_() 613 | 614 | def save(self): 615 | try: 616 | code = os.path.basename(self.parent.plotted_path) 617 | code = code[:code.rfind('.')] 618 | label = 1 619 | number_of_peaks = self.number_of_peaks 620 | peaks_labels = [] 621 | borders = [] 622 | for pl in self.peak_layouts: 623 | peak_label = self.str2label[pl.peak_label_getter.currentText()] 624 | peaks_labels.append(peak_label) 625 | 626 | begin = int(pl.begin_getter.text()) 627 | end = int(pl.end_getter.text()) 628 | borders.append((begin, end)) 629 | except ValueError: 630 | # popup window with exception 631 | msg = QtWidgets.QMessageBox(self) 632 | msg.setText("Check parameters. Something is wrong!") 633 | msg.setIcon(QtWidgets.QMessageBox.Warning) 634 | msg.exec_() 635 | 636 | self.parent.plotted_roi.save_annotated(self.parent.plotted_path, code, label, number_of_peaks, 637 | peaks_labels, borders, description=self.parent.current_description) 638 | 639 | if self.parent.current_flag: 640 | self.parent.current_flag = False 641 | self.parent.rois_list.addFile(self.parent.plotted_path) 642 | self.parent.file_suffix += 1 643 | self.parent.plot_current() 644 | else: 645 | self.parent.plotted_item.setSelected(False) 646 | index = min(self.parent.rois_list.row(self.parent.plotted_item) + 1, self.parent.rois_list.count() - 1) 647 | self.parent.plotted_item = self.parent.rois_list.item(index) 648 | self.parent.plotted_item.setSelected(True) 649 | self.parent.plot_chosen() 650 | self.close() 651 | 652 | 653 | class FileContextMenu(QtWidgets.QMenu): 654 | def __init__(self, parent: AnnotationMainWindow): 655 | super().__init__(parent) 656 | 657 | self.parent = parent 658 | self.menu = QtWidgets.QMenu(parent) 659 | 660 | self.close = QtWidgets.QAction('Close', parent) 661 | self.delete = QtWidgets.QAction('Delete', parent) 662 | 663 | self.menu.addAction(self.close) 664 | self.menu.addAction(self.delete) 665 | 666 | action = self.menu.exec_(QtGui.QCursor.pos()) 667 | 668 | if action == self.close: 669 | self.close_file() 670 | elif action == self.delete: 671 | self.delete_file() 672 | 673 | def close_file(self): 674 | item = self.parent.get_chosen() 675 | self.parent.close_file(item) 676 | 677 | def delete_file(self): 678 | item = self.parent.get_chosen() 679 | self.parent.delete_file(item) 680 | --------------------------------------------------------------------------------